Commit 4b10c8c3 authored by drbh's avatar drbh
Browse files

fix: improve scales change and revert conditional

parent ab4d480d
......@@ -38,9 +38,12 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
# if scales is a scalar (0D tensor), convert it to a 1D tensor
if scales.dim() == 0:
scales = scales.unsqueeze(0)
scales = scales.unsqueeze(0)
# repack weights for Marlin if a single scale is provided
if scales.size(0) == 1:
if scales.shape[1] == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment