Closed
Description
Motivating Example
Below is a case in MobileBertForMaskedLM which has a concatenation on two model parameters (hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
). This concat takes >20% of the single-threaded inference time on CPU but this cost can be saved with constant folding.
class MobileBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = MobileBertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
hidden_states += self.decoder.bias
return hidden_states
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @ezyang @msaroufim @wconstab @ngimel @bdhirsh
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Done