Skip to content

Commit e3da4ef

Browse files
[SYCL] Add sycl::detail::bit_cast implementation (#1762)
Implementing SYCL_INTEL_bitcast extension which reinterprets the bits of an object as an object of a different data type calling the sycl::detail::bit_cast function. This function has the same specification as std::bit_cast, defined by C++20. sycl::detail::bit_cast calls std::bit_cast if __cpp_lib_bit_cast macro is enabled, otherwise, it calls __builtin_bit_cast builtin for oneAPI Data Parallel C++ compiler and sycl::detail::memcpy function for other compilers. Signed-off-by: Dmitry Vodopyanov <[email protected]>
1 parent 2c5822e commit e3da4ef

File tree

4 files changed

+115
-22
lines changed

4 files changed

+115
-22
lines changed

sycl/doc/extensions/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ DPC++ extensions status:
77

88
| Extension | Status | Comment |
99
|-------------|:------------|:------------|
10-
| [SYCL_INTEL_bitcast](Bitcast/SYCL_INTEL_bitcast.asciidoc) | Proposal | |
10+
| [SYCL_INTEL_bitcast](Bitcast/SYCL_INTEL_bitcast.asciidoc) | Supported | As sycl::detail::bit_cast |
1111
| [C and C++ Standard libraries support](C-CXX-StandardLibrary/C-CXX-StandardLibrary.rst) | Partially supported(OpenCL: CPU, GPU) | |
1212
| [SYCL_INTEL_data_flow_pipes](DataFlowPipes/data_flow_pipes.asciidoc) | Partially supported(OpenCL: ACCELERATOR) | kernel_host_pipe_support part is not implemented |
1313
| [SYCL_INTEL_deduction_guides](deduction_guides/SYCL_INTEL_deduction_guides.asciidoc) | Supported | |

sycl/include/CL/sycl/detail/helpers.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <CL/sycl/detail/type_traits.hpp>
1818

1919
#include <memory>
20+
#include <numeric> // std::bit_cast
2021
#include <stdexcept>
2122
#include <type_traits>
2223
#include <vector>
@@ -42,6 +43,31 @@ inline void memcpy(void *Dst, const void *Src, size_t Size) {
4243
}
4344
}
4445

46+
template <typename To, typename From>
47+
constexpr To bit_cast(const From &from) noexcept {
48+
static_assert(sizeof(To) == sizeof(From),
49+
"Sizes of To and From must be equal");
50+
static_assert(std::is_trivially_copyable<From>::value,
51+
"From must be trivially copyable");
52+
static_assert(std::is_trivially_copyable<To>::value,
53+
"To must be trivially copyable");
54+
#if __cpp_lib_bit_cast
55+
return std::bit_cast<To>(from);
56+
#else // __cpp_lib_bit_cast
57+
58+
#if __has_builtin(__builtin_bit_cast)
59+
return __builtin_bit_cast(To, from);
60+
#else // __has_builtin(__builtin_bit_cast)
61+
static_assert(std::is_trivially_default_constructible<To>::value,
62+
"To must be trivially default constructible");
63+
To to;
64+
sycl::detail::memcpy(&to, &from, sizeof(To));
65+
return to;
66+
#endif // __has_builtin(__builtin_bit_cast)
67+
68+
#endif // __cpp_lib_bit_cast
69+
}
70+
4571
class context_impl;
4672
// The function returns list of events that can be passed to OpenCL API as
4773
// dependency list and waits for others.

sycl/include/CL/sycl/intel/sub_group.hpp

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <CL/sycl/range.hpp>
2222
#include <CL/sycl/types.hpp>
2323

24-
#include <numeric> // std::bit_cast
2524
#include <type_traits>
2625

2726
#ifdef __SYCL_DEVICE_ONLY__
@@ -72,22 +71,6 @@ using AcceptableForLocalLoadStore =
7271
bool_constant<!std::is_same<void, SelectBlockT<T>>::value &&
7372
Space == access::address_space::local_space>;
7473

75-
// TODO: move this to public cl::sycl::bit_cast as extension?
76-
template <typename To, typename From> To bit_cast(const From &from) {
77-
#if __cpp_lib_bit_cast
78-
return std::bit_cast<To>(from);
79-
#else
80-
81-
#if __has_builtin(__builtin_bit_cast)
82-
return __builtin_bit_cast(To, from);
83-
#else
84-
To to;
85-
sycl::detail::memcpy(&to, &from, sizeof(To));
86-
return to;
87-
#endif // __has_builtin(__builtin_bit_cast)
88-
#endif // __cpp_lib_bit_cast
89-
}
90-
9174
template <typename T, access::address_space Space>
9275
T load(const multi_ptr<T, Space> src) {
9376
using BlockT = SelectBlockT<T>;
@@ -97,7 +80,7 @@ T load(const multi_ptr<T, Space> src) {
9780
BlockT Ret =
9881
__spirv_SubgroupBlockReadINTEL<BlockT>(reinterpret_cast<PtrT>(src.get()));
9982

100-
return bit_cast<T>(Ret);
83+
return sycl::detail::bit_cast<T>(Ret);
10184
}
10285

10386
template <int N, typename T, access::address_space Space>
@@ -110,7 +93,7 @@ vec<T, N> load(const multi_ptr<T, Space> src) {
11093
VecT Ret =
11194
__spirv_SubgroupBlockReadINTEL<VecT>(reinterpret_cast<PtrT>(src.get()));
11295

113-
return bit_cast<typename vec<T, N>::vector_t>(Ret);
96+
return sycl::detail::bit_cast<typename vec<T, N>::vector_t>(Ret);
11497
}
11598

11699
template <typename T, access::address_space Space>
@@ -119,7 +102,7 @@ void store(multi_ptr<T, Space> dst, const T &x) {
119102
using PtrT = sycl::detail::ConvertToOpenCLType_t<multi_ptr<BlockT, Space>>;
120103

121104
__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
122-
bit_cast<BlockT>(x));
105+
sycl::detail::bit_cast<BlockT>(x));
123106
}
124107

125108
template <int N, typename T, access::address_space Space>
@@ -130,7 +113,7 @@ void store(multi_ptr<T, Space> dst, const vec<T, N> &x) {
130113
sycl::detail::ConvertToOpenCLType_t<const multi_ptr<BlockT, Space>>;
131114

132115
__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
133-
bit_cast<VecT>(x));
116+
sycl::detail::bit_cast<VecT>(x));
134117
}
135118

136119
} // namespace sub_group

sycl/test/bit_cast/bit_cast.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
2+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
3+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
4+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
5+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
6+
7+
#include <CL/sycl.hpp>
8+
9+
#include <iostream>
10+
#include <math.h>
11+
#include <type_traits>
12+
13+
constexpr cl::sycl::access::mode sycl_write = cl::sycl::access::mode::write;
14+
15+
template <typename To, typename From>
16+
class BitCastKernel;
17+
18+
template <typename To, typename From>
19+
To doBitCast(const From &ValueToConvert) {
20+
std::vector<To> Vec(1);
21+
{
22+
sycl::buffer<To, 1> Buf(Vec.data(), 1);
23+
sycl::queue Queue;
24+
Queue.submit([&](sycl::handler &cgh) {
25+
auto acc = Buf.template get_access<sycl_write>(cgh);
26+
cgh.single_task<class BitCastKernel<To, From>>([=]() {
27+
// TODO: change to sycl::bit_cast in the future
28+
acc[0] = sycl::detail::bit_cast<To>(ValueToConvert);
29+
});
30+
});
31+
}
32+
return Vec[0];
33+
}
34+
35+
template <typename To, typename From>
36+
int test(const From &Value) {
37+
auto ValueConvertedTwoTimes = doBitCast<From>(doBitCast<To>(Value));
38+
bool isOriginalValueEqualsToConvertedTwoTimes = false;
39+
if (std::is_integral<From>::value) {
40+
isOriginalValueEqualsToConvertedTwoTimes = Value == ValueConvertedTwoTimes;
41+
} else if ((std::is_floating_point<From>::value) || std::is_same<From, cl::sycl::half>::value) {
42+
static const float Epsilon = 0.0000001f;
43+
isOriginalValueEqualsToConvertedTwoTimes = fabs(Value - ValueConvertedTwoTimes) < Epsilon;
44+
} else {
45+
std::cerr << "Type " << typeid(From).name() << " neither integral nor floating point nor cl::sycl::half\n";
46+
return 1;
47+
}
48+
if (!isOriginalValueEqualsToConvertedTwoTimes) {
49+
std::cerr << "FAIL: Original value which is " << Value << " != value converted two times which is " << ValueConvertedTwoTimes << "\n";
50+
return 1;
51+
}
52+
std::cout << "PASS\n";
53+
return 0;
54+
}
55+
56+
int main() {
57+
int ReturnCode = 0;
58+
59+
std::cout << "cl::sycl::half to unsigned short ...\n";
60+
ReturnCode += test<unsigned short>(cl::sycl::half(1.0f));
61+
62+
std::cout << "unsigned short to cl::sycl::half ...\n";
63+
ReturnCode += test<cl::sycl::half>(static_cast<unsigned short>(16384));
64+
65+
std::cout << "cl::sycl::half to short ...\n";
66+
ReturnCode += test<short>(cl::sycl::half(1.0f));
67+
68+
std::cout << "short to cl::sycl::half ...\n";
69+
ReturnCode += test<cl::sycl::half>(static_cast<short>(16384));
70+
71+
std::cout << "int to float ...\n";
72+
ReturnCode += test<float>(static_cast<int>(2));
73+
74+
std::cout << "float to int ...\n";
75+
ReturnCode += test<int>(static_cast<float>(-2.4f));
76+
77+
std::cout << "unsigned int to float ...\n";
78+
ReturnCode += test<float>(static_cast<unsigned int>(6));
79+
80+
std::cout << "float to unsigned int ...\n";
81+
ReturnCode += test<unsigned int>(static_cast<float>(-2.4f));
82+
83+
return ReturnCode;
84+
}

0 commit comments

Comments
 (0)