Skip to content

int4_weight_only get plain weight are padded #2249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jiqing-feng opened this issue May 23, 2025 · 7 comments
Open

int4_weight_only get plain weight are padded #2249

jiqing-feng opened this issue May 23, 2025 · 7 comments

Comments

@jiqing-feng
Copy link

I try to quantize a model with int4_weight_only, and want to get the plained weight, but found the weight has been padded. To reproduce it, run the following script:

import torch
from transformers import TorchAoConfig, AutoModelForCausalLM
 
model_name = "JackFram/llama-68m"
quantization_config = TorchAoConfig("int4_weight_only")
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cuda:0", quantization_config=quantization_config)
print(quantized_model.model.layers[0].self_attn.q_proj.weight.tensor_impl.get_plain()[0].shape)
print(quantized_model.model.layers[0].self_attn.q_proj.weight.tensor_impl.get_plain()[0])

output

(768, 1024)
tensor([[11, 12,  8,  ...,  0,  0,  0],
        [ 5,  6,  5,  ...,  0,  0,  0],
        [ 5,  7,  7,  ...,  0,  0,  0],
        ...,
        [ 7,  5,  2,  ...,  0,  0,  0],
        [ 6,  1,  7,  ...,  0,  0,  0],
        [ 8, 11,  9,  ...,  0,  0,  0]], device='cuda:0', dtype=torch.int32)

The original shape should be (768, 768), but the plained weight shape is (768, 1024). Can we have a remove padding process in get_plain() function?

@jiqing-feng
Copy link
Author

Hi @jainapurva @HDCharles . Could you help to take a look? I see you upstreamed these codes.

@jiqing-feng
Copy link
Author

The motivation is that I have a cuda quantized model and want to load this quantized model from CPU or XPU. The data layout is different across different devices. I was wondering if we could have a common layout implementation that can be applied across different devices.

@jiqing-feng
Copy link
Author

The motivation is that I have a cuda quantized model and want to load this quantized model from CPU or XPU. The data layout is different across different devices. I was wondering if we could have a common layout implementation that can be applied across different devices.

Hi @jerryzh168 . Do you have any comments on that? The proposal is that we want to save the torchao model in a general format, and decide the data_layout / tensor_impl when loading. Do you if we can support this feature?

@HDCharles
Copy link
Contributor

HDCharles commented May 28, 2025

Hey

get_plain is defined for the layout as the data, scale and zero_point that is being stored. We don't capture the original shape without padding. The expectation is that you quantize and run the model on the same device however we do support having the model on cpu and then moving the layers to cuda as they get quantized.

normally if you have a cpu model you can do quantize_(model, config, device='cuda') and it will move it to cuda as it does quantization. I do not think this functionality is in huggingface though since they have their own model loading system. I don't think they support loading it on cpu and then quantizing on cuda.

we've discussed how annoying the UX is for int4 cuda/cpu and I believe @jerryzh168 was planning to implement it soon, not sure about the ETA.

@jiqing-feng
Copy link
Author

jiqing-feng commented May 28, 2025

Thanks @HDCharles ! Refer to your comment 2. Can I expect that torchao will support this feature I proposed in the future? We can load a quantized model on different device types (CPU/XPU/CUDA).

@HDCharles
Copy link
Contributor

yes, the current gap is more about transferring from one to another rather than the loading but that would be possible.

@jerryzh168
Copy link
Contributor

jerryzh168 commented May 28, 2025

@jiqing-feng for padding, I feel we should just drop it, I'm planning to replace the current int4wo kernel that's powered by tinygemm, that requires padding externally, with the gemlite kernels from @mobicham since it doesn't require padding, main reason is padding causes additional issues like slicing: #2174

for a "standard" layout, I think it will just be plain layout, we can store that and implement tensor.to(layout=?) to convert between the plain layout and the desired layout. for int4 specifically, the plain layout should also be packed to save space, we haven't decided what it should look like, but we could start with what's used by

class UIntXWeightOnlyConfig(AOBaseConfig):
I feel.

Another related aspect is device, I think ideally we can implement all these conversions in CPU, but if not possible, we can always move the plain layout to the target device (like CUDA, XPU) and then run the conversion (packing)

also cc @metascroy this is related to the retractability discussion as well.

Another consideration here is that the requirement for using to to convert between different layouts is that all these things are using the same quantization algorithm, I know currently people are extending the quantization algorithm of tinygemm int4 kernels, I think we could also change that to something more general, since tinygemm quantize/dequantize are very specific to tinygemm kernel (it even includes a dtype conversion) we could either use:

  1. choose_qparams_affine
    "choose_qparams_affine",
    (with integer zero_point and 0 is exactly representable)
  2. choose_qparams_affine_dont_preserve_zero
    "choose_qparams_affine_dont_preserve_zero",
    , that float zero is not exactly representable (zero point is also int)
  3. another variant of choose_qparams_affine that's needed by XPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants