Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions test/cpp/tensorexpr/test_external_calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,7 @@ TEST(ExternalCall, BinaryFloat) {
TEST(ExternalCall, UnaryFloat) {
using TensorFunc = std::function<at::Tensor(at::Tensor)>;
auto toExprHandleVec = [](std::vector<int64_t> v) {
auto intV = std::vector<int>(v.begin(), v.end());
return std::vector<ExprHandle>(intV.begin(), intV.end());
return std::vector<ExprHandle>(v.begin(), v.end());
};
using Test = std::tuple<
std::vector<int64_t>,
Expand All @@ -722,7 +721,7 @@ TEST(ExternalCall, UnaryFloat) {
return at::adaptive_avg_pool2d(x, {5, 7});
},
"nnc_aten_adaptive_avg_pool2d",
toExprHandleVec({5, 7})});
toExprHandleVec({0, -1, -1, 5, 7})});
tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
{100, 200},
{100},
Expand Down
39 changes: 39 additions & 0 deletions test/cpp/tensorexpr/test_quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,45 @@ TEST_F(Quantization, UpsampleNearst2d) {
CHECK_EQ(check, 1);
}

TEST_F(Quantization, QuantAdaptiveAvgPool2dDequantUInt8) {
const auto graph_string = R"IR(
graph(%x : Float(4, 2, 9, 9, strides=[162, 1, 18, 2], device=cpu)):
%2 : int = prim::Constant[value=13]()
%3 : int[] = prim::Constant[value=[3, 4]]()
%qz : int = prim::Constant[value=53]()
%qs : float = prim::Constant[value=0.0543662]()
%q : QUInt8(4, 2, 9, 9, strides=[162, 1, 18, 2], device=cpu) = aten::quantize_per_tensor(%x, %qs, %qz, %2)
%qu : QUInt8(4, 2, 3, 4, strides=[24, 1, 8, 2], device=cpu) = aten::adaptive_avg_pool2d(%q, %3)
%6 : Float(4, 2, 3, 4, strides=[24, 1, 8, 2], device=cpu) = aten::dequantize(%qu)
return (%6))IR";
auto graph = std::make_shared<Graph>();
parseIR(graph_string, &*graph);

auto x = at::rand({4, 2, 9, 9}, TensorOptions(kCPU).dtype(at::kFloat));
x = x.contiguous(c10::MemoryFormat::ChannelsLast);
auto q = at::quantize_per_tensor(x, 0.0543662f, 53, at::kQUInt8);
auto qu = at::adaptive_avg_pool2d(q, {3, 4});
auto y_expected = at::dequantize(qu);

TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {x};
StmtPtr s = k.getCodeGenStmt();

std::vector<IValue> stack = fmap<IValue>(inputs);
k.run(stack);
auto y = stack[0].toTensor();

bool check = at::allclose(y_expected, y);
if (!check) {
std::cout << "x:\n" << x << std::endl;
std::cout << "q:\n" << q << std::endl;
std::cout << "qu:\n" << qu << std::endl;
std::cout << "y_expected:\n" << y_expected << std::endl;
std::cout << "y:\n" << y << std::endl;
}
CHECK_EQ(check, 1);
}

at::Tensor quantized_cat(
c10::List<at::Tensor> const& xs,
int64_t dim,
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/tensorexpr/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ std::unordered_map<std::string, std::string> ExtCallMemoryReuse::
{"nnc_aten_quantized_mul_scalar", "nnc_aten_quantized_mul_scalar_out"},
{"nnc_aten_max_red", "nnc_aten_max_red_out"},
{"nnc_aten_conv1d", "nnc_aten_conv1d_out"},
{"nnc_aten_adaptive_avg_pool2d", "nnc_aten_adaptive_avg_pool2d_out"},
};
}

Expand Down
63 changes: 54 additions & 9 deletions torch/csrc/jit/tensorexpr/external_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,25 +1268,67 @@ void nnc_aten_adaptive_avg_pool2d(
int64_t* buf_dims,
int64_t* buf_strides,
int8_t* buf_dtypes,
int64_t args_num,
int64_t,
int64_t* extra_args) {
const int64_t x_qdtype = extra_args[2];
c10::optional<std::vector<std::pair<size_t, QIData>>> qdata;
if (x_qdtype != -1) {
qdata = {{1u,
{((double*)extra_args)[0],
extra_args[1],
at::toQIntType(static_cast<c10::ScalarType>(x_qdtype))}}};
}
auto tensors = constructTensors(
bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, qdata);

at::Tensor& r = tensors[0];
const at::Tensor& x = tensors[1];
int64_t H = extra_args[0];
int64_t W = H;
if (args_num > 1) {
W = extra_args[1];
}
try {
r = at::adaptive_avg_pool2d(x, {H, W});
r = at::adaptive_avg_pool2d(tensors[1], {extra_args[3], extra_args[4]});
} catch (...) {
}
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
}

void nnc_aten_adaptive_avg_pool2d_out(
int64_t bufs_in_num,
void** buf_data,
int64_t* buf_ranks,
int64_t* buf_dims,
int64_t* buf_strides,
int8_t* buf_dtypes,
int64_t,
int64_t* extra_args) {
const int64_t x_qdtype = extra_args[2];
c10::optional<std::vector<std::pair<size_t, QIData>>> qdata;
if (x_qdtype != -1) {
qdata = {{1u,
{((double*)extra_args)[0],
extra_args[1],
at::toQIntType(static_cast<c10::ScalarType>(x_qdtype))}}};
}

const size_t bufs_out_num = 1u;
auto tensors = constructTensors2(
bufs_in_num,
buf_data,
buf_ranks,
buf_dims,
buf_strides,
buf_dtypes,
qdata,
bufs_out_num);

at::Tensor r;
try {
r = at::adaptive_avg_pool2d(tensors[1], /*{H, W}*/{extra_args[3], extra_args[4]});
} catch (...) {
}

buf_data[0] = r.data_ptr();
c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get());
buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get();
}

void nnc_aten_mean(
int64_t bufs_num,
void** buf_data,
Expand Down Expand Up @@ -1568,6 +1610,9 @@ const static RegisterNNCExternalFunction nnc_conv1d_out(
const static RegisterNNCExternalFunction nnc_adaptive_avg_pool2d(
"nnc_aten_adaptive_avg_pool2d",
nnc_aten_adaptive_avg_pool2d);
const static RegisterNNCExternalFunction nnc_adaptive_avg_pool2d_out(
"nnc_aten_adaptive_avg_pool2d_out",
nnc_aten_adaptive_avg_pool2d_out);
const static RegisterNNCExternalFunction nnc_mean(
"nnc_aten_mean",
nnc_aten_mean);
Expand Down
46 changes: 38 additions & 8 deletions torch/csrc/jit/tensorexpr/operators/reduction.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/tensorexpr/operators/reduction.h>
#include <torch/csrc/jit/tensorexpr/operators/quantization.h>

using namespace torch::jit::tensorexpr;

Expand Down Expand Up @@ -171,16 +172,45 @@ Tensor computeAdaptiveAvgPool2d(
if (outputType) {
dtype = Dtype(*outputType);
}
BufHandle ResultBuf("adaptive_avgpool2d", outputShape, dtype);

auto x = c10::get<BufHandle>(inputs[0]);
// dummy initialization needed as for non-quant scenarios
ExprHandle qx_qscale = DoubleImm::make(0.0f);
ExprHandle qx_qzero = LongImm::make(1l);
int64_t qx_qdtype = -1l;
if (isQuantized(x)) {
qx_qscale = ExprHandle(x.node()->qscale());
qx_qzero = ExprHandle(x.node()->qzero());
qx_qdtype = (int64_t)immQDType(x);
}

auto strides = x.is_contiguous(c10::MemoryFormat::ChannelsLast)
? make_channels_last_strides(outputShape)
: make_contiguous_strides(outputShape);

BufHandle ResultBuf = Buf::make(
"adaptive_avgpool2d",
outputShape,
Dtype(*outputType),
c10::nullopt, // initializer
ExprVectorToExprHandleVector(strides),
qx_qscale,
qx_qzero);

// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto out_size_param = c10::get<IntList>(inputs[1]);
return Tensor(
ResultBuf.node(),
ExternalCall::make(
ResultBuf,
"nnc_aten_adaptive_avg_pool2d",
{c10::get<BufHandle>(inputs[0])},
c10::fmap<ExprHandle>(out_size_param)));
// Expand the dims as needed, to facilitate external call params processing
if (out_size_param.size() == 1) {out_size_param.push_back(out_size_param[0]);}
StmtPtr s = ExternalCall::make(
ResultBuf,
"nnc_aten_adaptive_avg_pool2d",
{x},
{qx_qscale,
qx_qzero,
qx_qdtype,
out_size_param[0],
out_size_param[1]});
return Tensor(ResultBuf.node(), s);
}

} // namespace tensorexpr
Expand Down