Skip to content

Commit 3e27750

Browse files
authored
Enable TP=3 with int4 checkpoint for WOQ (#3400)
1 parent b24885d commit 3e27750

File tree

3 files changed

+256
-28
lines changed

3 files changed

+256
-28
lines changed

examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py

Lines changed: 134 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,27 @@ def get_int_from_env(env_keys, default):
223223

224224
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
225225

226+
tp_grain_size = 64
227+
if args.ipex_weight_only_quantization and args.low_precision_checkpoint != "":
228+
pathname = args.low_precision_checkpoint
229+
assert os.path.exists(pathname), f"Checkpoint file does not exist: {pathname}"
230+
if os.path.isdir(pathname):
231+
try:
232+
with open(pathname + "/config.json") as f:
233+
quant_model_config = json.load(f)
234+
tp_grain_size = int(
235+
quant_model_config["quantization_config"]["group_size"]
236+
)
237+
except Exception as e:
238+
print("Failed to get group_size from config.json")
239+
elif args.group_size > 0:
240+
tp_grain_size = args.group_size
241+
else:
242+
print(
243+
"Warning: cannot get group_size from config.json or --group-size, "
244+
"using default value 64 for tp_grain_size"
245+
)
246+
226247

227248
class HuggingFaceModel(BaseLM):
228249
_DEFAULT_MAX_LENGTH = 2048
@@ -399,6 +420,9 @@ def write_checkpoints_json():
399420
base_dir=repo_root,
400421
dtype=infer_dtype,
401422
checkpoint=checkpoints_json,
423+
tensor_parallel=deepspeed.inference.config.DeepSpeedTPConfig(
424+
tp_grain_size=tp_grain_size
425+
),
402426
)
403427

404428
self.model = self.model.module
@@ -537,10 +561,13 @@ def write_checkpoints_json():
537561
num_heads = model.config.num_attention_heads
538562
rank = local_rank
539563

540-
layers_split_by_N = [
564+
mha_layers_split_by_N = [
541565
"q_proj",
542566
"k_proj",
543567
"v_proj",
568+
]
569+
# mlp is split with grain size = tp_grain_size
570+
mlp_layers_split_by_N = [
544571
"gate_proj",
545572
"up_proj",
546573
"fc_in",
@@ -549,23 +576,26 @@ def write_checkpoints_json():
549576
"w1",
550577
"w3",
551578
]
552-
layers_split_by_K = [
579+
mha_layers_split_by_K = [
553580
"o_proj",
581+
"out_proj",
582+
]
583+
# mlp is split with grain size = tp_grain_size
584+
mlp_layers_split_by_K = [
554585
"down_proj",
555586
"fc_out",
556587
"fc2",
557-
"out_proj",
558588
"dense",
559589
"dense_4h_to_h",
560590
"w2",
561591
]
592+
# lm_head is split with grain size = tp_grain_size
562593
lm_head_layers = ["lm_head"] # split by K but not quantized
563594
quantization_method = quant_model_config["quantization_config"][
564595
"quant_method"
565596
]
566597
head_range = [0]
567598
head_per_rank = num_heads // world_size
568-
569599
for i in range(0, world_size):
570600
head_this_rank = head_per_rank
571601
if i < num_heads % world_size:
@@ -578,7 +608,7 @@ def write_checkpoints_json():
578608
)
579609
if "bias" in key:
580610
continue
581-
if any(substring in key for substring in layers_split_by_N):
611+
if any(substring in key for substring in mha_layers_split_by_N):
582612
data = low_precision_checkpoint_dict[key]
583613
if quantization_method == "awq":
584614
# awq qweight: [K, N // 8]
@@ -592,7 +622,48 @@ def write_checkpoints_json():
592622
raise AssertionError(
593623
f"{quantization_method} is not supported yet."
594624
)
595-
if any(substring in key for substring in layers_split_by_K):
625+
elif any(
626+
substring in key for substring in mlp_layers_split_by_N
627+
):
628+
data = low_precision_checkpoint_dict[key]
629+
if quantization_method == "awq":
630+
# awq qweight: [K, N // 8]
631+
# awq scales: [K // G, N]
632+
# awq qzeros: [K // G, N // 8]
633+
if "scales" in key:
634+
assert (
635+
data.shape[1] % tp_grain_size == 0
636+
), "N must be divisible by tp_grain_size"
637+
grains = data.shape[1] // tp_grain_size
638+
dim = tp_grain_size
639+
else:
640+
assert (
641+
data.shape[1] * 8
642+
) % tp_grain_size == 0, (
643+
"N must be divisible by tp_grain_size"
644+
)
645+
grains = data.shape[1] // (tp_grain_size // 8)
646+
dim = tp_grain_size // 8
647+
grains_per_rank = grains // world_size
648+
grains_rem = grains % world_size
649+
grains_start = grains_per_rank * local_rank + min(
650+
local_rank, grains_rem
651+
)
652+
grains_end = (
653+
grains_start
654+
+ grains_per_rank
655+
+ (1 if local_rank < grains_rem else 0)
656+
)
657+
low_precision_checkpoint_dict[key] = data[
658+
:, grains_start * dim : grains_end * dim
659+
]
660+
else:
661+
raise AssertionError(
662+
f"{quantization_method} is not supported yet."
663+
)
664+
elif any(
665+
substring in key for substring in mha_layers_split_by_K
666+
):
596667
data = low_precision_checkpoint_dict[key]
597668
if quantization_method == "awq":
598669
# awq qweight: [K, N // 8]
@@ -612,18 +683,61 @@ def write_checkpoints_json():
612683
raise AssertionError(
613684
f"{quantization_method} is not supported yet."
614685
)
615-
if any(substring in key for substring in lm_head_layers):
686+
elif any(
687+
substring in key for substring in mlp_layers_split_by_K
688+
):
689+
data = low_precision_checkpoint_dict[key]
690+
if quantization_method == "awq":
691+
# awq qweight: [K, N // 8]
692+
# awq scales: [K // G, N]
693+
# awq qzeros: [K // G, N // 8]
694+
if "qweight" in key:
695+
assert (
696+
data.shape[0] % tp_grain_size == 0
697+
), "K must be divisible by tp_grain_size"
698+
grains = data.shape[0] // tp_grain_size
699+
dim = tp_grain_size
700+
else:
701+
grains = data.shape[0]
702+
dim = 1
703+
grains_per_rank = grains // world_size
704+
grains_rem = grains % world_size
705+
grains_start = grains_per_rank * local_rank + min(
706+
local_rank, grains_rem
707+
)
708+
grains_end = (
709+
grains_start
710+
+ grains_per_rank
711+
+ (1 if local_rank < grains_rem else 0)
712+
)
713+
low_precision_checkpoint_dict[key] = data[
714+
grains_start * dim : grains_end * dim
715+
]
716+
else:
717+
raise AssertionError(
718+
f"{quantization_method} is not supported yet."
719+
)
720+
elif any(substring in key for substring in lm_head_layers):
616721
# lm_head: [N, K] (not quantized)
617722
# Same for both AWQ and GPTQ
618723
data = low_precision_checkpoint_dict[key]
619-
if data.shape[-1] % head_range[-1] == 0:
620-
dim = data.shape[-1] // head_range[-1]
621-
else:
622-
dim = data.shape[-1] // world_size
623-
q_head_start = local_rank
624-
q_head_end = local_rank + 1
724+
assert (
725+
data.shape[1] % tp_grain_size == 0
726+
), "K must be divisible by tp_grain_size"
727+
grains = data.shape[1] // tp_grain_size
728+
dim = tp_grain_size
729+
grains_per_rank = grains // world_size
730+
grains_rem = grains % world_size
731+
grains_start = grains_per_rank * local_rank + min(
732+
local_rank, grains_rem
733+
)
734+
grains_end = (
735+
grains_start
736+
+ grains_per_rank
737+
+ (1 if local_rank < grains_rem else 0)
738+
)
625739
low_precision_checkpoint_dict[key] = data[
626-
:, q_head_start * dim : q_head_end * dim
740+
:, grains_start * dim : grains_end * dim
627741
]
628742
low_precision_checkpoint = (
629743
low_precision_checkpoint_dict,
@@ -1381,6 +1495,9 @@ def write_checkpoints_json():
13811495
base_dir=repo_root,
13821496
dtype=infer_dtype,
13831497
checkpoint=checkpoints_json,
1498+
tensor_parallel=deepspeed.inference.config.DeepSpeedTPConfig(
1499+
tp_grain_size=tp_grain_size
1500+
),
13841501
)
13851502

13861503
self._model = self._model.module
@@ -2146,6 +2263,9 @@ def write_checkpoints_json():
21462263
base_dir=repo_root,
21472264
dtype=infer_dtype,
21482265
checkpoint=checkpoints_json,
2266+
tensor_parallel=deepspeed.inference.config.DeepSpeedTPConfig(
2267+
tp_grain_size=tp_grain_size
2268+
),
21492269
)
21502270

21512271
self.model = self.model.module

0 commit comments

Comments
 (0)