@@ -223,6 +223,27 @@ def get_int_from_env(env_keys, default):
223
223
224
224
TokenSequence = Union [List [int ], torch .LongTensor , torch .Tensor , BatchEncoding ]
225
225
226
+ tp_grain_size = 64
227
+ if args .ipex_weight_only_quantization and args .low_precision_checkpoint != "" :
228
+ pathname = args .low_precision_checkpoint
229
+ assert os .path .exists (pathname ), f"Checkpoint file does not exist: { pathname } "
230
+ if os .path .isdir (pathname ):
231
+ try :
232
+ with open (pathname + "/config.json" ) as f :
233
+ quant_model_config = json .load (f )
234
+ tp_grain_size = int (
235
+ quant_model_config ["quantization_config" ]["group_size" ]
236
+ )
237
+ except Exception as e :
238
+ print ("Failed to get group_size from config.json" )
239
+ elif args .group_size > 0 :
240
+ tp_grain_size = args .group_size
241
+ else :
242
+ print (
243
+ "Warning: cannot get group_size from config.json or --group-size, "
244
+ "using default value 64 for tp_grain_size"
245
+ )
246
+
226
247
227
248
class HuggingFaceModel (BaseLM ):
228
249
_DEFAULT_MAX_LENGTH = 2048
@@ -399,6 +420,9 @@ def write_checkpoints_json():
399
420
base_dir = repo_root ,
400
421
dtype = infer_dtype ,
401
422
checkpoint = checkpoints_json ,
423
+ tensor_parallel = deepspeed .inference .config .DeepSpeedTPConfig (
424
+ tp_grain_size = tp_grain_size
425
+ ),
402
426
)
403
427
404
428
self .model = self .model .module
@@ -537,10 +561,13 @@ def write_checkpoints_json():
537
561
num_heads = model .config .num_attention_heads
538
562
rank = local_rank
539
563
540
- layers_split_by_N = [
564
+ mha_layers_split_by_N = [
541
565
"q_proj" ,
542
566
"k_proj" ,
543
567
"v_proj" ,
568
+ ]
569
+ # mlp is split with grain size = tp_grain_size
570
+ mlp_layers_split_by_N = [
544
571
"gate_proj" ,
545
572
"up_proj" ,
546
573
"fc_in" ,
@@ -549,23 +576,26 @@ def write_checkpoints_json():
549
576
"w1" ,
550
577
"w3" ,
551
578
]
552
- layers_split_by_K = [
579
+ mha_layers_split_by_K = [
553
580
"o_proj" ,
581
+ "out_proj" ,
582
+ ]
583
+ # mlp is split with grain size = tp_grain_size
584
+ mlp_layers_split_by_K = [
554
585
"down_proj" ,
555
586
"fc_out" ,
556
587
"fc2" ,
557
- "out_proj" ,
558
588
"dense" ,
559
589
"dense_4h_to_h" ,
560
590
"w2" ,
561
591
]
592
+ # lm_head is split with grain size = tp_grain_size
562
593
lm_head_layers = ["lm_head" ] # split by K but not quantized
563
594
quantization_method = quant_model_config ["quantization_config" ][
564
595
"quant_method"
565
596
]
566
597
head_range = [0 ]
567
598
head_per_rank = num_heads // world_size
568
-
569
599
for i in range (0 , world_size ):
570
600
head_this_rank = head_per_rank
571
601
if i < num_heads % world_size :
@@ -578,7 +608,7 @@ def write_checkpoints_json():
578
608
)
579
609
if "bias" in key :
580
610
continue
581
- if any (substring in key for substring in layers_split_by_N ):
611
+ if any (substring in key for substring in mha_layers_split_by_N ):
582
612
data = low_precision_checkpoint_dict [key ]
583
613
if quantization_method == "awq" :
584
614
# awq qweight: [K, N // 8]
@@ -592,7 +622,48 @@ def write_checkpoints_json():
592
622
raise AssertionError (
593
623
f"{ quantization_method } is not supported yet."
594
624
)
595
- if any (substring in key for substring in layers_split_by_K ):
625
+ elif any (
626
+ substring in key for substring in mlp_layers_split_by_N
627
+ ):
628
+ data = low_precision_checkpoint_dict [key ]
629
+ if quantization_method == "awq" :
630
+ # awq qweight: [K, N // 8]
631
+ # awq scales: [K // G, N]
632
+ # awq qzeros: [K // G, N // 8]
633
+ if "scales" in key :
634
+ assert (
635
+ data .shape [1 ] % tp_grain_size == 0
636
+ ), "N must be divisible by tp_grain_size"
637
+ grains = data .shape [1 ] // tp_grain_size
638
+ dim = tp_grain_size
639
+ else :
640
+ assert (
641
+ data .shape [1 ] * 8
642
+ ) % tp_grain_size == 0 , (
643
+ "N must be divisible by tp_grain_size"
644
+ )
645
+ grains = data .shape [1 ] // (tp_grain_size // 8 )
646
+ dim = tp_grain_size // 8
647
+ grains_per_rank = grains // world_size
648
+ grains_rem = grains % world_size
649
+ grains_start = grains_per_rank * local_rank + min (
650
+ local_rank , grains_rem
651
+ )
652
+ grains_end = (
653
+ grains_start
654
+ + grains_per_rank
655
+ + (1 if local_rank < grains_rem else 0 )
656
+ )
657
+ low_precision_checkpoint_dict [key ] = data [
658
+ :, grains_start * dim : grains_end * dim
659
+ ]
660
+ else :
661
+ raise AssertionError (
662
+ f"{ quantization_method } is not supported yet."
663
+ )
664
+ elif any (
665
+ substring in key for substring in mha_layers_split_by_K
666
+ ):
596
667
data = low_precision_checkpoint_dict [key ]
597
668
if quantization_method == "awq" :
598
669
# awq qweight: [K, N // 8]
@@ -612,18 +683,61 @@ def write_checkpoints_json():
612
683
raise AssertionError (
613
684
f"{ quantization_method } is not supported yet."
614
685
)
615
- if any (substring in key for substring in lm_head_layers ):
686
+ elif any (
687
+ substring in key for substring in mlp_layers_split_by_K
688
+ ):
689
+ data = low_precision_checkpoint_dict [key ]
690
+ if quantization_method == "awq" :
691
+ # awq qweight: [K, N // 8]
692
+ # awq scales: [K // G, N]
693
+ # awq qzeros: [K // G, N // 8]
694
+ if "qweight" in key :
695
+ assert (
696
+ data .shape [0 ] % tp_grain_size == 0
697
+ ), "K must be divisible by tp_grain_size"
698
+ grains = data .shape [0 ] // tp_grain_size
699
+ dim = tp_grain_size
700
+ else :
701
+ grains = data .shape [0 ]
702
+ dim = 1
703
+ grains_per_rank = grains // world_size
704
+ grains_rem = grains % world_size
705
+ grains_start = grains_per_rank * local_rank + min (
706
+ local_rank , grains_rem
707
+ )
708
+ grains_end = (
709
+ grains_start
710
+ + grains_per_rank
711
+ + (1 if local_rank < grains_rem else 0 )
712
+ )
713
+ low_precision_checkpoint_dict [key ] = data [
714
+ grains_start * dim : grains_end * dim
715
+ ]
716
+ else :
717
+ raise AssertionError (
718
+ f"{ quantization_method } is not supported yet."
719
+ )
720
+ elif any (substring in key for substring in lm_head_layers ):
616
721
# lm_head: [N, K] (not quantized)
617
722
# Same for both AWQ and GPTQ
618
723
data = low_precision_checkpoint_dict [key ]
619
- if data .shape [- 1 ] % head_range [- 1 ] == 0 :
620
- dim = data .shape [- 1 ] // head_range [- 1 ]
621
- else :
622
- dim = data .shape [- 1 ] // world_size
623
- q_head_start = local_rank
624
- q_head_end = local_rank + 1
724
+ assert (
725
+ data .shape [1 ] % tp_grain_size == 0
726
+ ), "K must be divisible by tp_grain_size"
727
+ grains = data .shape [1 ] // tp_grain_size
728
+ dim = tp_grain_size
729
+ grains_per_rank = grains // world_size
730
+ grains_rem = grains % world_size
731
+ grains_start = grains_per_rank * local_rank + min (
732
+ local_rank , grains_rem
733
+ )
734
+ grains_end = (
735
+ grains_start
736
+ + grains_per_rank
737
+ + (1 if local_rank < grains_rem else 0 )
738
+ )
625
739
low_precision_checkpoint_dict [key ] = data [
626
- :, q_head_start * dim : q_head_end * dim
740
+ :, grains_start * dim : grains_end * dim
627
741
]
628
742
low_precision_checkpoint = (
629
743
low_precision_checkpoint_dict ,
@@ -1381,6 +1495,9 @@ def write_checkpoints_json():
1381
1495
base_dir = repo_root ,
1382
1496
dtype = infer_dtype ,
1383
1497
checkpoint = checkpoints_json ,
1498
+ tensor_parallel = deepspeed .inference .config .DeepSpeedTPConfig (
1499
+ tp_grain_size = tp_grain_size
1500
+ ),
1384
1501
)
1385
1502
1386
1503
self ._model = self ._model .module
@@ -2146,6 +2263,9 @@ def write_checkpoints_json():
2146
2263
base_dir = repo_root ,
2147
2264
dtype = infer_dtype ,
2148
2265
checkpoint = checkpoints_json ,
2266
+ tensor_parallel = deepspeed .inference .config .DeepSpeedTPConfig (
2267
+ tp_grain_size = tp_grain_size
2268
+ ),
2149
2269
)
2150
2270
2151
2271
self .model = self .model .module
0 commit comments