Skip to content

Commit 99aa54f

Browse files
authored
TPP LM head slice for generation and kernel parallel (#2253)
* lm head parall * lm head slice for generation
1 parent 1d5e83d commit 99aa54f

File tree

4 files changed

+52
-4
lines changed

4 files changed

+52
-4
lines changed

csrc/cpu/tpp/kernels/TPPGEMMKrnl.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ inline void tpp_linear_bias(
9292
auto in_sizes = t_in.sizes();
9393
auto wt_sizes = t_wt.sizes();
9494
auto BS = in_sizes[0] * in_sizes[1];
95+
if (BS > FT_OPT_SIZE) { // first token compute
96+
if (wt_sizes[3] != 100) {
97+
t_wt = wt_tensor_for_first_token<T>(t_wt);
98+
wt_sizes = t_wt.sizes();
99+
}
100+
large_cache_opt = true;
101+
}
102+
95103
auto C = in_sizes[2];
96104

97105
auto Nc = wt_sizes[1];
@@ -169,11 +177,14 @@ inline void tpp_linear_no_bias(
169177
at::Tensor& t_out) {
170178
auto in_sizes = t_in.sizes();
171179
auto BS = in_sizes[0] * in_sizes[1];
180+
auto wt_sizes = t_wt.sizes();
172181
if (BS > FT_OPT_SIZE) { // first token compute
173-
t_wt = wt_tensor_for_first_token<T>(t_wt);
182+
if (wt_sizes[3] != 100) {
183+
t_wt = wt_tensor_for_first_token<T>(t_wt);
184+
wt_sizes = t_wt.sizes();
185+
}
174186
large_cache_opt = true;
175187
}
176-
auto wt_sizes = t_wt.sizes();
177188
auto C = in_sizes[2];
178189

179190
auto Nc = wt_sizes[1];

examples/cpu/inference/python/llm/distributed/run_generation_with_deepspeed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ def get_checkpoint_files(model_name_or_path):
245245
if not hasattr(config, "text_max_length") and args.prompt is None:
246246
config.text_max_length = int(args.input_tokens) + int(args.max_new_tokens)
247247

248+
if not hasattr(config, "lm_head_generation"):
249+
config.lm_head_generation = True
250+
248251
# XXX: can't automatically derive dtype via config's `from_pretrained`
249252
# dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16
250253

examples/cpu/inference/python/llm/single_instance/run_generation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@
9898
)
9999
if not hasattr(config, "text_max_length") and args.prompt is None:
100100
config.text_max_length = int(args.input_tokens) + int(args.max_new_tokens)
101+
102+
if not hasattr(config, "lm_head_generation"):
103+
config.lm_head_generation = True
104+
101105
model = model_class[0].from_pretrained(
102106
args.model_id,
103107
torch_dtype=amp_dtype,

intel_extension_for_pytorch/transformers/models/reference/models.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def GPTJForCausalLM_forward(
4747
torch.cuda.set_device(self.transformer.first_device)
4848
hidden_states = hidden_states.to(self.lm_head.weight.device)
4949

50+
if (
51+
hasattr(self, "config")
52+
and hasattr(self.config, "lm_head_generation")
53+
and self.config.lm_head_generation
54+
and hidden_states.size(1) != 1
55+
):
56+
hidden_states = hidden_states[:, -1:, :]
57+
5058
# make sure sampling in fp16 works correctly and
5159
# compute loss in fp32 to match with mesh-tf version
5260
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
@@ -119,6 +127,14 @@ def LlamaForCausalLM_forward(
119127
)
120128

121129
hidden_states = outputs[0]
130+
if (
131+
hasattr(self, "config")
132+
and hasattr(self.config, "lm_head_generation")
133+
and self.config.lm_head_generation
134+
and hidden_states.size(1) != 1
135+
):
136+
hidden_states = hidden_states[:, -1:, :]
137+
122138
logits = self.lm_head(hidden_states)
123139

124140
loss = None
@@ -178,6 +194,13 @@ def GPTNeoXForCausalLM_forward(
178194
)
179195

180196
hidden_states = outputs[0]
197+
if (
198+
hasattr(self, "config")
199+
and hasattr(self.config, "lm_head_generation")
200+
and self.config.lm_head_generation
201+
and hidden_states.size(1) != 1
202+
):
203+
hidden_states = hidden_states[:, -1:, :]
181204
lm_logits = self.embed_out(hidden_states)
182205

183206
lm_loss = None
@@ -244,8 +267,15 @@ def OPTForCausalLM_forward(
244267
output_hidden_states=output_hidden_states,
245268
return_dict=return_dict,
246269
)
247-
248-
logits = self.lm_head(outputs[0]).contiguous()
270+
hidden_states = outputs[0]
271+
if (
272+
hasattr(self, "config")
273+
and hasattr(self.config, "lm_head_generation")
274+
and self.config.lm_head_generation
275+
and hidden_states.size(1) != 1
276+
):
277+
hidden_states = hidden_states[:, -1:, :]
278+
logits = self.lm_head(hidden_states).contiguous()
249279

250280
loss = None
251281
if labels is not None:

0 commit comments

Comments
 (0)