Skip to content

Commit 3094f34

Browse files
authored
Add ipex::einsum_add (#674)
Enable einsum+add fusion by oneDNN binary post ops.
1 parent c1fcf7a commit 3094f34

File tree

8 files changed

+912
-0
lines changed

8 files changed

+912
-0
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Einsum.cpp

Lines changed: 685 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
5+
#include <c10/core/Scalar.h>
6+
#include <torch/csrc/jit/runtime/custom_operator.h>
7+
8+
#include "csrc/cpu/ideep/ideep.hpp"
9+
10+
namespace torch {
11+
namespace jit {
12+
13+
// XXX: PyTorch does not support nesting namespace
14+
// And the alias analysis is not working for namespace other than aten ...
15+
// So we fake some op namespaces to workaround that.
16+
namespace ipex {
17+
static auto einsum_binary = Symbol::fromQualString("ipex::einsum_binary");
18+
19+
} // namespace ipex
20+
21+
} // namespace jit
22+
} // namespace torch
23+
24+
namespace torch_ipex {
25+
namespace cpu {
26+
27+
at::Tensor einsum_binary(
28+
c10::string_view,
29+
const c10::List<at::Tensor>& operands,
30+
const at::Tensor& input,
31+
const c10::Scalar& alpha);
32+
33+
} // namespace cpu
34+
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ static auto bmm_add = Symbol::fromQualString("ipex::bmm_add");
2525
namespace torch_ipex {
2626
namespace cpu {
2727

28+
at::Tensor bmm_impl(
29+
const at::Tensor& tensor1,
30+
const at::Tensor& tensor2,
31+
at::Tensor out,
32+
const ideep::attr_t& attr,
33+
const std::vector<ideep::tensor>& postop_tensors,
34+
const float dst_coeff);
35+
2836
at::Tensor dil_matmul_div(
2937
const at::Tensor& left,
3038
const at::Tensor& right,

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ void FuseConcatBnRelu(std::shared_ptr<Graph>& graph);
5252

5353
void insertPrePackedConvTranspose2dOp(std::shared_ptr<Graph>& graph);
5454

55+
void FusedEinsumPost(std::shared_ptr<Graph>& graph);
5556
} // namespace graph_rewrite
5657
} // namespace jit
5758
} // namespace torch
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "graph_rewrite.h"
2+
#include "graph_rewrite_utils.h"
3+
4+
#include <ATen/code_template.h>
5+
6+
namespace torch {
7+
namespace jit {
8+
namespace graph_rewrite {
9+
10+
using namespace at::jit;
11+
12+
auto ipex_einsum_filter =
13+
[](const Match& match,
14+
const std::unordered_map<std::string, Value*>& vmap) {
15+
const auto& match_vmap = match.values_map;
16+
auto equation =
17+
getIValue("equation", match_vmap, vmap).value().toStringView();
18+
int num_ops = std::count(equation.begin(), equation.end(), ',') + 1;
19+
if (num_ops != 2)
20+
return false; // only process the 2 operands
21+
return true;
22+
};
23+
24+
void FusedEinsumPost(std::shared_ptr<Graph>& graph) {
25+
SubgraphRewriter rewriter_einsum_binary;
26+
std::array<std::string, 2> binarys = {"add", "add_"};
27+
auto aten_einsum_binary = CodeTemplate(R"(
28+
graph(%equation, %inputs, %add_arg, %alpha):
29+
%x = aten::einsum(%equation, %inputs)
30+
%res = aten::${binary}(%x, %add_arg, %alpha)
31+
return (%res))");
32+
std::string fused_einsum_binary = R"(
33+
graph(%equation, %inputs, %add_arg, %alpha):
34+
%res = ipex::einsum_binary(%equation, %inputs, %add_arg, %alpha)
35+
return (%res))";
36+
37+
for (const auto& binary : binarys) {
38+
TemplateEnv env;
39+
env.s("binary", binary);
40+
rewriter_einsum_binary.RegisterRewritePattern(
41+
aten_einsum_binary.format(env), fused_einsum_binary);
42+
}
43+
rewriter_einsum_binary.runOnGraph(graph, ipex_einsum_filter);
44+
}
45+
46+
} // namespace graph_rewrite
47+
} // namespace jit
48+
} // namespace torch

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "csrc/aten/cpu/ConcatBnRelu.h"
88
#include "csrc/jit/cpu/kernels/ConvPacked.h"
99
#include "csrc/jit/cpu/kernels/ConvTransposePacked.h"
10+
#include "csrc/jit/cpu/kernels/Einsum.h"
1011
#include "csrc/jit/cpu/kernels/Embeddingbag.h"
1112
#include "csrc/jit/cpu/kernels/Interaction.h"
1213
#include "csrc/jit/cpu/kernels/LinearPacked.h"
@@ -832,6 +833,22 @@ RegisterOperators op({
832833
};
833834
},
834835
aliasAnalysisFromSchema()),
836+
Operator(
837+
"ipex::einsum_binary(str equation, Tensor[] tensors, Tensor add_arg, Scalar alpha) -> Tensor",
838+
[](const Node* node) -> Operation {
839+
return [](Stack* stack) {
840+
auto result = einsum_binary(
841+
(std::move(peek(stack, 0, 4))).toStringView(),
842+
(std::move(peek(stack, 1, 4))).toTensorList(),
843+
(std::move(peek(stack, 2, 4))).toTensor(),
844+
(std::move(peek(stack, 3, 4))).toScalar());
845+
846+
drop(stack, 4);
847+
pack(stack, std::move(result));
848+
return 0;
849+
};
850+
},
851+
aliasAnalysisFromSchema()),
835852

836853
});
837854
} // namespace jit

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
352352
// concat multi-linear with same input
353353
FrozenConcatLinear(graph);
354354

355+
// ipex einsum
356+
graph_rewrite::FusedEinsumPost(graph);
357+
355358
// Fuse the scores calculation(dim + matmul + (add)? + softmax) for
356359
// Multi-Head-Attention
357360
graph_rewrite::FuseMHAScoreCalc(graph);

tests/cpu/test_jit.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,27 @@ def forward(self, x):
786786
y3 += x
787787
return y3.relu_()
788788

789+
class EinsumAdd(nn.Module):
790+
def __init__(self, equation):
791+
super(EinsumAdd, self).__init__()
792+
self.equation = equation
793+
def forward(self, input1, input2, bias):
794+
return torch.einsum(self.equation, input1, input2) + bias
795+
796+
class EinsumAddInplace(nn.Module):
797+
def __init__(self, equation):
798+
super(EinsumAddInplace, self).__init__()
799+
self.equation = equation
800+
def forward(self, input1, input2, bias):
801+
return torch.einsum(self.equation, input1, input2).add_(bias)
802+
803+
class EinsumAddInplaceV1(nn.Module):
804+
def __init__(self, equation):
805+
super(EinsumAddInplaceV1, self).__init__()
806+
self.equation = equation
807+
def forward(self, input1, input2, bias):
808+
return bias.add_(torch.einsum(self.equation, input1, input2))
809+
789810
class Tester(TestCase):
790811

791812
def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None, prec=None, levels=['O0','O1'], use_channels_last=[True, False]):
@@ -2466,6 +2487,101 @@ def test_bmm_add(self):
24662487
expected = torch.baddbmm(M, batch1, batch2)
24672488
self.assertTrue(torch.allclose(out, expected))
24682489

2490+
def test_einsum_add(self):
2491+
def _test_fp32(model_test, input1, input2, bias, kind_in_graph='ipex::einsum_binary', prec=1e-3):
2492+
model = copy.deepcopy(model_test)
2493+
model = model.eval()
2494+
model = ipex.optimize(model, dtype=torch.float32)
2495+
with torch.no_grad():
2496+
res_ref = model(input1, input2, bias)
2497+
tr_model = torch.jit.trace(model, (input1, input2, bias))
2498+
tr_model = torch.jit.freeze(tr_model)
2499+
tr_model(input1, input2, bias)
2500+
tr_model(input1, input2, bias)
2501+
trace_graph = tr_model.graph_for(input1, input2, bias)
2502+
res_jit = tr_model(input1, input2, bias,)
2503+
self.assertEqual(res_ref, res_jit, prec)
2504+
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
2505+
2506+
bias = torch.randn(3,2304)
2507+
input1 = torch.randn(2, 3, 768)
2508+
input2 = torch.randn(768, 2304)
2509+
model_v1 = EinsumAdd('bsh,ho->bso')
2510+
_test_fp32(model_v1, input1, input2, bias)
2511+
2512+
bias = torch.randn(2304)
2513+
input1 = torch.randn(4, 3, 768)
2514+
input2 = torch.randn(768, 2304)
2515+
model_v1 = EinsumAddInplace('bsh,ho->bso')
2516+
_test_fp32(model_v1, input1, input2, bias)
2517+
2518+
bias = torch.randn(4, 3, 2304)
2519+
input1 = torch.randn(4, 3, 768)
2520+
input2 = torch.randn(768, 2304)
2521+
model_v1 = EinsumAddInplaceV1('bsh,ho->bso')
2522+
_test_fp32(model_v1, input1, input2, bias, kind_in_graph='aten::einsum')
2523+
2524+
bias1 = torch.randn(2, 1, 128, 128)
2525+
input3 = torch.randn(2, 4, 128, 768)
2526+
input4 = torch.randn(2, 4, 128, 768)
2527+
model_v2 = EinsumAdd("bnqd,bnkd->bnqk")
2528+
_test_fp32(model_v2, input3, input4, bias1)
2529+
2530+
bias1 = torch.randn(8, 1, 1, 128)
2531+
input3 = torch.randn(8, 4, 128, 768)
2532+
input4 = torch.randn(8, 4, 128, 768)
2533+
model_v2 = EinsumAdd("bnqd,bnkd->bnqk")
2534+
_test_fp32(model_v2, input3, input4, bias1)
2535+
2536+
bias1 = torch.randn(2, 4, 128, 768)
2537+
input1 = torch.randn(2, 4, 128, 768)
2538+
input2 = torch.randn(4, 768, 768)
2539+
model_v2 = EinsumAdd("balh,ahr->balr")
2540+
_test_fp32(model_v2, input1, input2, bias1)
2541+
2542+
bias1 = torch.randn(768)
2543+
input1 = torch.randn(128, 1024)
2544+
input2 = torch.randn(768, 1024)
2545+
model_v2 = EinsumAdd("mc,nc->mn")
2546+
_test_fp32(model_v2, input1, input2, bias1)
2547+
2548+
bias1 = torch.randn(768)
2549+
input1 = torch.randn(128, 1024)
2550+
input2 = torch.randn(1024, 768)
2551+
model_v2 = EinsumAdd("mc,cn->mn")
2552+
_test_fp32(model_v2, input1, input2, bias1)
2553+
2554+
bias1 = torch.randn(1024)
2555+
input1 = torch.randn(1024, 1024)
2556+
input2 = torch.randn(1024, 1024)
2557+
model_v2 = EinsumAdd("mc,cn->nm")
2558+
_test_fp32(model_v2, input1, input2, bias1)
2559+
2560+
bias1 = torch.randn(768)
2561+
input1 = torch.randn(2, 128, 1024)
2562+
input2 = torch.randn(1024, 23, 768)
2563+
model_v2 = EinsumAdd("bqc,chv->bqhv")
2564+
_test_fp32(model_v2, input1, input2, bias1)
2565+
2566+
bias = torch.randn(768)
2567+
input1 = torch.randn(2, 128, 16, 64)
2568+
input2 = torch.randn(16,64, 768)
2569+
model = EinsumAdd("bqhc,hco->bqo")
2570+
_test_fp32(model, input1, input2, bias)
2571+
2572+
bias = torch.randn(8)
2573+
input1 = torch.randn(8)
2574+
input2 = torch.randn(8)
2575+
model = EinsumAdd("i,i->")
2576+
_test_fp32(model, input1, input2, bias)
2577+
2578+
#the output of torch.einsum("ij,j") is tensor([])
2579+
bias = torch.randn(1)
2580+
input1 = torch.randn(0, 3)
2581+
input2 = torch.randn(3)
2582+
model = EinsumAdd(("ij,j"))
2583+
_test_fp32(model, input1, input2, bias)
2584+
24692585
def test_ipex_softmax(self):
24702586
self._test_output(
24712587
AtenSoftmaxRepalce(),

0 commit comments

Comments
 (0)