diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp index 88b75667b654..f24b93ac71fb 100644 --- a/test/cpp/tensorexpr/test_external_calls.cpp +++ b/test/cpp/tensorexpr/test_external_calls.cpp @@ -1061,5 +1061,110 @@ TEST(ExternalCall, JitCustomFusionOp) { #endif } +TEST(ExternalCall, JitCustomQuantOp) { + const char* custom_op_schema_literal = + "nnc_custom::special_mul(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc"; + const char* external_func_name = "nnc_special_mul"; + + auto special_mul_lowering_func = + [external_func_name]( + const std::vector& inputs, + const std::vector& output_shape, + const std::vector& output_strides, + const c10::optional& output_type, + at::Device device) { + auto output_dtype = Dtype(*output_type); + const torch::jit::tensorexpr::BufHandle& qa = + c10::get(inputs[0]); + const torch::jit::tensorexpr::BufHandle& qb = + c10::get(inputs[1]); + const auto out_qscale = DoubleImm::make(c10::get(inputs[2])).node(); + const auto out_qzero = LongImm::make(c10::get(inputs[3])).node(); + + BufHandle result_buf("nnc_special_mul_res_buf", + output_shape, + output_dtype); + result_buf.node()->set_qscale(out_qscale); + result_buf.node()->set_qzero(out_qzero); + + torch::jit::tensorexpr::StmtPtr s = + torch::jit::tensorexpr::ExternalCall::make( + result_buf, + external_func_name, + {qa, qb}, + {ExprHandle(qa.node()->qscale()), + ExprHandle(qa.node()->qzero()), + (int64_t)qa.dtype().scalar_type(), + ExprHandle(qb.node()->qscale()), + ExprHandle(qb.node()->qzero()), + (int64_t)qb.dtype().scalar_type(), + ExprHandle(out_qscale), + ExprHandle(out_qzero)}); + return Tensor(result_buf.node(), s); + }; + + auto special_mul_external_func = [](int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int64_t* buf_strides, + int8_t* buf_dtypes, + int64_t args_num, + int64_t* extra_args) {}; + + torch::jit::RegisterOperators reg({Operator( + custom_op_schema_literal, + [](const Node* node) -> Operation { + return [](Stack& _stack) { + auto a = std::move(peek(_stack, 0, 2)).toTensor(); + auto b = std::move(peek(_stack, 1, 2)).toTensor(); + drop(_stack, 2); + auto result = a * b; + pack(_stack, std::move(result)); + return 0; + }; + }, + c10::AliasAnalysisKind::FROM_SCHEMA)}); + + auto& custom_operator_set = torch::jit::tensorexpr::getQuantizationOperatorSet(); + custom_operator_set.insert({custom_op_schema_literal}); + + auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry(); + te_lowering_registry.insert( + parseSchema(custom_op_schema_literal), special_mul_lowering_func); + + auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry(); + te_nnc_func_registry[external_func_name] = special_mul_external_func; + + const auto graph_string = R"IR( + graph(%x1 : Float(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu), + %x2 : Float(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu)): + %2 : int = prim::Constant[value=12]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QUInt8(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QUInt8(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QUInt8(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = nnc_custom::special_mul(%q1, %q2, %qsa, %qza) + %6 : Float(1, 3, 224, 224, strides=[150528, 1, 672, 3], device=cpu) = aten::dequantize(%qa) + return (%6))IR"; + + auto graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + +#ifdef TORCH_ENABLE_LLVM + setTexprQuantEnabled(true); + FuseTensorExprs(graph, 1); + torch::jit::testing::FileCheck() + .check("prim::TensorExprGroup_") + ->check("nnc_custom::special_mul") + ->run(*graph); +#else + torch::jit::testing::FileCheck().check("nnc_custom::special_mul")->run(*graph); +#endif +} } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp index 649bbb2d33ac..1ce19598bfb0 100644 --- a/test/cpp/tensorexpr/test_graph_opt.cpp +++ b/test/cpp/tensorexpr/test_graph_opt.cpp @@ -316,5 +316,88 @@ TEST_F(GraphOpt, AOTGraphPrepPasses) { testing::FileCheck().check("return (%x, %y, %z)")->run(*g); } +TEST_F(GraphOpt, DecomposeQuantizationAddandMul) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(8, 16, strides=[16, 1], device=cpu)): + %1 : int = prim::Constant[value=0]() + %2 : float = prim::Constant[value=0.86077326536178589]() + %3 : int = prim::Constant[value=63]() + %4 : float = prim::Constant[value=0.16712260246276855]() + %5 : int = prim::Constant[value=60]() + %6 : float = prim::Constant[value=0.08735954761505127]() + %7 : int = prim::Constant[value=13]() + %8 : Long(device=cpu) = prim::Constant[value={59}]() + %9 : Float(device=cpu) = prim::Constant[value={0.0421975}]() + %10 : QUInt8(8, 16, strides=[16, 1], device=cpu) = aten::quantize_per_tensor(%x, %9, %8, %7) + %11 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::add(%10, %10, %6, %5) + %12 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::add(%11, %11, %4, %3) + %mul_1.1 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::mul(%12, %12, %2, %1) + %14 : Float(8, 16, strides=[16, 1], device=cpu) = aten::dequantize(%mul_1.1) + return (%14))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get(), true); + g->lint(); + + TensorExprKernel kernel(g); + + // Quantized::add should be decomposed to be + // aten::dequantize/aten::add/aten::quantize_per_tensor + // Similar for quantized::mul + testing::FileCheck() + .check("aten::quantize_per_tensor") + ->check("aten::dequantize") + ->check("aten::add") + ->check("aten::quantize_per_tensor") + ->check("aten::dequantize") + ->check("aten::add") + ->check("aten::quantize_per_tensor") + ->check("aten::dequantize") + ->check("aten::mul") + ->check("aten::quantize_per_tensor") + ->check("aten::dequantize") + ->run(*kernel.graph()); +#endif +} + +TEST_F(GraphOpt, DecomposeQuantizationAddRelu) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(8, 16, strides=[16, 1], device=cpu)): + %1 : float = prim::Constant[value=0.20982688665390015]() + %2 : int = prim::Constant[value=0]() + %3 : float = prim::Constant[value=0.045946117490530014]() + %4 : int = prim::Constant[value=13]() + %5 : Long(device=cpu) = prim::Constant[value={59}]() + %6 : Float(device=cpu) = prim::Constant[value={0.0421975}]() + %7 : QUInt8(8, 16, strides=[16, 1], device=cpu) = aten::quantize_per_tensor(%x, %6, %5, %4) + %8 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::add_relu(%7, %7, %3, %2) + %mul_1.1 : QUInt8(8, 16, strides=[16, 1], device=cpu) = quantized::mul(%8, %8, %1, %2) + %10 : Float(8, 16, strides=[16, 1], device=cpu) = aten::dequantize(%mul_1.1) + return (%10))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get(), true); + g->lint(); + + TensorExprKernel kernel(g); + + // Quantized::add_relu should be decomposed to be + // aten::dequantize/aten::add/aten::relu/aten::quantize_per_tensor + // Quantized::mul should be decomposed to be + // aten::dequantize/aten::mul/aten::quantize_per_tensor + testing::FileCheck() + .check("aten::quantize_per_tensor") + ->check("aten::dequantize") + ->check("aten::add") + ->check("aten::relu") + ->check("aten::quantize_per_tensor") + ->check("aten::dequantize") + ->check("aten::mul") + ->check("aten::quantize_per_tensor") + ->check("aten::dequantize") + ->run(*kernel.graph()); +#endif +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp index a6049c5ac0b6..598a7c0f1644 100644 --- a/test/cpp/tensorexpr/test_quantization.cpp +++ b/test/cpp/tensorexpr/test_quantization.cpp @@ -221,6 +221,105 @@ TEST_F(Quantization, QuantAddDequantUInt8) { CHECK_EQ(check, 1); } +at::Tensor quantized_add_relu( + at::Tensor x1, + at::Tensor x2, + double scale, + int64_t zero) { + const auto qadd_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("quantized::add_relu", "") + .typed(); + return qadd_op.call(x1, x2, scale, zero); +} + +TEST_F(Quantization, QuantAddReluDequantInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=12]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QInt8(2, 2) = quantized::add_relu(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8); + auto qa = quantized_add_relu(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantAddReluDequantUInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QUInt8(2, 2) = quantized::add_relu(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); + auto qa = quantized_add_relu(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + CHECK_EQ(check, 1); +} + TEST_F(Quantization, QuantSigmoidDequantUInt8) { const auto graph_string = R"IR( graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)): diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index fe63de1a04d4..9e1b5368a50f 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -40,6 +40,7 @@ namespace torch { namespace jit { static bool texpr_reductions_enabled = false; +static bool texpr_quant_enabled = false; bool isSupportedForBlock(Node* node) { switch (node->kind()) { @@ -75,6 +76,25 @@ OperatorSet& getCustomOperatorSet() { return _g_custom_operator_set; } +OperatorSet& getQuantizationOperatorSet() { + // clang-format off + static OperatorSet _g_supported_quantization_set{ + "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", + "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", + "aten::dequantize.self(Tensor self) -> Tensor", + "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", + "quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc", + "quantized::matmul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc", + "quantized::add_relu(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", + "quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)", + "quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)", + "quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)", + "quantized::linear_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)", + }; + // clang-format on + return _g_supported_quantization_set; +}; + static const OperatorSet& supported_non_eltwise_set() { // clang-format off static const OperatorSet supported_non_eltwise_set{ @@ -102,12 +122,14 @@ bool isSupported(Node* node) { "aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", }; + // clang-format on if (get_tensorexpr_elementwise_set().contains(node) || node->isMemberOf(supported_non_eltwise_set()) || node->isMemberOf(supported_misc_set) || node->isMemberOf(getCustomOperatorSet()) || + (texpr_quant_enabled && node->isMemberOf(getQuantizationOperatorSet())) || (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) { // We only insert guards on Tensor types, so we rely on the output // of a node being uniquely determined by its input types. @@ -180,10 +202,20 @@ bool setTexprReductionsEnabled(bool value) { return old_value; } +bool setTexprQuantEnabled(bool value) { + bool old_value = texpr_quant_enabled; + texpr_quant_enabled = value; + return old_value; +} + bool texprReductionsEnabled() { return texpr_reductions_enabled; } +bool texprQuantEnabled() { + return texpr_quant_enabled; +} + void removeProfileNodesAndSpecializeTypes(Block* b) { for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { if (it->kind() == prim::profile) { @@ -489,7 +521,8 @@ class TensorExprFuser { // non-elementwise like batch_norm, conv, matmul, and // a few exceptions (e.g. prim::ConstantChunk, etc) listed above if (!(get_tensorexpr_elementwise_set().contains(n)) && - !n->isMemberOf(tensorexpr::supported_non_eltwise_set())) { + !n->isMemberOf(tensorexpr::supported_non_eltwise_set()) && + !n->isMemberOf(tensorexpr::getQuantizationOperatorSet())) { continue; } @@ -1079,7 +1112,7 @@ class TensorExprFuser { // All tensor types should be known. return false; } - if (c10::isComplexType(*st) || c10::isQIntType(*st)) { + if (c10::isComplexType(*st)) { return false; } } diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h index fc800dde0de7..5ce37ba86947 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.h +++ b/torch/csrc/jit/passes/tensorexpr_fuser.h @@ -25,6 +25,8 @@ TORCH_API void setTensorExprDynamicShapeFusionEnabled(bool val); TORCH_API bool tensorExprDynamicShapeFusionEnabled(); TORCH_API bool setTexprReductionsEnabled(bool value); TORCH_API bool texprReductionsEnabled(); +TORCH_API bool setTexprQuantEnabled(bool value); +TORCH_API bool texprQuantEnabled(); TORCH_API void RemoveProfileNodesAndSpecializeTypes( std::shared_ptr& graph); @@ -72,6 +74,11 @@ TORCH_API bool isSupported(Node* node); /// @return Reference of the custome operator set /// TORCH_API OperatorSet& getCustomOperatorSet(); + +/// Get the modifiable quantization operator set object. +/// @return Reference of the quantization operator set +/// +TORCH_API OperatorSet& getQuantizationOperatorSet(); } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 7674037a0cfb..8cc19997ad64 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -979,6 +979,8 @@ void initJITBindings(PyObject* module) { "_jit_texpr_dynamic_shape_enabled", &tensorExprDynamicShapeFusionEnabled) .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled) + .def("_jit_set_texpr_quantization_enabled", &setTexprQuantEnabled) + .def("_jit_texpr_quantization_enabled", &texprQuantEnabled) .def( "_jit_set_te_generate_block_code", [](bool gen_block_code) { diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index 4b94b2e4b7f6..2cbe00b7edc6 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -91,6 +91,19 @@ class TORCH_API CodeGen { size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); } + virtual at::Tensor empty_strided_quantized( + c10::IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + double scale, + int64_t zero_point, + c10::optional memory_format_opt) { + return at::native::empty_affine_quantized( + size, dtype_opt, layout_opt, device_opt, pin_memory_opt, scale, zero_point, memory_format_opt); + }; + const std::string& kernel_func_name() const { return kernel_func_name_; } diff --git a/torch/csrc/jit/tensorexpr/graph_opt.cpp b/torch/csrc/jit/tensorexpr/graph_opt.cpp index ee18ec7b5d1f..d7c4200c6899 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.cpp +++ b/torch/csrc/jit/tensorexpr/graph_opt.cpp @@ -5,6 +5,10 @@ #include #include #include +#include +#include +#include +#include namespace torch { namespace jit { @@ -181,6 +185,228 @@ bool OptimizeCat(const std::shared_ptr& graph) { return false; } +Node* insertDequant(Graph* graph, Value* quantized_val) { + Node* dequant = graph->create(Symbol::aten("dequantize"), {quantized_val}); + auto tt = quantized_val->type()->cast(); + TORCH_INTERNAL_ASSERT( + tt != nullptr, + buildErrorMessage("Value to be dequanted should be a tensortype.")); + auto output_type = tt->withScalarType(at::ScalarType::Float); + dequant->output() + ->setDebugName(quantized_val->debugName() + ".dequant") + ->setType(output_type); + + return graph->insertNode(dequant); +} + +Node* insertQuant(Graph* graph, std::vector & inputs) { + // TODO: also to deal with quantize_per_channel + Node* quant = graph->create(Symbol::aten("quantize_per_tensor"), inputs); + + Value * fp32_value = inputs.at(0); + + // TODO: to select QUInt8 or QInt8 + auto output_type = fp32_value->type()->cast()->withScalarType(at::ScalarType::QUInt8); + quant->output() + ->setDebugName(fp32_value->debugName() + ".quant") + ->setType(output_type); + + return graph->insertNode(quant); +} + +void DecomposeQuantizedOp(std::shared_ptr graph, + Graph& pattern_graph, + Graph& target_graph, + c10::optional& post_op_graph) { + std::vector values_to_rewrite; + std::unordered_map rewrite_map; + std::unordered_set nodes_to_delete_; + + // Locate the pattern graph and replace with target graph + const auto& matches = findPatternMatches(pattern_graph, *graph); + for (const Match& match : matches) { + // Collect original inputs and insertion point for target graph + Node* ins_point = nullptr; + Value* quant_dtype = nullptr; + std::vector inputs, outputs; + std::vector qtensor_pos_list; + for (const auto idx : c10::irange(pattern_graph.inputs().size())) { + Value* input = match.values_map.at(pattern_graph.inputs().at(idx)); + inputs.push_back(input); + if (!ins_point || ins_point->isBefore(input->node())) { + ins_point = input->node(); + } + c10::TensorTypePtr inputType = input->type()->cast(); + if (inputType) { + auto scalarType = inputType->scalarType(); + if (scalarType == at::ScalarType::QInt8 || scalarType == at::ScalarType::QUInt8) { + qtensor_pos_list.push_back(idx); + } + } + } + AT_ASSERT(ins_point); + + // Prepare the outputs for target graph + for (Value* v : pattern_graph.outputs()) { + Value* output = match.values_map.at(v); + outputs.push_back(output); + if (!quant_dtype) { + c10::TensorTypePtr outputType = output->type()->cast(); + if (outputType) { + quant_dtype = graph->insertConstant(IValue(static_cast(outputType->scalarType().value()))); + } + } + } + + if (!quant_dtype) { + continue; + } + + // Insert dequant nodes and replace the outputs in the original inputs + ins_point = ins_point->next(); + WithInsertPoint insert_point(ins_point); + for (const auto idx : c10::irange(qtensor_pos_list.size())) { + size_t qtensor_pos = qtensor_pos_list.at(idx); + ins_point = insertDequant(graph.get(), inputs.at(qtensor_pos)); + inputs.at(qtensor_pos) = ins_point->output(); + } + + // Prepare the inputs for quantization OP (partially) and FP32 OP + // TODO: by default -2 (scale and zero_point), while to remap if map existing + std::vector quant_inputs; + quant_inputs.push_back(nullptr); + quant_inputs.push_back(inputs.at(inputs.size()-2)); + quant_inputs.push_back(inputs.at(inputs.size()-1)); + quant_inputs.push_back(quant_dtype); + inputs.pop_back(); + inputs.pop_back(); + + // Insert the according FP32 OP + std::vector new_outputs = insertGraph(*graph, target_graph, inputs); + AT_ASSERT(outputs.size() == new_outputs.size()); + + // Insert post-op nodes + if (post_op_graph != c10::nullopt) { + for (const auto idx : c10::irange(new_outputs.size())) { + Value* v = new_outputs.at(idx); + Value* v_ori = outputs.at(idx); + if (v->type()->cast()) { + v->setType(v_ori->type()->cast()->withScalarType(at::ScalarType::Float)); + std::vector post_op_outputs = insertGraph(*graph, **post_op_graph, {v}); + new_outputs.at(idx) = post_op_outputs[0]; + } + } + } + + // Insert quant nodes + for (const auto idx : c10::irange(new_outputs.size())) { + Value* v = new_outputs.at(idx); + Value* v_ori = outputs.at(idx); + if (v->type()->cast()) { + v->setType(v_ori->type()->cast()->withScalarType(at::ScalarType::Float)); + quant_inputs.at(0) = v; + ins_point = insertQuant(graph.get(), quant_inputs); + new_outputs.at(idx) = ins_point->output(); + } + } + + // Record all planned rewritings + AT_ASSERT(outputs.size() == new_outputs.size()); + for (const auto idx : c10::irange(outputs.size())) { + values_to_rewrite.push_back(outputs[idx]); + rewrite_map[outputs[idx]] = + new_outputs[idx]->setType(outputs[idx]->type()); + } + // Record all planned deletions + for (Node* pattern_n : pattern_graph.nodes()) { + if (match.nodes_map.count(pattern_n)) { + Node* n = match.nodes_map.at(pattern_n); + nodes_to_delete_.insert(n); + } + } + } + + // Perform planned rewritings + for (auto v : values_to_rewrite) { + v->replaceAllUsesWith(rewrite_map.at(v)); + } + + // Perform planned deletions + for (auto n : nodes_to_delete_) { + n->removeAllInputs(); + } + for (auto n : nodes_to_delete_) { + n->destroy(); + } + nodes_to_delete_.clear(); +} + +void DecomposeQuantizedOps(std::shared_ptr graph) { + std::vector> op_list; + // Only support unary OPs with shape not changed + std::vector> post_op_list; + + op_list.push_back({ + R"(graph(%a_quant, %b_quant, %scale, %zero_point): + %add_out = quantized::add(%a_quant, %b_quant, %scale, %zero_point) + return (%add_out) )", + R"(graph(%a, %b): + %1 : int = prim::Constant[value=1]() + %add_out = aten::add(%a, %b, %1) + return (%add_out) )" + }); + post_op_list.push_back(c10::nullopt); + + op_list.push_back({ + R"(graph(%a_quant, %b_quant, %scale, %zero_point): + %mul_out = quantized::mul(%a_quant, %b_quant, %scale, %zero_point) + return (%mul_out) )", + R"(graph(%a, %b): + %mul_out = aten::mul(%a, %b) + return (%mul_out) )" + }); + post_op_list.push_back(c10::nullopt); + + op_list.push_back({ + R"(graph(%a_quant, %b_quant, %scale, %zero_point): + %add_out = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point) + return (%add_out) )", + R"(graph(%a, %b): + %1 : int = prim::Constant[value=1]() + %add_out = aten::add(%a, %b, %1) + return (%add_out) )" + }); + post_op_list.push_back(c10::make_optional(std::string( + R"(graph(%a): + %relu_out = aten::relu(%a) + return (%relu_out) )") + )); + + for (const auto idx : c10::irange(op_list.size())) { + Graph pattern_graph; + parseIR(op_list[idx].first, &pattern_graph); + Graph target_graph; + parseIR(op_list[idx].second, &target_graph); + c10::optional post_op_graph; + Graph post_op_g; + if (post_op_list[idx] != c10::nullopt) { + parseIR(post_op_list[idx].value(), &post_op_g); + post_op_graph.emplace(&post_op_g); + } + DecomposeQuantizedOp(graph, pattern_graph, target_graph, post_op_graph); + } +} + +bool DecomposeOps(const std::shared_ptr& graph) { + DecomposeQuantizedOps(graph); + GRAPH_DUMP("After DecomposeQuantizedOps:", graph); + EliminateCommonSubexpression(graph); + GRAPH_DUMP("After EliminateCommonSubexpression:", graph); + ConstantPooling(graph); + GRAPH_DUMP("After ConstantPooling:", graph); + return true; +} + void annotateInputShapes( const std::shared_ptr& graph, const std::vector>& example_inputs) { diff --git a/torch/csrc/jit/tensorexpr/graph_opt.h b/torch/csrc/jit/tensorexpr/graph_opt.h index 1180d0ac438b..cec1ee4a6daf 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.h +++ b/torch/csrc/jit/tensorexpr/graph_opt.h @@ -57,6 +57,7 @@ namespace tensorexpr { // aten_cat is the output buffer here. bool OptimizeCat(const std::shared_ptr& graph); +bool DecomposeOps(const std::shared_ptr& graph); TORCH_API void annotateInputShapes( const std::shared_ptr& graph, diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 9c22eb2db503..920c789a4998 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -1034,6 +1035,19 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { } }; + auto setQinfoAsNeeded = [&](c10::TensorTypePtr tt, BufHandle inBuffer) { + auto scalarType = tt->scalarType(); + if (scalarType == at::ScalarType::QInt8 || scalarType == at::ScalarType::QUInt8) { + qtensorInputIndex_.emplace(input_name_map_[input] + "_scale", bufferArgs_.size()); + VarHandle scale(input_name_map_[input] + "_scale", kDouble); + bufferArgs_.emplace_back(scale); + VarHandle zero_point(input_name_map_[input] + "_zero", kLong); + bufferArgs_.emplace_back(zero_point); + inBuffer.node()->set_qscale(scale.node()); + inBuffer.node()->set_qzero(zero_point.node()); + } + }; + Tensor result(nullptr, nullptr); switch (t->kind()) { case TypeKind::TensorType: { @@ -1069,6 +1083,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d)); bufs_.emplace(input, inBuffer.node()); bufferArgs_.emplace_back(inBuffer); + setQinfoAsNeeded(tt, inBuffer); break; } @@ -1089,7 +1104,7 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { "t" + input_name_map_[input], {flat_size}, ToDtype(static_cast(*tt->scalarType()))); - + setQinfoAsNeeded(tt, inBuffer); result = Compute( "input" + c10::to_string(bufs_.size() + 1), size_handles, @@ -1295,7 +1310,7 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides( // TODO: call into `convertOutputToCorrectStrides`. Currently this causes a // bug in IRSimplifier to occur. See explanation in // `convertOutputToCorrectStrides` - return Compute( + Tensor outputTensor = Compute( "output_1", dims, [&](const std::vector& axes_input) { std::vector axes(axes_input.begin(), axes_input.end()); auto absolute_position = ExprHandle(immLike(axes[0], 0)); @@ -1319,6 +1334,11 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides( } return BufHandle(buf).load(new_axes); }); + + auto outputBuf = outputTensor.buf(); + outputBuf->set_qscale(buf->qscale()); + outputBuf->set_qzero(buf->qzero()); + return outputTensor; } void TensorExprKernel::bindConstant(const torch::jit::Value* v) { @@ -1566,6 +1586,10 @@ void TensorExprKernel::optimizeOwningGraph() { // Optimize the concatenation OptimizeCat(graph_); + GRAPH_DUMP("Before DecomposeOps:", graph_); + DecomposeOps(graph_); + GRAPH_DUMP("After DecomposeOps:", graph_); + // Synchronize the symbolic strides information auto graph_outputs = graph_->outputs(); TORCH_INTERNAL_ASSERT(graph_outputs.size() == _orignal_graph_outputs.size()); @@ -1654,14 +1678,18 @@ void TensorExprKernel::compile() { if (!bufs_.count(output)) { throw malformed_input("cannot find output Tensor"); } + BufPtr outputBufPtr = bufs_.at(output); if (!output->type()->cast()) { // Scalar outputs are represented as 0-dim buffers. - bufOutputs_.insert(bufs_.at(output)); - bufferArgs_.emplace_back(BufHandle(bufs_.at(output))); + bufOutputs_.insert(outputBufPtr); + bufferArgs_.emplace_back(BufHandle(outputBufPtr)); tensorOutputTensorOptions_.emplace_back( - c10::TensorOptions(tensorType(bufs_.at(output))).device(device_)); + c10::TensorOptions(tensorType(outputBufPtr)).device(device_)); tensorOutputSizes_.emplace_back(); tensorOutputStrides_.emplace_back(); + tensorOutputQscales_.emplace_back(nullptr); + tensorOutputQzeros_.emplace_back(nullptr); + isOutputQuantized_.push_back(false); isOutputScalar_.push_back(true); bufs_.erase(output); continue; @@ -1678,21 +1706,23 @@ void TensorExprKernel::compile() { tensorOutputStrideDesc_.push_back(stride_desc); Tensor properly_strided_output = convertSymbolicOutputToCorrectStrides(output); + outputBufPtr = properly_strided_output.buf(); if (properly_strided_output.stmt()) { block->append_stmt(properly_strided_output.stmt()); } - bufs_[output] = properly_strided_output.buf(); + bufs_[output] = outputBufPtr; } else { // The "strided" tensor will be incorrect if used in NNC, // since NNC views it as contiguous. Only convert it to the right // strides at the end of the kernel (if already contiguous it's a no-op) Tensor properly_strided_output = convertStaticShapeOutputToCorrectStrides(output); + outputBufPtr = properly_strided_output.buf(); if (properly_strided_output.stmt()) { block->append_stmt(properly_strided_output.stmt()); } // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - bufs_[output] = properly_strided_output.buf(); + bufs_[output] = outputBufPtr; auto sizes = *tt->sizes().concrete_sizes(); tensorOutputSizes_.push_back(sizes); auto strides = tt->strides().concrete_sizes(); @@ -1706,10 +1736,24 @@ void TensorExprKernel::compile() { } } - bufOutputs_.insert(bufs_.at(output)); - bufferArgs_.emplace_back(BufHandle(bufs_.at(output))); + if (outputBufPtr->qscale()) { + tensorOutputQscales_.push_back(outputBufPtr->qscale()); + tensorOutputQzeros_.push_back(outputBufPtr->qzero()); + isOutputQuantized_.push_back(true); + } else { + tensorOutputQscales_.emplace_back(nullptr); + tensorOutputQzeros_.emplace_back(nullptr); + isOutputQuantized_.push_back(false); + } + + auto memory_format = outputBufPtr->is_contiguous(at::MemoryFormat::ChannelsLast) + ? at::MemoryFormat::ChannelsLast + : at::MemoryFormat::Contiguous; + bufOutputs_.insert(outputBufPtr); + bufferArgs_.emplace_back(BufHandle(outputBufPtr)); tensorOutputTensorOptions_.emplace_back( - c10::TensorOptions(tensorType(bufs_.at(output))).device(device_)); + c10::TensorOptions(tensorType(outputBufPtr)).device(device_) + .memory_format(c10::make_optional(memory_format))); isOutputScalar_.push_back(false); bufs_.erase(output); } @@ -1843,6 +1887,9 @@ std::vector TensorExprKernel::prepareRunArgs( // TODO: preallocate `runArgs` during compilation and fill in values where // possible (e.g. for constant tensors) std::vector runArgs; + const double dummyScale = 0.0f; + const int64_t dummyZero = 1l; + runArgs.reserve( inputs.size() + input_stride_args_.size() + bufOutputs_.size()); @@ -1854,7 +1901,15 @@ std::vector TensorExprKernel::prepareRunArgs( } else if (input.isDouble()) { runArgs.emplace_back(input.toDouble()); } else if (input.isTensor()) { - runArgs.emplace_back(input.toTensor().data_ptr()); + const at::Tensor & inputTensor = input.toTensor(); + runArgs.emplace_back(inputTensor.data_ptr()); + if (inputTensor.is_quantized()) { + at::QuantizerPtr quantizer = inputTensor.quantizer(); + TORCH_INTERNAL_ASSERT(quantizer->qscheme() == c10::QScheme::PER_TENSOR_AFFINE, + "Only per_tensor_affine qtensor is supported in NNC currently."); + runArgs.emplace_back(static_cast(quantizer.get())->scale()); + runArgs.emplace_back(static_cast(quantizer.get())->zero_point()); + } } } @@ -1872,25 +1927,46 @@ std::vector TensorExprKernel::prepareRunArgs( for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) { auto const& opts = tensorOutputTensorOptions_[i]; - outputs.emplace_back(codegen_->empty_strided( - static_sizes[i], - static_strides[i], - opts.dtype, - opts.layout, - opts.device, - opts.pinned_memory)); + outputs.emplace_back(isOutputQuantized_[i] + ? codegen_->empty_strided_quantized( + static_sizes[i], + opts.dtype, + opts.layout, + opts.device, + opts.pinned_memory, + dummyScale, + dummyZero, + opts.memory_format) + : codegen_->empty_strided( + static_sizes[i], + static_strides[i], + opts.dtype, + opts.layout, + opts.device, + opts.pinned_memory)); + runArgs.emplace_back(outputs.back().data_ptr()); } } else { for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) { auto const& opts = tensorOutputTensorOptions_[i]; - outputs.emplace_back(codegen_->empty_strided( - tensorOutputSizes_[i], - tensorOutputStrides_[i], - opts.dtype, - opts.layout, - opts.device, - opts.pinned_memory)); + outputs.emplace_back(isOutputQuantized_[i] + ? codegen_->empty_strided_quantized( + tensorOutputSizes_[i], + opts.dtype, + opts.layout, + opts.device, + opts.pinned_memory, + dummyScale, + dummyZero, + opts.memory_format) + : codegen_->empty_strided( + tensorOutputSizes_[i], + tensorOutputStrides_[i], + opts.dtype, + opts.layout, + opts.device, + opts.pinned_memory)); runArgs.emplace_back(outputs.back().data_ptr()); } } @@ -1910,6 +1986,8 @@ void TensorExprKernel::runKernel(Stack& stack) const { // Set up arguments (inputs, then outputs) for kernel call. auto inputs = last(stack, nInputs_); std::vector outputs; + double qscale = 0.0f; + int64_t qzero = 1l; std::vector runArgs = prepareRunArgs(inputs, outputs); @@ -1919,6 +1997,29 @@ void TensorExprKernel::runKernel(Stack& stack) const { // Update the stack. drop(stack, nInputs_); + // Reset qzero/qscale for quantized output as we may only know after run + // type I: the quant info is from graph constants originally + // type II: the quant info is from graph input originally + for (size_t i = 0, size = outputs.size(); i < size; ++i) { + if (isOutputQuantized_[i]) { + auto outputTensor = static_cast(outputs[i].unsafeGetTensorImpl()); + if (tensorOutputQscales_[i]->isConstant()) { + qscale = to(tensorOutputQscales_[i])->value(); + qzero = to(tensorOutputQzeros_[i])->value(); + } else { + VarPtr outputVar = to(tensorOutputQscales_[i]); + TORCH_INTERNAL_ASSERT(outputVar != nullptr, "Expect a VarPtr but does not get it"); + size_t idx = qtensorInputIndex_.at(outputVar->name_hint()); + qscale = *runArgs[idx+1].DoublePtr(); + qzero = *runArgs[idx+2].LongPtr(); + } + auto quantizer = at::make_per_tensor_affine_quantizer(qscale, + qzero, + outputTensor->quantizer()->scalar_type()); + outputTensor->set_quantizer_(quantizer); + } + } + int64_t idx = 0; for (auto& o : outputs) { if (isOutputScalar_[idx++]) { diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index c2bbc20b305e..ccd886fde945 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -275,12 +275,14 @@ class TORCH_API TensorExprKernel { c10::optional layout; c10::optional device; c10::optional pinned_memory; + c10::optional memory_format; UnpackedTensorOptions(const c10::TensorOptions& opts) : dtype(optTypeMetaToScalarType(opts.dtype_opt())), layout(opts.layout_opt()), device(opts.device_opt()), - pinned_memory(opts.pinned_memory_opt()) {} + pinned_memory(opts.pinned_memory_opt()), + memory_format(opts.memory_format_opt()) {} }; ExprHandle getVarForShape(const c10::ShapeSymbol& ss); @@ -305,8 +307,12 @@ class TORCH_API TensorExprKernel { std::vector bufferArgs_; std::vector> tensorOutputSizes_; std::vector> tensorOutputStrides_; + std::vector tensorOutputQscales_; + std::vector tensorOutputQzeros_; + std::unordered_map qtensorInputIndex_; std::vector tensorOutputStrideDesc_; std::vector isOutputScalar_; + std::vector isOutputQuantized_; std::vector tensorOutputTensorOptions_; std::unordered_set bufOutputs_; std::unordered_map bufs_; diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 9c84137d330d..3330b6c45586 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -404,6 +404,19 @@ at::Tensor LLVMCodeGen::empty_strided( size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); } +at::Tensor LLVMCodeGen::empty_strided_quantized( + c10::IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + double scale, + int64_t zero_point, + c10::optional memory_format_opt) { + return at::native::empty_affine_quantized( + size, dtype_opt, layout_opt, device_opt, pin_memory_opt, scale, zero_point, memory_format_opt); +} + void* LLVMCodeGen::getKernelAddress(LLVMCodeGenCallee* callee) { return (void*)callee->getKernelAddress(); } diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h index 7ab506fa8fe1..8d334b986d6a 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.h +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h @@ -53,6 +53,16 @@ class TORCH_API LLVMCodeGen : public CodeGen { c10::optional device_opt, c10::optional pin_memory_opt) override; +at::Tensor empty_strided_quantized( + c10::IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt, + double scale, + int64_t zero_point, + c10::optional memory_format_opt) override; + template T value() { return value(nullptr); diff --git a/torch/csrc/jit/tensorexpr/operators/quantization.cpp b/torch/csrc/jit/tensorexpr/operators/quantization.cpp index c5d9c00523f0..6e0c7e7045c9 100644 --- a/torch/csrc/jit/tensorexpr/operators/quantization.cpp +++ b/torch/csrc/jit/tensorexpr/operators/quantization.cpp @@ -169,7 +169,8 @@ Tensor computeQuantizePerTensor( ExprHandleVectorToExprVector(outputShape), dtype, nullptr, - c10::nullopt, + isChannelsLast(c10::get(inputs[0])) ? make_channels_last_strides(outputShape) + : make_contiguous_strides(outputShape), qscale.node(), qzero.node()); return Tensor(buf, vars, e.node()); @@ -319,8 +320,8 @@ Tensor computeQuantizedConv2dPrepack( // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) dilation[1], groups, - immQScale(qw), - immQZero(qw), + ExprHandle(qw.node()->qscale()), + ExprHandle(qw.node()->qzero()), (int64_t)immQDType(qw)}); return Tensor(ResultBuf.node(), s); } @@ -349,8 +350,8 @@ Tensor computeQuantizedConv1d( ResultBuf, "nnc_aten_quantized_conv1d", {qx, prepacked}, - {immQScale(qx), - immQZero(qx), + {ExprHandle(qx.node()->qscale()), + ExprHandle(qx.node()->qzero()), (int64_t)immQDType(qx), out_qscale, out_qzero}); @@ -381,8 +382,8 @@ Tensor computeQuantizedConv2d( ResultBuf, "nnc_aten_quantized_conv2d", {qx, prepacked}, - {immQScale(qx), - immQZero(qx), + {ExprHandle(qx.node()->qscale()), + ExprHandle(qx.node()->qzero()), (int64_t)immQDType(qx), out_qscale, out_qzero}); @@ -413,8 +414,8 @@ Tensor computeQuantizedConv2dRelu( ResultBuf, "nnc_aten_quantized_conv2d_relu", {qx, prepacked}, - {immQScale(qx), - immQZero(qx), + {ExprHandle(qx.node()->qscale()), + ExprHandle(qx.node()->qzero()), (int64_t)immQDType(qx), out_qscale, out_qzero}); @@ -445,8 +446,8 @@ Tensor computeQuantizedLinear( ResultBuf, "nnc_aten_quantized_linear", {qx, prepacked}, - {immQScale(qx), - immQZero(qx), + {ExprHandle(qx.node()->qscale()), + ExprHandle(qx.node()->qzero()), (int64_t)immQDType(qx), out_qscale, out_qzero}); @@ -477,8 +478,8 @@ Tensor computeQuantizedLinearRelu( ResultBuf, "nnc_aten_quantized_linear_relu", {qx, prepacked}, - {immQScale(qx), - immQZero(qx), + {ExprHandle(qx.node()->qscale()), + ExprHandle(qx.node()->qzero()), (int64_t)immQDType(qx), out_qscale, out_qzero}); @@ -518,11 +519,11 @@ Tensor computeQuantizedAddExternalCall( ResultBuf, "nnc_aten_quantized_add", {qa, qb}, - {immQScale(qa), - immQZero(qa), + {ExprHandle(qa.node()->qscale()), + ExprHandle(qa.node()->qzero()), (int64_t)immQDType(qa), - immQScale(qb), - immQZero(qb), + ExprHandle(qb.node()->qscale()), + ExprHandle(qb.node()->qzero()), (int64_t)immQDType(qb), out_qscale, out_qzero}); @@ -549,11 +550,11 @@ Tensor computeQuantizedMul( ResultBuf, "nnc_aten_quantized_mul", {qa, qb}, - {immQScale(qa), - immQZero(qa), + {ExprHandle(qa.node()->qscale()), + ExprHandle(qa.node()->qzero()), (int64_t)immQDType(qa), - immQScale(qb), - immQZero(qb), + ExprHandle(qb.node()->qscale()), + ExprHandle(qb.node()->qzero()), (int64_t)immQDType(qb), out_qscale, out_qzero}); @@ -572,18 +573,19 @@ Tensor computeQuantizedMulScalar( const auto scalar = c10::get(inputs[1]); // Change to dtype based on outputType when dtype propagation implemented const auto out_qdtype = immQDType(qa); - double scale1 = immQScale(qa); + auto scale1 = ExprHandle(qa.node()->qscale()); + auto out_scale = scale1 * DoubleImm::make(scalar); auto ResultBuf = makeQBufHandleContiguous( "quantized_mul_scalar", outputShape, Dtype(out_qdtype), - scale1 * scalar, - immQZero(qa)); + out_scale.node(), + qa.node()->qzero()); StmtPtr s = ExternalCall::make( ResultBuf, "nnc_aten_quantized_mul_scalar", {qa}, - {scale1, immQZero(qa), (int64_t)immQDType(qa), scalar}); + {scale1, ExprHandle(qa.node()->qzero()), (int64_t)immQDType(qa), scalar}); return Tensor(ResultBuf.node(), s); } @@ -602,19 +604,19 @@ Tensor computeQuantizedRelu( "quantized_relu", outputShape, Dtype(out_qdtype), - immQScale(qa), - immQZero(qa)) + qa.node()->qscale(), + qa.node()->qzero()) : makeQBufHandleContiguous( "quantized_relu", outputShape, Dtype(out_qdtype), - immQScale(qa), - immQZero(qa)); + qa.node()->qscale(), + qa.node()->qzero()); StmtPtr s = ExternalCall::make( ResultBuf, "nnc_aten_quantized_relu", {qa}, - {immQScale(qa), immQZero(qa), (int64_t)immQDType(qa)}); + {ExprHandle(qa.node()->qscale()), ExprHandle(qa.node()->qzero()), (int64_t)immQDType(qa)}); return Tensor(ResultBuf.node(), s); } @@ -639,8 +641,8 @@ Tensor computeQuantizedCat( for (const auto i : c10::irange(n)) { const BufHandle& bh = inputList[i]; args.emplace_back(bh); - extra_args.emplace_back(immQScale(bh)); - extra_args.emplace_back(immQZero(bh)); + extra_args.emplace_back(ExprHandle(bh.node()->qscale())); + extra_args.emplace_back(ExprHandle(bh.node()->qzero())); extra_args.emplace_back((int64_t)immQDType(bh)); } extra_args.emplace_back(argDim); @@ -684,9 +686,16 @@ Tensor computeDequantize( indices.push_back(VarHandle(var)); } auto y = dequant(tensorOrConstant(inputs[0], indices), dtype, qscale, qzero); - BufPtr buf = alloc( - "dequantize", ExprHandleVectorToExprVector(outputShape), dtype); - return Tensor(buf, vars, y.node()); + + auto strides = isChannelsLast(qx) ? make_channels_last_strides(outputShape) + : make_contiguous_strides(outputShape); + BufHandle buf = Buf::make( + "dequantize", + outputShape, + dtype, + c10::nullopt, // initializer + fmap(strides, [&](ExprPtr stride) { return ExprHandle(stride); })); + return Tensor(buf.node(), vars, y.node()); } Tensor computeUpsampleNearest2d( @@ -759,12 +768,13 @@ Tensor computeUpsampleNearest2dExternalCall( scale_factor_w = (*scale_factors)[1]; } const BufHandle& x = c10::get(inputs[0]); - double qx_qscale = -1.f; - int64_t qx_qzero = -1l; + // dummy initialization needed as for non-quant scenarios + ExprHandle qx_qscale = DoubleImm::make(0.0f); + ExprHandle qx_qzero = LongImm::make(1l); int64_t qx_qdtype = -1l; if (isQuantized(x)) { - qx_qscale = immQScale(x); - qx_qzero = immQZero(x); + qx_qscale = ExprHandle(x.node()->qscale()); + qx_qzero = ExprHandle(x.node()->qzero()); qx_qdtype = (int64_t)immQDType(x); } @@ -774,8 +784,8 @@ Tensor computeUpsampleNearest2dExternalCall( "upsample_nearest2d", outputShape, Dtype(immQDType(x)), - qx_qscale, - qx_qzero); + qx_qscale.node(), + qx_qzero.node()); } return BufHandle("upsample_nearest2d", outputShape, dtype); }(); @@ -823,8 +833,8 @@ Tensor computeQuantizedSigmoidExternalCall( ResultBuf, "nnc_aten_quantized_sigmoid", {qx}, - {immQScale(qx), - immQZero(qx), + {ExprHandle(qx.node()->qscale()), + ExprHandle(qx.node()->qzero()), (int64_t)immQDType(qx), out_qscale, out_qzero});