diff --git a/test/cpp/tensorexpr/test_ops.cpp b/test/cpp/tensorexpr/test_ops.cpp index 379c901968d5..8121d0a8fa39 100644 --- a/test/cpp/tensorexpr/test_ops.cpp +++ b/test/cpp/tensorexpr/test_ops.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include using namespace torch::jit::tensorexpr; @@ -76,3 +77,64 @@ TEST(Ops, ChannelsLastSum) { ASSERT_TRUE(at::allclose(bt, ref)); } } + +TEST(Ops, LinearWithBias) { + const auto graph_string = R"IR( + graph(%x : Float(1, 16, strides=[16, 1], device=cpu), + %w : Float(8, 16, strides=[16, 1], device=cpu), + %b : Float(8, strides=[1], device=cpu)): + %1 : Float(1, 8, strides=[8, 1], device=cpu) = aten::linear(%x, %w, %b) + return (%1))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat)); + auto w = at::rand({8, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat)); + auto b = at::rand({8}, at::TensorOptions(c10::kCPU).dtype(at::kFloat)); + auto y_expected = at::linear(x, w, b); + + TensorExprKernel k(graph); + std::vector inputs = {x, w, b}; + + std::vector stack = at::fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + + bool check = at::allclose(y_expected, y); + if(!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST(Ops, LinearWithoutBias) { + const auto graph_string = R"IR( + graph(%x : Float(1, 16, strides=[16, 1], device=cpu), + %w : Float(8, 16, strides=[16, 1], device=cpu)): + %1 : NoneType = prim::Constant() + %2 : Float(1, 8, strides=[8, 1], device=cpu) = aten::linear(%x, %w, %1) + return (%2))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat)); + auto w = at::rand({8, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat)); + auto y_expected = at::linear(x, w); + + TensorExprKernel k(graph); + std::vector inputs = {x, w}; + + std::vector stack = at::fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + + bool check = at::allclose(y_expected, y); + if(!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} \ No newline at end of file diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 9736ce1b11f8..ada1e917e3fd 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -81,6 +81,7 @@ static const OperatorSet& supported_non_eltwise_set() { "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "aten::matmul(Tensor self, Tensor other) -> Tensor", + "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", }; // clang-format on return supported_non_eltwise_set; diff --git a/torch/csrc/jit/tensorexpr/graph_opt.cpp b/torch/csrc/jit/tensorexpr/graph_opt.cpp index ee18ec7b5d1f..9677a35385b6 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.cpp +++ b/torch/csrc/jit/tensorexpr/graph_opt.cpp @@ -1,10 +1,15 @@ #include +#include +#include #include +#include +#include #include #include #include #include +#include namespace torch { namespace jit { @@ -181,6 +186,169 @@ bool OptimizeCat(const std::shared_ptr& graph) { return false; } +void DecomposeCompositeOp( + std::shared_ptr graph, + std::string& pattern_op, + std::vector& target_ops, + const std::function&)>& op_map, + const std::function< + void(const std::vector&, const std::vector&)>& + op_shape_reform) { + Graph pattern_graph; + parseIR(pattern_op, &pattern_graph); + + std::vector values_to_rewrite; + std::unordered_map rewrite_map; + std::unordered_set nodes_to_delete_; + + // Locate the pattern graph and replace with target graph + const auto& matches = findPatternMatches(pattern_graph, *graph); + for (const Match& match : matches) { + // Collect original inputs and insertion point for target graph + Node* ins_point = nullptr; + std::vector inputs, outputs; + Graph target_graph; + for (const auto idx : c10::irange(pattern_graph.inputs().size())) { + Value* input = match.values_map.at(pattern_graph.inputs().at(idx)); + inputs.push_back(input); + if (!ins_point || ins_point->isBefore(input->node())) { + ins_point = input->node(); + } + } + AT_ASSERT(ins_point); + ins_point = ins_point->next(); + WithInsertPoint insert_point(ins_point); + + // Prepare the outputs for target graph + for (Value* v : pattern_graph.outputs()) { + Value* output = match.values_map.at(v); + outputs.push_back(output); + } + + // Select the right target graph based on inputs + size_t target_op_idx = op_map(inputs); + parseIR(target_ops[target_op_idx], &target_graph); + + // Insert the target graph + std::vector new_outputs = insertGraph(*graph, target_graph, inputs); + AT_ASSERT(outputs.size() == new_outputs.size()); + + // Reform shapes for the new Values + for (const auto idx : c10::irange(outputs.size())) { + new_outputs[idx]->setType(outputs[idx]->type()); + } + op_shape_reform(inputs, new_outputs); + + // Record all planned rewritings + for (const auto idx : c10::irange(outputs.size())) { + values_to_rewrite.push_back(outputs[idx]); + rewrite_map[outputs[idx]] = new_outputs[idx]; + } + + // Record all planned deletions + for (Node* pattern_n : pattern_graph.nodes()) { + if (match.nodes_map.count(pattern_n)) { + Node* n = match.nodes_map.at(pattern_n); + nodes_to_delete_.insert(n); + } + } + } + + // Perform planned rewritings + for (auto v : values_to_rewrite) { + v->replaceAllUsesWith(rewrite_map.at(v)); + } + + // Perform planned deletions + for (auto n : nodes_to_delete_) { + n->removeAllInputs(); + } + for (auto n : nodes_to_delete_) { + n->destroy(); + } + nodes_to_delete_.clear(); +} + +// Optimization pass to decompose an aten Op into a series of other Ops +// providing equivalent functionality. This helps TE to support more Ops like +// aten::linear which can be constructed with other Ops by nature, and other +// scenarios with performance benefitial or ease the lowering/optimization +// inside TE. +void DecomposeCompositeOps(std::shared_ptr graph) { + // The list of graph of the OP needs to be decompose. + std::vector pattern_op_list; + // The list of graph list that will be used as target. You can add several + // target graphs for a pattern OP if needed. + std::vector> target_op_list; + // The list of functions to select target graph based on inputs if several + // ones provided. + std::vector&)>> op_map_list; + // The list of functions to reformalize the value shapes inside target graph. + // This is due to the newly added values (neither graph input nor output) are + // lacks of shape information as not go through the profiling run phase, while + // shape information is required for TE lowering. + std::vector&, const std::vector&)>> + op_shape_reform_list; + + pattern_op_list.push_back( + R"(graph(%input, %weight, %bias): + %linear_out = aten::linear(%input, %weight, %bias) + return (%linear_out) )"); + target_op_list.push_back({ + R"(graph(%input, %weight, %bias): + %1 : int = prim::Constant[value=1]() + %weight_t = aten::t(%weight) + %mm_out = aten::matmul(%input, %weight_t) + %add_out = aten::add(%mm_out, %bias, %1) + return (%add_out) )", + R"(graph(%input, %weight, %bias): + %weight_t = aten::t(%weight) + %mm_out = aten::matmul(%input, %weight_t) + return (%mm_out) )"}); + op_map_list.push_back([&](const std::vector& inputs) { + return inputs[2]->mustBeNone() ? 1 : 0; + }); + op_shape_reform_list.push_back([&](const std::vector& inputs, + const std::vector& outputs) { + Node* mm = outputs[0]->node(); + if (!inputs[2]->mustBeNone()) { // with Bias + // Reformalize the output value of aten::matmul + auto add = outputs[0]->node(); + auto mm_out = add->inputs().at(0); + mm_out->setType(outputs[0]->type()->cast()); + mm = mm_out->node(); + } + // Reformalize the output value of aten::t + auto weight_t = mm->inputs().at(1); + auto weight_tensortype = + weight_t->node()->input()->type()->cast(); + auto ori_sizes = weight_tensortype->sizes(); + at::makeArrayRef({ori_sizes[1].value(), ori_sizes[0].value()}); + weight_t->setType(weight_tensortype->withSizes( + at::makeArrayRef({ori_sizes[1].value(), ori_sizes[0].value()}))); + }); + + for (const auto idx : c10::irange(pattern_op_list.size())) { + DecomposeCompositeOp( + graph, + pattern_op_list[idx], + target_op_list[idx], + op_map_list[idx], + op_shape_reform_list[idx]); + } +} + +bool DecomposeOps(const std::shared_ptr& graph) { + DecomposeCompositeOps(graph); + GRAPH_DUMP("After DecomposeCompositeOps:", graph); + EliminateCommonSubexpression(graph); + GRAPH_DUMP("After EliminateCommonSubexpression:", graph); + ConstantPooling(graph); + GRAPH_DUMP("After ConstantPooling:", graph); + return true; +} + void annotateInputShapes( const std::shared_ptr& graph, const std::vector>& example_inputs) { diff --git a/torch/csrc/jit/tensorexpr/graph_opt.h b/torch/csrc/jit/tensorexpr/graph_opt.h index 1180d0ac438b..cec1ee4a6daf 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.h +++ b/torch/csrc/jit/tensorexpr/graph_opt.h @@ -57,6 +57,7 @@ namespace tensorexpr { // aten_cat is the output buffer here. bool OptimizeCat(const std::shared_ptr& graph); +bool DecomposeOps(const std::shared_ptr& graph); TORCH_API void annotateInputShapes( const std::shared_ptr& graph, diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 5c7c783e1b78..0309ac29e7e1 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1611,6 +1611,10 @@ void TensorExprKernel::optimizeOwningGraph() { // Optimize the concatenation OptimizeCat(graph_); + GRAPH_DUMP("Before DecomposeOps:", graph_); + DecomposeOps(graph_); + GRAPH_DUMP("After DecomposeOps:", graph_); + // Synchronize the symbolic strides information auto graph_outputs = graph_->outputs(); TORCH_INTERNAL_ASSERT(graph_outputs.size() == _orignal_graph_outputs.size());