Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions test/cpp/tensorexpr/test_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/torch.h>

using namespace torch::jit::tensorexpr;
Expand Down Expand Up @@ -76,3 +77,64 @@ TEST(Ops, ChannelsLastSum) {
ASSERT_TRUE(at::allclose(bt, ref));
}
}

TEST(Ops, LinearWithBias) {
const auto graph_string = R"IR(
graph(%x : Float(1, 16, strides=[16, 1], device=cpu),
%w : Float(8, 16, strides=[16, 1], device=cpu),
%b : Float(8, strides=[1], device=cpu)):
%1 : Float(1, 8, strides=[8, 1], device=cpu) = aten::linear(%x, %w, %b)
return (%1))IR";
auto graph = std::make_shared<torch::jit::Graph>();
parseIR(graph_string, &*graph);

auto x = at::rand({1, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat));
auto w = at::rand({8, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat));
auto b = at::rand({8}, at::TensorOptions(c10::kCPU).dtype(at::kFloat));
auto y_expected = at::linear(x, w, b);

TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x, w, b};

std::vector<c10::IValue> stack = at::fmap<c10::IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();

bool check = at::allclose(y_expected, y);
if(!check) {
std::cout << "x:\n" << x << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
TORCH_CHECK_EQ(check, 1);
}

TEST(Ops, LinearWithoutBias) {
const auto graph_string = R"IR(
graph(%x : Float(1, 16, strides=[16, 1], device=cpu),
%w : Float(8, 16, strides=[16, 1], device=cpu)):
%1 : NoneType = prim::Constant()
%2 : Float(1, 8, strides=[8, 1], device=cpu) = aten::linear(%x, %w, %1)
return (%2))IR";
auto graph = std::make_shared<torch::jit::Graph>();
parseIR(graph_string, &*graph);

auto x = at::rand({1, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat));
auto w = at::rand({8, 16}, at::TensorOptions(c10::kCPU).dtype(at::kFloat));
auto y_expected = at::linear(x, w);

TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x, w};

std::vector<c10::IValue> stack = at::fmap<c10::IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();

bool check = at::allclose(y_expected, y);
if(!check) {
std::cout << "x:\n" << x << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
TORCH_CHECK_EQ(check, 1);
}
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ static const OperatorSet& supported_non_eltwise_set() {
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
"aten::matmul(Tensor self, Tensor other) -> Tensor",
"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor",
};
// clang-format on
return supported_non_eltwise_set;
Expand Down
168 changes: 168 additions & 0 deletions torch/csrc/jit/tensorexpr/graph_opt.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
#include <torch/csrc/jit/tensorexpr/graph_opt.h>

#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <functional>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -181,6 +186,169 @@ bool OptimizeCat(const std::shared_ptr<Graph>& graph) {
return false;
}

void DecomposeCompositeOp(
std::shared_ptr<Graph> graph,
std::string& pattern_op,
std::vector<std::string>& target_ops,
const std::function<size_t(const std::vector<Value*>&)>& op_map,
const std::function<
void(const std::vector<Value*>&, const std::vector<Value*>&)>&
op_shape_reform) {
Graph pattern_graph;
parseIR(pattern_op, &pattern_graph);

std::vector<Value*> values_to_rewrite;
std::unordered_map<Value*, Value*> rewrite_map;
std::unordered_set<Node*> nodes_to_delete_;

// Locate the pattern graph and replace with target graph
const auto& matches = findPatternMatches(pattern_graph, *graph);
for (const Match& match : matches) {
// Collect original inputs and insertion point for target graph
Node* ins_point = nullptr;
std::vector<Value*> inputs, outputs;
Graph target_graph;
for (const auto idx : c10::irange(pattern_graph.inputs().size())) {
Value* input = match.values_map.at(pattern_graph.inputs().at(idx));
inputs.push_back(input);
if (!ins_point || ins_point->isBefore(input->node())) {
ins_point = input->node();
}
}
AT_ASSERT(ins_point);
ins_point = ins_point->next();
WithInsertPoint insert_point(ins_point);

// Prepare the outputs for target graph
for (Value* v : pattern_graph.outputs()) {
Value* output = match.values_map.at(v);
outputs.push_back(output);
}

// Select the right target graph based on inputs
size_t target_op_idx = op_map(inputs);
parseIR(target_ops[target_op_idx], &target_graph);

// Insert the target graph
std::vector<Value*> new_outputs = insertGraph(*graph, target_graph, inputs);
AT_ASSERT(outputs.size() == new_outputs.size());

// Reform shapes for the new Values
for (const auto idx : c10::irange(outputs.size())) {
new_outputs[idx]->setType(outputs[idx]->type());
}
op_shape_reform(inputs, new_outputs);

// Record all planned rewritings
for (const auto idx : c10::irange(outputs.size())) {
values_to_rewrite.push_back(outputs[idx]);
rewrite_map[outputs[idx]] = new_outputs[idx];
}

// Record all planned deletions
for (Node* pattern_n : pattern_graph.nodes()) {
if (match.nodes_map.count(pattern_n)) {
Node* n = match.nodes_map.at(pattern_n);
nodes_to_delete_.insert(n);
}
}
}

// Perform planned rewritings
for (auto v : values_to_rewrite) {
v->replaceAllUsesWith(rewrite_map.at(v));
}

// Perform planned deletions
for (auto n : nodes_to_delete_) {
n->removeAllInputs();
}
for (auto n : nodes_to_delete_) {
n->destroy();
}
nodes_to_delete_.clear();
}

// Optimization pass to decompose an aten Op into a series of other Ops
// providing equivalent functionality. This helps TE to support more Ops like
// aten::linear which can be constructed with other Ops by nature, and other
// scenarios with performance benefitial or ease the lowering/optimization
// inside TE.
void DecomposeCompositeOps(std::shared_ptr<Graph> graph) {
// The list of graph of the OP needs to be decompose.
std::vector<std::string> pattern_op_list;
// The list of graph list that will be used as target. You can add several
// target graphs for a pattern OP if needed.
std::vector<std::vector<std::string>> target_op_list;
// The list of functions to select target graph based on inputs if several
// ones provided.
std::vector<std::function<size_t(const std::vector<Value*>&)>> op_map_list;
// The list of functions to reformalize the value shapes inside target graph.
// This is due to the newly added values (neither graph input nor output) are
// lacks of shape information as not go through the profiling run phase, while
// shape information is required for TE lowering.
std::vector<std::function<void(
const std::vector<Value*>&, const std::vector<Value*>&)>>
op_shape_reform_list;

pattern_op_list.push_back(
R"(graph(%input, %weight, %bias):
%linear_out = aten::linear(%input, %weight, %bias)
return (%linear_out) )");
target_op_list.push_back({
R"(graph(%input, %weight, %bias):
%1 : int = prim::Constant[value=1]()
%weight_t = aten::t(%weight)
%mm_out = aten::matmul(%input, %weight_t)
%add_out = aten::add(%mm_out, %bias, %1)
return (%add_out) )",
R"(graph(%input, %weight, %bias):
%weight_t = aten::t(%weight)
%mm_out = aten::matmul(%input, %weight_t)
return (%mm_out) )"});
op_map_list.push_back([&](const std::vector<Value*>& inputs) {
return inputs[2]->mustBeNone() ? 1 : 0;
});
op_shape_reform_list.push_back([&](const std::vector<Value*>& inputs,
const std::vector<Value*>& outputs) {
Node* mm = outputs[0]->node();
if (!inputs[2]->mustBeNone()) { // with Bias
// Reformalize the output value of aten::matmul
auto add = outputs[0]->node();
auto mm_out = add->inputs().at(0);
mm_out->setType(outputs[0]->type()->cast<TensorType>());
mm = mm_out->node();
}
// Reformalize the output value of aten::t
auto weight_t = mm->inputs().at(1);
auto weight_tensortype =
weight_t->node()->input()->type()->cast<TensorType>();
auto ori_sizes = weight_tensortype->sizes();
at::makeArrayRef({ori_sizes[1].value(), ori_sizes[0].value()});
weight_t->setType(weight_tensortype->withSizes(
at::makeArrayRef({ori_sizes[1].value(), ori_sizes[0].value()})));
});

for (const auto idx : c10::irange(pattern_op_list.size())) {
DecomposeCompositeOp(
graph,
pattern_op_list[idx],
target_op_list[idx],
op_map_list[idx],
op_shape_reform_list[idx]);
}
}

bool DecomposeOps(const std::shared_ptr<Graph>& graph) {
DecomposeCompositeOps(graph);
GRAPH_DUMP("After DecomposeCompositeOps:", graph);
EliminateCommonSubexpression(graph);
GRAPH_DUMP("After EliminateCommonSubexpression:", graph);
ConstantPooling(graph);
GRAPH_DUMP("After ConstantPooling:", graph);
return true;
}

void annotateInputShapes(
const std::shared_ptr<Graph>& graph,
const std::vector<c10::optional<at::Tensor>>& example_inputs) {
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/tensorexpr/graph_opt.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace tensorexpr {
// aten_cat is the output buffer here.

bool OptimizeCat(const std::shared_ptr<Graph>& graph);
bool DecomposeOps(const std::shared_ptr<Graph>& graph);

TORCH_API void annotateInputShapes(
const std::shared_ptr<Graph>& graph,
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,10 @@ void TensorExprKernel::optimizeOwningGraph() {
// Optimize the concatenation
OptimizeCat(graph_);

GRAPH_DUMP("Before DecomposeOps:", graph_);
DecomposeOps(graph_);
GRAPH_DUMP("After DecomposeOps:", graph_);

// Synchronize the symbolic strides information
auto graph_outputs = graph_->outputs();
TORCH_INTERNAL_ASSERT(graph_outputs.size() == _orignal_graph_outputs.size());
Expand Down