Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions ggml/src/vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,22 @@ void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
const uint tid = gl_LocalInvocationID.x;

// There are not enough cols to use all threads
if (tid >= p.ncols) {
return;
}

const uint block_size = min(p.ncols, BLOCK_SIZE);

uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;

tmp[tid] = FLOAT_TYPE(0.0f);

[[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) {
const uint col = i*BLOCK_SIZE + 2*tid;
[[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
const uint col = i*block_size + 2*tid;
const uint ib = (row*p.ncols + col)/QUANT_K; // block index
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index
Expand All @@ -38,7 +45,7 @@ void main() {

// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
[[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
Expand Down
7 changes: 4 additions & 3 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2271,9 +2271,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op

for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
if (ggml_blck_size(type_a) != 256) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1}));
}
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
}
}

Expand Down