diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index a02f47fca3250..80edf2c9eb652 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -105,6 +105,13 @@ template * __spirv_CompositeConstruct(const T v); +template +extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t +__spirv_JointMatrixGetElementCoordINTEL( + __spv::__spirv_JointMatrixINTEL *, size_t i); + template diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index 07cc3c477ed43..ba11632f5d371 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -88,6 +88,20 @@ class wi_element { Group, T, Use, NumRows, NumCols, Layout> &Mat, std::size_t i) : M(Mat), idx(i) {} + + inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { +#if defined(__SYCL_DEVICE_ONLY__) + __ocl_vec_t coord = + __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx); + const uint32_t row = coord[0]; + const uint32_t col = coord[1]; + return std::make_tuple(row, col); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + operator T() { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(M.spvm, idx); @@ -171,6 +185,20 @@ class wi_element &Mat, std::size_t i) : M(Mat), idx(i) {} + + inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { +#if defined(__SYCL_DEVICE_ONLY__) + __ocl_vec_t coord = + __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx); + const uint32_t row = coord[0]; + const uint32_t col = coord[1]; + return std::make_tuple(row, col); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + operator sycl::ext::oneapi::bfloat16() { #ifdef __SYCL_DEVICE_ONLY__ return __spirv_VectorExtractDynamic(M.spvm, idx); diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp new file mode 100644 index 0000000000000..da034be7a8fc7 --- /dev/null +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -0,0 +1,235 @@ +// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out + +// Kernel B sum by col +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void sum_cols_ref(host_accessor B, + host_accessor sum_cols) { + int sum_cols_ref[N] = {0}; + for (size_t j = 0; j < N; j++) { + for (size_t i = 0; i < M; i++) { + sum_cols_ref[j] += B[i][j]; + } + auto diff = sum_cols[j] - sum_cols_ref[j]; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} + +// clang-format off +/* + Here is a demonstration of how matrix B will be divided across + work items for this test case. + < --------------- 128 ----------------------------------> + x x x x x x x x x x x x x x x x .......... x x x x x x ^ + x x x x x x x x x x x x x x x x .......... x x x x x x 16 + x x x x x x x x x x x x x x x x .......... x x x x x x | + ..... | + x x x x x x x x x x x x x x x x .......... x x x x x x | + x x x x x x x x x x x x x x x x .......... x x x x x x v + + + --------------- 64 ----------------> + x x x x x x .......... x x x x x x ^ + x x x x x x .......... x x x x x x 8 + x x x x x x .......... x x x x x x | <-- part of (VNNI-ed) + ..... | original matrix each SG + x x x x x x .......... x x x x x x | holds + x x x x x x .......... x x x x x x v + < WI0 > < WI15 > + + + <-------- 16 -------------> + x x x .......... x x x ^ + x x x .......... x x x | + x x x .......... x x x | <-- part of (non-VNNI-ed) original matrix + ..... | each SG holds + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x 32 + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x | + x x x .......... x x x v + + If we dividie the above matrix across 16 (SG_SZ) work items, + each WI will hold 32 elements. And these 32 elements will be + 8x4 chunks as shown in the VNNI-ed matrix figure. +*/ + +// The total distribution among the WIs in ALL the sub-groups is as follows: +// This is useful to figure out the the global index is to be calculated + +/* +W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements +wi [0,0] --> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | wi [0,16] --> i=0, [0, 64] + i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | i=1, [0, 65] + i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | i=2, [0, 66] + i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | i=3, [0, 67] + + i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | .... + i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] | + i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] | + i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] | + ... ... .... | + i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | i=28, [7, 124] + i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | i=29, [7, 125] + i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | i=30, [7, 126] + i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | i=31, [7, 127] +---------------------------------------------------------------------------------------- --------------------------- +wi [1,0] --> i=0, [8, 0] + i=1, [8, 1] + i=2, [8, 2] + i=3, [8, 2] + ... + i=28, [15, 0] + i=29, [15, 1] + i=30, [15, 2] + i=31, [15, 3] +*/ + +// The following is the distribution among WIs in a SINGLE SG. +/* +W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements + +wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | + i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | + i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | + i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | + + i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | + i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] | + i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] | + i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] | + ... ... .... | + i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | + i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | + i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | + i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | + +*/ +// clang-format on + +template +void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { + buffer bufB(B.get_data(), range<2>(M, N)); + // size of vector is known because SG size of set by the user in this case + int sum_cols[N] = {0}; + buffer sum_cols_v(sum_cols, N); // there are total of tK/4 * 2, 16 rows + q.submit([&](handler &cgh) { + auto accB = bufB.get_access(cgh); + + auto v = sum_cols_v.get_access(cgh); + auto os = sycl::stream(100000, 6144, cgh); + + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + + // TK = 32, TN = 16 + joint_matrix + sub_b; + + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (global_idx * (TK / 4) * N) + + sg_starty / SG_SZ * TN * 4, + N); + + int32_t sum_local_cols[N] = {0}; // 4 local cols, N total + // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row + auto wiData = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + + size_t + global_index; // Index into the result array that holds the sums. + + // Keep track of cols handled in this WI + int32_t handled_cols[N] = {-1}; + + // each WI calculates local sum of cols + for (int i = 0; i < wiData.length(); ++i) { + // get the index of the element in the submatrix + auto dataItem = wiData[i]; + auto [row, col] = dataItem.get_coord(); + + // Calculation of global index + int sg_idx = (int)global_idy / SG_SZ; + global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; + sum_local_cols[global_index] += wiData[i]; + handled_cols[global_index] = 1; + } + + for (int j = 0; j < N; j++) { + if (handled_cols[j] == 1) { + global_index = j; + sum_local_cols[global_index] = reduce_over_group( + sg, sum_local_cols[global_index], sycl::plus<>()); + atomic_fetch_add(v[global_index], sum_local_cols[global_index]); + } + } + }); // parallel for + }).wait(); + sum_cols_ref(bufB.get_host_access(), sum_cols_v.get_host_access()); +} + +// TK = 32, TN = 16 +static constexpr size_t MATRIX_K = TK / 4 * 2; // 16 +static constexpr size_t MATRIX_N = TN * 4 * 2; // 128 +int8_t B[MATRIX_K][MATRIX_N]; + +/* < --------------- 128 ----------------------------------> + x x x x x x x x x x x x x x x x .......... x x x x x x ^ + x x x x x x x x x x x x x x x x .......... x x x x x x 16 + x x x x x x x x x x x x x x x x .......... x x x x x x | + ..... | + x x x x x x x x x x x x x x x x .......... x x x x x x | + x x x x x x x x x x x x x x x x .......... x x x x x x v +*/ +int main() { + big_matrix MB((int8_t *)&B); + + size_t NDRangeK = MATRIX_K / (TK / 4); + size_t NDRangeN = (MATRIX_N / 4) / TN; + queue q; + nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + for (int i = 0; i < MATRIX_K; i++) { + for (int j = 0; j < MATRIX_N; j++) { + B[i][j] = i; + } + } + + matrix_sum_cols(q, MB, r); + + std::cout << "Passed\n"; + + return 0; +}