Skip to content

Commit 3d0b366

Browse files
xadupreankitm3k
authored andcommitted
Fix MatMulBnFusion to exclude cases when tensors are not 2D tensors (microsoft#22762)
### Description Fixes microsoft#22512, MatMul, Add can be fused into a single Gemm even if tensors dimensions are > 2. The PR excludes that cases. ### Motivation and Context ORT crashes on valid models due to that unexpected fusion.
1 parent 2006a22 commit 3d0b366

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

onnxruntime/core/optimizer/matmul_bn_fusion.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons
107107
return false;
108108
}
109109

110+
// Checks the first input of MatMul has 2 dimensions.
111+
// The test for the second input is done in method Apply as it accesses the constant.
112+
if (node.InputDefs()[0] == nullptr) {
113+
// This should never happen but just in case.
114+
return false;
115+
}
116+
auto shape_a = node.InputDefs()[0]->Shape();
117+
if (shape_a == nullptr) {
118+
// We cannot shape the rank. It is better to avoid fusing.
119+
return false;
120+
}
121+
if (shape_a->dim_size() != 2) {
122+
// Gemm only supports 2D tensors.
123+
return false;
124+
}
125+
110126
// First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
111127
const auto& output_defs = batch_norm_node->OutputDefs();
112128
if (output_defs.size() > 1) {
@@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
165181
bias_tensor->dims_size() != 1 ||
166182
mean_tensor->dims_size() != 1 ||
167183
var_tensor->dims_size() != 1 ||
184+
matmul_b_tensor->dims_size() != 2 ||
168185
scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
169186
bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
170187
mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,6 +1764,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) {
17641764
}
17651765
}
17661766

1767+
TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) {
1768+
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx";
1769+
1770+
std::shared_ptr<Model> p_model;
1771+
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
1772+
Graph& graph = p_model->MainGraph();
1773+
1774+
std::string expected_output_name;
1775+
GraphViewer graphViewer(graph);
1776+
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
1777+
auto& node = *graph.GetNode(node_index);
1778+
if (node.OpType() == "BatchNormalization") {
1779+
expected_output_name = node.OutputDefs()[0]->Name();
1780+
}
1781+
}
1782+
1783+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1784+
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
1785+
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
1786+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
1787+
1788+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
1789+
1790+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1791+
ASSERT_EQ(op_to_count["BatchNormalization"], 1);
1792+
ASSERT_EQ(op_to_count["MatMul"], 1);
1793+
ASSERT_EQ(op_to_count["Gemm"], 0);
1794+
}
1795+
17671796
TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) {
17681797
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx";
17691798

517 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)