Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ struct ValueSelectionIntersectionKernel {

template <typename index_t, typename hash_coeffs_t>
struct SparseBinaryOpIntersectionAFunctor {
// For std::is_trivially_copyable
SparseBinaryOpIntersectionAFunctor& operator=(
const SparseBinaryOpIntersectionAFunctor&) = delete;

int64_t operator()(index_t nnz_idx) const {
int64_t hash = 0;
if (!ptr_indices_) {
Expand Down Expand Up @@ -256,6 +260,10 @@ struct SparseBinaryOpIntersectionAFunctor {

template <typename index_t, typename hash_coeffs_t, typename hash_t = int64_t>
struct SparseBinaryOpIntersectionBFunctor {
// For std::is_trivially_copyable
SparseBinaryOpIntersectionBFunctor& operator=(
const SparseBinaryOpIntersectionBFunctor&) = delete;

index_t operator()(index_t nnz_idx) const {
int64_t hash = 0;
if (hash_ptr_) {
Expand Down
3 changes: 3 additions & 0 deletions src/ATen/native/sparse/xpu/sycl/SparseTensorKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ struct KernelLauncher {

template <typename index_t, typename hash_coeffs_t>
struct FlattenIndicesFunctor {
// For std::is_trivially_copyable
FlattenIndicesFunctor& operator=(const FlattenIndicesFunctor&) = delete;

int64_t operator()(int64_t nnz_idx) const {
const auto* ptr_indices_dim = ptr_indices_ + nnz_idx * indices_nnz_stride_;
auto hash = static_cast<int64_t>(0);
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/ActivationHardshrinkKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ struct HardshrinkFunctor {
return (a >= -lambd_ && a <= lambd_) ? scalar_t(0) : a;
}

HardshrinkFunctor(const scalar_t lambd) : lambd_(lambd) {}
HardshrinkFunctor(scalar_t lambd) : lambd_(lambd) {}

private:
const scalar_t lambd_;
scalar_t lambd_;
};

void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) {
Expand Down
183 changes: 183 additions & 0 deletions src/ATen/native/xpu/sycl/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,189 @@ class IndexSelectScalarFunctor {
}
};

template <class IdxConfig,
bool TrivialOffCal = false,
bool KnownProblemInner = false>
void init_global_batch_info(
BatchKernelConfig::ItemDesc& id,
int64_t& idx_logical_off,
int64_t& glb_batch_group,
int64_t& glb_batch_group_loc_off,
IdxConfig &cfg) {
using IdxType = typename IdxConfig::IdxType;

idx_logical_off = id.glb_batch % cfg.index_num_;
int64_t idx_off;
if constexpr (TrivialOffCal) {
idx_off = idx_logical_off;
} else {
idx_off = IndexToOffset<IdxType, int64_t>::get(
idx_logical_off,
cfg.iinfo_,
IndexToOffset<IdxType, int64_t>::NON_STRICT_CONTIGUOUS);
}
glb_batch_group = id.glb_batch / cfg.index_num_;
glb_batch_group_loc_off = cfg.iinfo_.data[idx_off];
glb_batch_group_loc_off = glb_batch_group_loc_off >= 0
? glb_batch_group_loc_off
: cfg.indexing_dimension_size_ + glb_batch_group_loc_off;
}

template <class IdxConfig,
bool TrivialOffCal = false,
bool KnownProblemInner = false>
int64_t inline indexing_logical_off(
BatchKernelConfig::ItemDesc& id,
int64_t glb_batch_group,
int64_t glb_batch_group_loc_off,
IdxConfig &cfg) {
int64_t si, pi, bi;
int64_t glb_batch_group_glb_off =
glb_batch_group * cfg.indexing_dimension_size_ +
glb_batch_group_loc_off;
auto stride = cfg.stride_;
if constexpr (KnownProblemInner) {
si = 0;
pi = id.glb_problem;
bi = glb_batch_group_glb_off;
return (pi + bi * cfg.problem_) * stride;
} else {
if (cfg.problem_inner_) {
si = 0;
pi = id.glb_problem;
bi = glb_batch_group_glb_off;
return (pi + bi * cfg.problem_) * stride;
} else {
si = glb_batch_group_glb_off;
pi = id.glb_problem;
bi = 0;
return si + pi * stride;
}
}
}

template <class IdxConfig,
bool TrivialOffCal = false,
bool KnownProblemInner = false>
int64_t inline fixing_logical_off(
BatchKernelConfig::ItemDesc& id,
int64_t glb_batch_group,
int64_t idx_logical_off,
IdxConfig &cfg) {
int64_t si, pi, bi, stride;
int64_t glb_batch_group_glb_off =
glb_batch_group * cfg.index_num_ + idx_logical_off;
if constexpr (KnownProblemInner) {
si = 0;
stride = 1;
pi = id.glb_problem;
bi = glb_batch_group_glb_off;
return pi + bi * cfg.problem_;
} else {
if (cfg.problem_inner_) {
si = 0;
stride = 1;
pi = id.glb_problem;
bi = glb_batch_group_glb_off;
return pi + bi * cfg.problem_;
} else {
bi = 0;
si = glb_batch_group_glb_off;
pi = id.glb_problem;
stride = cfg.index_num_;
return si + pi * stride;
}
}
}

template <
class IdxConfig,
bool TrivialOffCal = false,
bool KnownProblemInner = false>
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
void index_kernel(IdxConfig cfg) {
using ValType = typename IdxConfig::ValType;

auto item = syclext::this_work_item::get_nd_item<2>();
auto id = cfg.get_item_desc(item);

if (id.glb_problem >= cfg.problem_ ||
id.glb_batch >= cfg.problem_batch_) {
return;
}

// Indexing kernels have three operands,
// 1. index operand
// 2. operand indexing on
// 3. operand has fixing size as index (optional)
int64_t idx_logical_off, glb_batch_group, glb_batch_group_loc_off;
int64_t glb_indexing_logical_off, glb_fixing_logical_off;
int64_t dst_off, src_off;

init_global_batch_info<IdxConfig>(
id, idx_logical_off, glb_batch_group, glb_batch_group_loc_off, cfg);

glb_indexing_logical_off =
indexing_logical_off<IdxConfig>(id, glb_batch_group, glb_batch_group_loc_off, cfg);

if (cfg.sinfo_.data != nullptr && cfg.dinfo_.data != nullptr) {
glb_fixing_logical_off =
fixing_logical_off<IdxConfig>(id, glb_batch_group, idx_logical_off, cfg);
}

if constexpr (TrivialOffCal) {
if (cfg.indexing_dst_) {
dst_off = glb_indexing_logical_off;
if (cfg.sinfo_.data != nullptr) {
src_off = glb_fixing_logical_off;
}
} else {
src_off = glb_indexing_logical_off;
dst_off = glb_fixing_logical_off;
}
} else {
if (cfg.indexing_dst_) {
// index_copy, index_add, index_fill
dst_off = IndexToOffset<ValType, int64_t>::get(
glb_indexing_logical_off,
cfg.dinfo_,
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
if (cfg.sinfo_.data != nullptr) {
src_off = IndexToOffset<const ValType, int64_t>::get(
glb_fixing_logical_off,
cfg.sinfo_,
IndexToOffset<const ValType, int64_t>::NON_STRICT_CONTIGUOUS);
}
} else {
// index_select
src_off = IndexToOffset<const ValType, int64_t>::get(
glb_indexing_logical_off,
cfg.sinfo_,
IndexToOffset<const ValType, int64_t>::NON_STRICT_CONTIGUOUS);
dst_off = IndexToOffset<ValType, int64_t>::get(
glb_fixing_logical_off,
cfg.dinfo_,
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
}
}
cfg.func_(
cfg.dinfo_.data,
cfg.sinfo_.data,
dst_off,
src_off,
glb_batch_group_loc_off,
cfg.alpha_);
}


template <class IdxConfig, bool TrivialOffCal = false,
bool KnownProblemInner = false>
static inline void launch_index_kernel(IdxConfig cfg) {
auto &queue = getCurrentSYCLQueue();
sycl_kernel_submit<index_kernel<IdxConfig, TrivialOffCal, KnownProblemInner>>(
cfg.global_size(), cfg.group_size(), queue, 0, cfg);
}

template <
class SrcInfo,
class DstInfo,
Expand Down
Loading