diff --git a/sycl/include/CL/sycl/types.hpp b/sycl/include/CL/sycl/types.hpp index d635defe870aa..64f93b8e6f5ea 100644 --- a/sycl/include/CL/sycl/types.hpp +++ b/sycl/include/CL/sycl/types.hpp @@ -372,12 +372,34 @@ template class vec { return *this; } +#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ + explicit vec(const DataT &arg) { + m_Data = (DataType)arg; + } + + template + typename std::enable_if::value, vec &>::type + operator=(const DataT &Rhs) { + m_Data = (DataType)Rhs; + return *this; + } +#else explicit vec(const DataT &arg) { for (int i = 0; i < NumElements; ++i) { setValue(i, arg); } } + template + typename std::enable_if::value, vec &>::type + operator=(const DataT &Rhs) { + for (int i = 0; i < NumElements; ++i) { + setValue(i, Rhs); + } + return *this; + } +#endif + #ifdef __SYCL_USE_EXT_VECTOR_TYPE__ // Optimized naive constructors with NumElements of DataT values. // We don't expect compilers to optimize vararg recursive functions well. diff --git a/sycl/test/basic_tests/vectors.cpp b/sycl/test/basic_tests/vectors.cpp index 46bc40cfec830..1f349c9737cad 100644 --- a/sycl/test/basic_tests/vectors.cpp +++ b/sycl/test/basic_tests/vectors.cpp @@ -50,5 +50,13 @@ int main() { int64_t(vec_2.x()); cl::sycl::int4(vec_2.x()); + // Check broadcasting operator= + cl::sycl::vec b_vec(1.0); + b_vec = 0.5; + assert(static_cast(b_vec.x()) == static_cast(0.5)); + assert(static_cast(b_vec.y()) == static_cast(0.5)); + assert(static_cast(b_vec.z()) == static_cast(0.5)); + assert(static_cast(b_vec.w()) == static_cast(0.5)); + return 0; }