Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions onnxruntime/core/optimizer/matmul_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons
return false;
}

// Checks the first input of MatMul has 2 dimensions.
// The test for the second input is done in method Apply as it accesses the constant.
if (node.InputDefs()[0] == nullptr) {
// This should never happen but just in case.
return false;
}
auto shape_a = node.InputDefs()[0]->Shape();
if (shape_a == nullptr) {
// We cannot shape the rank. It is better to avoid fusing.
return false;
}
if (shape_a->dim_size() != 2) {
// Gemm only supports 2D tensors.
return false;
}

// First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
const auto& output_defs = batch_norm_node->OutputDefs();
if (output_defs.size() > 1) {
Expand Down Expand Up @@ -165,6 +181,8 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
bias_tensor->dims_size() != 1 ||
mean_tensor->dims_size() != 1 ||
var_tensor->dims_size() != 1 ||
// matmul_a_tensor->dims_size() != 2 ||
matmul_b_tensor->dims_size() != 2 ||
scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) {
}
}

TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::string expected_output_name;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "BatchNormalization") {
expected_output_name = node.OutputDefs()[0]->Name();
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["BatchNormalization"], 1);
ASSERT_EQ(op_to_count["MatMul"], 1);
ASSERT_EQ(op_to_count["Gemm"], 0);
}

TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx";

Expand Down
Binary file not shown.
Loading