diff --git a/sycl/include/CL/sycl/detail/spirv.hpp b/sycl/include/CL/sycl/detail/spirv.hpp index d662e2afc7880..3ba8ff09b8828 100644 --- a/sycl/include/CL/sycl/detail/spirv.hpp +++ b/sycl/include/CL/sycl/detail/spirv.hpp @@ -33,6 +33,32 @@ template <> struct group_scope<::cl::sycl::intel::sub_group> { static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup; }; +// Generic shuffles and broadcasts may require multiple calls to SPIR-V +// intrinsics, and should use the fewest broadcasts possible +// - Loop over 64-bit chunks until remaining bytes < 64-bit +// - At most one 32-bit, 16-bit and 8-bit chunk left over +template +void GenericCall(const Functor &ApplyToBytes) { + if (sizeof(T) >= sizeof(uint64_t)) { +#pragma unroll + for (size_t Offset = 0; Offset < sizeof(T); Offset += sizeof(uint64_t)) { + ApplyToBytes(Offset, sizeof(uint64_t)); + } + } + if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) { + size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t); + ApplyToBytes(Offset, sizeof(uint32_t)); + } + if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) { + size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t); + ApplyToBytes(Offset, sizeof(uint16_t)); + } + if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) { + size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t); + ApplyToBytes(Offset, sizeof(uint8_t)); + } +} + template bool GroupAll(bool pred) { return __spirv_GroupAll(group_scope::value, pred); } @@ -41,47 +67,137 @@ template bool GroupAny(bool pred) { return __spirv_GroupAny(group_scope::value, pred); } +// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic +template +using is_native_broadcast = bool_constant::value>; + +template +using EnableIfNativeBroadcast = detail::enable_if_t< + is_native_broadcast::value && std::is_integral::value, T>; + +// Bitcast broadcasts can be implemented using a single SPIR-V GroupBroadcast +// intrinsic, but require type-punning via an appropriate integer type +template +using is_bitcast_broadcast = bool_constant< + !is_native_broadcast::value && std::is_trivially_copyable::value && + (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8)>; + +template +using EnableIfBitcastBroadcast = detail::enable_if_t< + is_bitcast_broadcast::value && std::is_integral::value, T>; + +template +using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t; + +// Generic broadcasts may require multiple calls to SPIR-V GroupBroadcast +// intrinsics, and should use the fewest broadcasts possible +// - Loop over 64-bit chunks until remaining bytes < 64-bit +// - At most one 32-bit, 16-bit and 8-bit chunk left over +template +using is_generic_broadcast = + bool_constant::value && + !is_bitcast_broadcast::value && + std::is_trivially_copyable::value>; + +template +using EnableIfGenericBroadcast = detail::enable_if_t< + is_generic_broadcast::value && std::is_integral::value, T>; + // Broadcast with scalar local index // Work-group supports any integral type // Sub-group currently supports only uint32_t +template struct GroupId { using type = size_t; }; +template <> struct GroupId<::cl::sycl::intel::sub_group> { + using type = uint32_t; +}; template -detail::enable_if_t::value && std::is_integral::value, T> -GroupBroadcast(T x, IdT local_id) { +EnableIfNativeBroadcast GroupBroadcast(T x, IdT local_id) { + using GroupIdT = typename GroupId::type; + GroupIdT GroupLocalId = static_cast(local_id); using OCLT = detail::ConvertToOpenCLType_t; - using OCLIdT = detail::ConvertToOpenCLType_t; - OCLT ocl_x = detail::convertDataToType(x); - OCLIdT ocl_id = detail::convertDataToType(local_id); - return __spirv_GroupBroadcast(group_scope::value, ocl_x, ocl_id); + using OCLIdT = detail::ConvertToOpenCLType_t; + OCLT OCLX = detail::convertDataToType(x); + OCLIdT OCLId = detail::convertDataToType(GroupLocalId); + return __spirv_GroupBroadcast(group_scope::value, OCLX, OCLId); } template -detail::enable_if_t::value && std::is_integral::value, - T> -GroupBroadcast(T x, IdT local_id) { - using SGIdT = uint32_t; - SGIdT sg_local_id = static_cast(local_id); - using OCLT = detail::ConvertToOpenCLType_t; - using OCLIdT = detail::ConvertToOpenCLType_t; - OCLT ocl_x = detail::convertDataToType(x); - OCLIdT ocl_id = detail::convertDataToType(sg_local_id); - return __spirv_GroupBroadcast(group_scope::value, ocl_x, ocl_id); +EnableIfBitcastBroadcast GroupBroadcast(T x, IdT local_id) { + using GroupIdT = typename GroupId::type; + GroupIdT GroupLocalId = static_cast(local_id); + using BroadcastT = ConvertToNativeBroadcastType_t; + using OCLIdT = detail::ConvertToOpenCLType_t; + auto BroadcastX = detail::bit_cast(x); + OCLIdT OCLId = detail::convertDataToType(GroupLocalId); + BroadcastT Result = + __spirv_GroupBroadcast(group_scope::value, BroadcastX, OCLId); + return detail::bit_cast(Result); +} +template +EnableIfGenericBroadcast GroupBroadcast(T x, IdT local_id) { + T Result; + char *XBytes = reinterpret_cast(&x); + char *ResultBytes = reinterpret_cast(&Result); + auto BroadcastBytes = [=](size_t Offset, size_t Size) { + uint64_t BroadcastX, BroadcastResult; + detail::memcpy(&BroadcastX, XBytes + Offset, Size); + BroadcastResult = GroupBroadcast(BroadcastX, local_id); + detail::memcpy(ResultBytes + Offset, &BroadcastResult, Size); + }; + GenericCall(BroadcastBytes); + return Result; } // Broadcast with vector local index template -T GroupBroadcast(T x, id local_id) { +EnableIfNativeBroadcast GroupBroadcast(T x, id local_id) { if (Dimensions == 1) { return GroupBroadcast(x, local_id[0]); } using IdT = vec; using OCLT = detail::ConvertToOpenCLType_t; using OCLIdT = detail::ConvertToOpenCLType_t; - IdT vec_id; + IdT VecId; + for (int i = 0; i < Dimensions; ++i) { + VecId[i] = local_id[Dimensions - i - 1]; + } + OCLT OCLX = detail::convertDataToType(x); + OCLIdT OCLId = detail::convertDataToType(VecId); + return __spirv_GroupBroadcast(group_scope::value, OCLX, OCLId); +} +template +EnableIfBitcastBroadcast GroupBroadcast(T x, id local_id) { + if (Dimensions == 1) { + return GroupBroadcast(x, local_id[0]); + } + using IdT = vec; + using BroadcastT = ConvertToNativeBroadcastType_t; + using OCLIdT = detail::ConvertToOpenCLType_t; + IdT VecId; for (int i = 0; i < Dimensions; ++i) { - vec_id[i] = local_id[Dimensions - i - 1]; + VecId[i] = local_id[Dimensions - i - 1]; } - OCLT ocl_x = detail::convertDataToType(x); - OCLIdT ocl_id = detail::convertDataToType(vec_id); - return __spirv_GroupBroadcast(group_scope::value, ocl_x, ocl_id); + auto BroadcastX = detail::bit_cast(x); + OCLIdT OCLId = detail::convertDataToType(VecId); + BroadcastT Result = + __spirv_GroupBroadcast(group_scope::value, BroadcastX, OCLId); + return detail::bit_cast(Result); +} +template +EnableIfGenericBroadcast GroupBroadcast(T x, id local_id) { + if (Dimensions == 1) { + return GroupBroadcast(x, local_id[0]); + } + T Result; + char *XBytes = reinterpret_cast(&x); + char *ResultBytes = reinterpret_cast(&Result); + auto BroadcastBytes = [=](size_t Offset, size_t Size) { + uint64_t BroadcastX, BroadcastResult; + detail::memcpy(&BroadcastX, XBytes + Offset, Size); + BroadcastResult = GroupBroadcast(BroadcastX, local_id); + detail::memcpy(ResultBytes + Offset, &BroadcastResult, Size); + }; + GenericCall(BroadcastBytes); + return Result; } // Single happens-before means semantics should always apply to all spaces @@ -400,28 +516,6 @@ using EnableIfGenericShuffle = sizeof(T) == 4 || sizeof(T) == 8)), T>; -template -void GenericShuffle(const ShuffleFunctor &ShuffleBytes) { - if (sizeof(T) >= sizeof(uint64_t)) { -#pragma unroll - for (size_t Offset = 0; Offset < sizeof(T); Offset += sizeof(uint64_t)) { - ShuffleBytes(Offset, sizeof(uint64_t)); - } - } - if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) { - size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t); - ShuffleBytes(Offset, sizeof(uint32_t)); - } - if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) { - size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t); - ShuffleBytes(Offset, sizeof(uint16_t)); - } - if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) { - size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t); - ShuffleBytes(Offset, sizeof(uint8_t)); - } -} - template EnableIfGenericShuffle SubgroupShuffle(T x, id<1> local_id) { T Result; @@ -433,7 +527,7 @@ EnableIfGenericShuffle SubgroupShuffle(T x, id<1> local_id) { ShuffleResult = SubgroupShuffle(ShuffleX, local_id); detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size); }; - GenericShuffle(ShuffleBytes); + GenericCall(ShuffleBytes); return Result; } @@ -448,7 +542,7 @@ EnableIfGenericShuffle SubgroupShuffleXor(T x, id<1> local_id) { ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id); detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size); }; - GenericShuffle(ShuffleBytes); + GenericCall(ShuffleBytes); return Result; } @@ -465,7 +559,7 @@ EnableIfGenericShuffle SubgroupShuffleDown(T x, T y, id<1> local_id) { ShuffleResult = SubgroupShuffleDown(ShuffleX, ShuffleY, local_id); detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size); }; - GenericShuffle(ShuffleBytes); + GenericCall(ShuffleBytes); return Result; } @@ -482,7 +576,7 @@ EnableIfGenericShuffle SubgroupShuffleUp(T x, T y, id<1> local_id) { ShuffleResult = SubgroupShuffleUp(ShuffleX, ShuffleY, local_id); detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size); }; - GenericShuffle(ShuffleBytes); + GenericCall(ShuffleBytes); return Result; } diff --git a/sycl/include/CL/sycl/intel/group_algorithm.hpp b/sycl/include/CL/sycl/intel/group_algorithm.hpp index 932a53ba07675..c8f6faa2a08a6 100644 --- a/sycl/include/CL/sycl/intel/group_algorithm.hpp +++ b/sycl/include/CL/sycl/intel/group_algorithm.hpp @@ -138,6 +138,12 @@ template using EnableIfIsPointer = cl::sycl::detail::enable_if_t::value, T>; +template +using EnableIfIsTriviallyCopyable = cl::sycl::detail::enable_if_t< + std::is_trivially_copyable::value && + !cl::sycl::detail::is_vector_arithmetic::value, + T>; + // EnableIf shorthands for algorithms that depend on type and an operator template using EnableIfIsScalarArithmeticNativeOp = cl::sycl::detail::enable_if_t< @@ -286,8 +292,8 @@ EnableIfIsPointer none_of(Group g, Ptr first, Ptr last, } template -EnableIfIsScalarArithmetic broadcast(Group, T x, - typename Group::id_type local_id) { +EnableIfIsTriviallyCopyable broadcast(Group, T x, + typename Group::id_type local_id) { static_assert(sycl::detail::is_generic_group::value, "Group algorithms only support the sycl::group and " "intel::sub_group class."); @@ -323,7 +329,7 @@ EnableIfIsVectorArithmetic broadcast(Group g, T x, } template -EnableIfIsScalarArithmetic +EnableIfIsTriviallyCopyable broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) { static_assert(sycl::detail::is_generic_group::value, "Group algorithms only support the sycl::group and " @@ -363,7 +369,7 @@ broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) { } template -EnableIfIsScalarArithmetic broadcast(Group g, T x) { +EnableIfIsTriviallyCopyable broadcast(Group g, T x) { static_assert(sycl::detail::is_generic_group::value, "Group algorithms only support the sycl::group and " "intel::sub_group class."); diff --git a/sycl/test/group-algorithm/broadcast.cpp b/sycl/test/group-algorithm/broadcast.cpp index df0887a40d4a0..04028fade9669 100644 --- a/sycl/test/group-algorithm/broadcast.cpp +++ b/sycl/test/group-algorithm/broadcast.cpp @@ -10,17 +10,19 @@ #include #include #include +#include #include using namespace sycl; using namespace sycl::intel; +template class broadcast_kernel; template void test(queue q, InputContainer input, OutputContainer output) { typedef typename InputContainer::value_type InputT; typedef typename OutputContainer::value_type OutputT; - typedef class broadcast_kernel kernel_name; + typedef class broadcast_kernel kernel_name; size_t N = input.size(); size_t G = 4; range<2> R(G, G); @@ -54,12 +56,49 @@ int main() { } constexpr int N = 16; - std::array input; - std::array output; - std::iota(input.begin(), input.end(), 1); - std::fill(output.begin(), output.end(), false); - test(q, input, output); + // Test built-in scalar type + { + std::array input; + std::array output; + std::iota(input.begin(), input.end(), 1); + std::fill(output.begin(), output.end(), false); + test(q, input, output); + } + + // Test pointer type + { + std::array input; + std::array output; + for (int i = 0; i < N; ++i) { + input[i] = static_cast(0x0) + i; + } + std::fill(output.begin(), output.end(), static_cast(0x0)); + test(q, input, output); + } + // Test user-defined type + // - Use complex as a proxy for this + // - Test float and double to test 64-bit and 128-bit types + { + std::array, N> input; + std::array, 3> output; + for (int i = 0; i < N; ++i) { + input[i] = + std::complex(0, 1) + (float)i * std::complex(2, 2); + } + std::fill(output.begin(), output.end(), std::complex(0, 0)); + test(q, input, output); + } + { + std::array, N> input; + std::array, 3> output; + for (int i = 0; i < N; ++i) { + input[i] = + std::complex(0, 1) + (double)i * std::complex(2, 2); + } + std::fill(output.begin(), output.end(), std::complex(0, 0)); + test(q, input, output); + } std::cout << "Test passed." << std::endl; }