Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions test/cpp/tensorexpr/test_external_calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,5 +1061,110 @@ TEST(ExternalCall, JitCustomFusionOp) {
#endif
}

TEST(ExternalCall, JitCustomQuantOp) {
const char* custom_op_schema_literal =
"nnc_custom::special_mul(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc";
const char* external_func_name = "nnc_special_mul";

auto special_mul_lowering_func =
[external_func_name](
const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
const c10::optional<torch::jit::tensorexpr::ScalarType>& output_type,
at::Device device) {
auto output_dtype = Dtype(*output_type);
const torch::jit::tensorexpr::BufHandle& qa =
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
const torch::jit::tensorexpr::BufHandle& qb =
c10::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
const auto out_qscale = DoubleImm::make(c10::get<double>(inputs[2])).node();
const auto out_qzero = LongImm::make(c10::get<int64_t>(inputs[3])).node();

BufHandle result_buf("nnc_special_mul_res_buf",
output_shape,
output_dtype);
result_buf.node()->set_qscale(out_qscale);
result_buf.node()->set_qzero(out_qzero);

torch::jit::tensorexpr::StmtPtr s =
torch::jit::tensorexpr::ExternalCall::make(
result_buf,
external_func_name,
{qa, qb},
{ExprHandle(qa.node()->qscale()),
ExprHandle(qa.node()->qzero()),
(int64_t)qa.dtype().scalar_type(),
ExprHandle(qb.node()->qscale()),
ExprHandle(qb.node()->qzero()),
(int64_t)qb.dtype().scalar_type(),
ExprHandle(out_qscale),
ExprHandle(out_qzero)});
return Tensor(result_buf.node(), s);
};

auto special_mul_external_func = [](int64_t bufs_num,
void** buf_data,
int64_t* buf_ranks,
int64_t* buf_dims,
int64_t* buf_strides,
int8_t* buf_dtypes,
int64_t args_num,
int64_t* extra_args) {};

torch::jit::RegisterOperators reg({Operator(
custom_op_schema_literal,
[](const Node* node) -> Operation {
return [](Stack& _stack) {
auto a = std::move(peek(_stack, 0, 2)).toTensor();
auto b = std::move(peek(_stack, 1, 2)).toTensor();
drop(_stack, 2);
auto result = a * b;
pack(_stack, std::move(result));
return 0;
};
},
c10::AliasAnalysisKind::FROM_SCHEMA)});

auto& custom_operator_set = torch::jit::tensorexpr::getQuantizationOperatorSet();
custom_operator_set.insert({custom_op_schema_literal});

auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry();
te_lowering_registry.insert(
parseSchema(custom_op_schema_literal), special_mul_lowering_func);

auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry();
te_nnc_func_registry[external_func_name] = special_mul_external_func;

const auto graph_string = R"IR(
graph(%x1 : Float(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu),
%x2 : Float(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu)):
%2 : int = prim::Constant[value=12]()
%qz1 : int = prim::Constant[value=13]()
%qs1 : float = prim::Constant[value=0.1]()
%qz2 : int = prim::Constant[value=13]()
%qs2 : float = prim::Constant[value=0.1]()
%qza : int = prim::Constant[value=13]()
%qsa : float = prim::Constant[value=0.1]()
%q1 : QUInt8(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
%q2 : QUInt8(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
%qa : QUInt8(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = nnc_custom::special_mul(%q1, %q2, %qsa, %qza)
%6 : Float(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = aten::dequantize(%qa)
return (%6))IR";

auto graph = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, graph.get());

#ifdef TORCH_ENABLE_LLVM
setTexprQuantEnabled(true);
FuseTensorExprs(graph, 1);
torch::jit::testing::FileCheck()
.check("prim::TensorExprGroup_")
->check("nnc_custom::special_mul")
->run(*graph);
#else
torch::jit::testing::FileCheck().check("nnc_custom::special_mul")->run(*graph);
#endif
}
} // namespace jit
} // namespace torch
83 changes: 83 additions & 0 deletions test/cpp/tensorexpr/test_graph_opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,5 +316,88 @@ TEST_F(GraphOpt, AOTGraphPrepPasses) {
testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
}

TEST_F(GraphOpt, DecomposeQuantizationAddandMul) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(8, 16, strides=[16, 1], device=cpu)):
%1 : int = prim::Constant[value=0]()
%2 : float = prim::Constant[value=0.86077326536178589]()
%3 : int = prim::Constant[value=63]()
%4 : float = prim::Constant[value=0.16712260246276855]()
%5 : int = prim::Constant[value=60]()
%6 : float = prim::Constant[value=0.08735954761505127]()
%7 : int = prim::Constant[value=13]()
%8 : Long(device=cpu) = prim::Constant[value={59}]()
%9 : Float(device=cpu) = prim::Constant[value={0.0421975}]()
%10 : QUInt8(8, 16, strides=[16, 1], device=cpu) = aten::quantize_per_tensor(%x, %9, %8, %7)
%11 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::add(%10, %10, %6, %5)
%12 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::add(%11, %11, %4, %3)
%mul_1.1 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::mul(%12, %12, %2, %1)
%14 : Float(8, 16, strides=[16, 1], device=cpu) = aten::dequantize(%mul_1.1)
return (%14))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get(), true);
g->lint();

TensorExprKernel kernel(g);

// Quantized::add should be decomposed to be
// aten::dequantize/aten::add/aten::quantize_per_tensor
// Similar for quantized::mul
testing::FileCheck()
.check("aten::quantize_per_tensor")
->check("aten::dequantize")
->check("aten::add")
->check("aten::quantize_per_tensor")
->check("aten::dequantize")
->check("aten::add")
->check("aten::quantize_per_tensor")
->check("aten::dequantize")
->check("aten::mul")
->check("aten::quantize_per_tensor")
->check("aten::dequantize")
->run(*kernel.graph());
#endif
}

TEST_F(GraphOpt, DecomposeQuantizationAddRelu) {
#ifdef TORCH_ENABLE_LLVM
const auto graph_string = R"IR(
graph(%x : Float(8, 16, strides=[16, 1], device=cpu)):
%1 : float = prim::Constant[value=0.20982688665390015]()
%2 : int = prim::Constant[value=0]()
%3 : float = prim::Constant[value=0.045946117490530014]()
%4 : int = prim::Constant[value=13]()
%5 : Long(device=cpu) = prim::Constant[value={59}]()
%6 : Float(device=cpu) = prim::Constant[value={0.0421975}]()
%7 : QUInt8(8, 16, strides=[16, 1], device=cpu) = aten::quantize_per_tensor(%x, %6, %5, %4)
%8 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::add_relu(%7, %7, %3, %2)
%mul_1.1 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::mul(%8, %8, %1, %2)
%10 : Float(8, 16, strides=[16, 1], device=cpu) = aten::dequantize(%mul_1.1)
return (%10))IR";
auto g = std::make_shared<Graph>();
torch::jit::parseIR(graph_string, g.get(), true);
g->lint();

TensorExprKernel kernel(g);

// Quantized::add_relu should be decomposed to be
// aten::dequantize/aten::add/aten::relu/aten::quantize_per_tensor
// Quantized::mul should be decomposed to be
// aten::dequantize/aten::mul/aten::quantize_per_tensor
testing::FileCheck()
.check("aten::quantize_per_tensor")
->check("aten::dequantize")
->check("aten::add")
->check("aten::relu")
->check("aten::quantize_per_tensor")
->check("aten::dequantize")
->check("aten::mul")
->check("aten::quantize_per_tensor")
->check("aten::dequantize")
->run(*kernel.graph());
#endif
}

} // namespace jit
} // namespace torch
99 changes: 99 additions & 0 deletions test/cpp/tensorexpr/test_quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,105 @@ TEST_F(Quantization, QuantAddDequantUInt8) {
CHECK_EQ(check, 1);
}

at::Tensor quantized_add_relu(
at::Tensor x1,
at::Tensor x2,
double scale,
int64_t zero) {
const auto qadd_op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("quantized::add_relu", "")
.typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
return qadd_op.call(x1, x2, scale, zero);
}

TEST_F(Quantization, QuantAddReluDequantInt8) {
const auto graph_string = R"IR(
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
%2 : int = prim::Constant[value=12]()
%qz1 : int = prim::Constant[value=13]()
%qs1 : float = prim::Constant[value=0.1]()
%qz2 : int = prim::Constant[value=13]()
%qs2 : float = prim::Constant[value=0.1]()
%qza : int = prim::Constant[value=13]()
%qsa : float = prim::Constant[value=0.1]()
%q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
%q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
%qa : QInt8(2, 2) = quantized::add_relu(%q1, %q2, %qsa, %qza)
%6 : Float(2, 2) = aten::dequantize(%qa)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8);
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8);
auto qa = quantized_add_relu(q1, q2, 0.1f, 13);
auto y_expected = at::dequantize(qa);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x1, x2};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();
bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "x1:\n" << x1 << std::endl;
std::cout << "q1:\n" << q1 << std::endl;
std::cout << "x2:\n" << x2 << std::endl;
std::cout << "q2:\n" << q2 << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

TEST_F(Quantization, QuantAddReluDequantUInt8) {
const auto graph_string = R"IR(
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
%2 : int = prim::Constant[value=13]()
%qz1 : int = prim::Constant[value=13]()
%qs1 : float = prim::Constant[value=0.1]()
%qz2 : int = prim::Constant[value=13]()
%qs2 : float = prim::Constant[value=0.1]()
%qza : int = prim::Constant[value=13]()
%qsa : float = prim::Constant[value=0.1]()
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
%q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
%qa : QUInt8(2, 2) = quantized::add_relu(%q1, %q2, %qsa, %qza)
%6 : Float(2, 2) = aten::dequantize(%qa)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
auto qa = quantized_add_relu(q1, q2, 0.1f, 13);
auto y_expected = at::dequantize(qa);

TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x1, x2};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();
bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "x1:\n" << x1 << std::endl;
std::cout << "q1:\n" << q1 << std::endl;
std::cout << "x2:\n" << x2 << std::endl;
std::cout << "q2:\n" << q2 << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

TEST_F(Quantization, QuantSigmoidDequantUInt8) {
const auto graph_string = R"IR(
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)):
Expand Down
37 changes: 35 additions & 2 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace torch {
namespace jit {

static bool texpr_reductions_enabled = false;
static bool texpr_quant_enabled = false;

bool isSupportedForBlock(Node* node) {
switch (node->kind()) {
Expand Down Expand Up @@ -75,6 +76,25 @@ OperatorSet& getCustomOperatorSet() {
return _g_custom_operator_set;
}

OperatorSet& getQuantizationOperatorSet() {
// clang-format off
static OperatorSet _g_supported_quantization_set{
"aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
"aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
"aten::dequantize.self(Tensor self) -> Tensor",
"quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
"quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc",
"quantized::matmul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc",
"quantized::add_relu(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
"quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)",
"quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)",
"quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)",
"quantized::linear_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)",
Comment on lines +85 to +92
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These operators serve the FX front-end. So I think another option is to exclude these operators from the quantization set and enable the texpr_quant_enabled by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we can combine IPEX front-end and NNC quantization as a mature solution and then enhance IPEX INT8 OOB performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile, I'd recommend exposing an interface to get this quantization operation set to extend it. Just like getCustomOperatorSet

};
// clang-format on
return _g_supported_quantization_set;
};

static const OperatorSet& supported_non_eltwise_set() {
// clang-format off
static const OperatorSet supported_non_eltwise_set{
Expand Down Expand Up @@ -102,12 +122,14 @@ bool isSupported(Node* node) {
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
};

// clang-format on

if (get_tensorexpr_elementwise_set().contains(node) ||
node->isMemberOf(supported_non_eltwise_set()) ||
node->isMemberOf(supported_misc_set) ||
node->isMemberOf(getCustomOperatorSet()) ||
(texpr_quant_enabled && node->isMemberOf(getQuantizationOperatorSet())) ||
(texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
// We only insert guards on Tensor types, so we rely on the output
// of a node being uniquely determined by its input types.
Expand Down Expand Up @@ -180,10 +202,20 @@ bool setTexprReductionsEnabled(bool value) {
return old_value;
}

bool setTexprQuantEnabled(bool value) {
bool old_value = texpr_quant_enabled;
texpr_quant_enabled = value;
return old_value;
}

bool texprReductionsEnabled() {
return texpr_reductions_enabled;
}

bool texprQuantEnabled() {
return texpr_quant_enabled;
}

void removeProfileNodesAndSpecializeTypes(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
if (it->kind() == prim::profile) {
Expand Down Expand Up @@ -489,7 +521,8 @@ class TensorExprFuser {
// non-elementwise like batch_norm, conv, matmul, and
// a few exceptions (e.g. prim::ConstantChunk, etc) listed above
if (!(get_tensorexpr_elementwise_set().contains(n)) &&
!n->isMemberOf(tensorexpr::supported_non_eltwise_set())) {
!n->isMemberOf(tensorexpr::supported_non_eltwise_set()) &&
!n->isMemberOf(tensorexpr::getQuantizationOperatorSet())) {
continue;
}

Expand Down Expand Up @@ -1079,7 +1112,7 @@ class TensorExprFuser {
// All tensor types should be known.
return false;
}
if (c10::isComplexType(*st) || c10::isQIntType(*st)) {
if (c10::isComplexType(*st)) {
return false;
}
}
Expand Down
Loading