From a908c37ce97f9d78b0925cea1180c79845e94b47 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 10 Apr 2023 09:49:40 +0200 Subject: [PATCH 01/19] Allow use of OpenCL GPU-based BLAS using ClBlast instead of OpenBLAS for context processing --- Makefile | 4 ++ ggml.c | 132 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 128 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 8fbb19c46cc10..34824d6907277 100644 --- a/Makefile +++ b/Makefile @@ -113,6 +113,10 @@ ifdef LLAMA_CUBLAS ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif +ifdef LLAMA_CLBLAST + CFLAGS += -DGGML_USE_CLBLAST + LDFLAGS += -lclblast -lOpenCL +endif ifdef LLAMA_GPROF CFLAGS += -pg CXXFLAGS += -pg diff --git a/ggml.c b/ggml.c index 6e46c0e5ad1da..bb618ea43e0d3 100644 --- a/ggml.c +++ b/ggml.c @@ -143,6 +143,22 @@ inline static void* ggml_aligned_malloc(size_t size) { } \ } while (0) +#if GGML_USE_CLBLAST +#ifndef GGML_USE_OPENBLAS +#define GGML_USE_OPENBLAS +#endif + +#define CL_TARGET_OPENCL_VERSION 110 +#include + +cl_platform_id platform; +cl_device_id device; +cl_context context; +cl_command_queue queue; +cl_event event; +bool cl_initialized = false; +#endif + #if defined(GGML_USE_ACCELERATE) #include #elif defined(GGML_USE_OPENBLAS) @@ -7422,7 +7438,7 @@ static void ggml_compute_forward_rms_norm( // ggml_compute_forward_mul_mat -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster static bool ggml_compute_forward_mul_mat_use_blas( @@ -7447,6 +7463,85 @@ static bool ggml_compute_forward_mul_mat_use_blas( return false; } + +#ifdef GGML_USE_CLBLAST +static bool ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) { + cl_int err = 0; + + if (!cl_initialized) { + cl_uint num_platforms; + clGetPlatformIDs(0, NULL, &num_platforms); + cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); + clGetPlatformIDs(num_platforms, platforms, NULL); + platform = platforms[0]; + cl_uint num_devices; + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); + cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); + device = devices[0]; + context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL context: %d\n", err); + fflush(stdout); + } + queue = clCreateCommandQueue(context, device, 0, &err); + event = NULL; + + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Command Queue: %d\n", err); + fflush(stdout); + } + + free(platforms); + free(devices); + cl_initialized = true; + } + + // Prepare buffers + cl_mem cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_WRITE, m*k*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer A: %d\n", err); + fflush(stdout); + } + cl_mem cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer B: %d\n", err); + fflush(stdout); + } + cl_mem cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer C: %d\n", err); + fflush(stdout); + } + + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, m*k*sizeof(float), host_a, 0, NULL, NULL); + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, n*k*sizeof(float), host_b, 0, NULL, NULL); + clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL); + + // Call the SGEMM routine. + CLBlastStatusCode status = CLBlastSgemm(order, + trans_a, trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, &event); + + // Wait for completion + if (status == CLBlastSuccess) { + clWaitForEvents(1, &event); + clReleaseEvent(event); + } + + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL); + + clReleaseMemObject(cl_buffer_a); + clReleaseMemObject(cl_buffer_b); + clReleaseMemObject(cl_buffer_c); +} +#endif #endif static void ggml_compute_forward_mul_mat_f32( @@ -7462,7 +7557,7 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const int64_t ne10 = src1->ne[0]; #endif const int64_t ne11 = src1->ne[1]; @@ -7519,7 +7614,7 @@ static void ggml_compute_forward_mul_mat_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -7570,6 +7665,13 @@ static void ggml_compute_forward_mul_mat_f32( CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT +#ifdef GGML_USE_CLBLAST + ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); +#else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, @@ -7713,7 +7815,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); @@ -7797,6 +7899,13 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // zT = y * xT +#ifdef GGML_USE_CLBLAST + ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); +#else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, @@ -7963,7 +8072,7 @@ static void ggml_compute_forward_mul_mat_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -8053,6 +8162,13 @@ static void ggml_compute_forward_mul_mat_q_f32( CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #else // zT = y * xT +#ifdef GGML_USE_CLBLAST + ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); +#else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, @@ -10885,7 +11001,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning @@ -10902,7 +11018,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); @@ -12303,7 +12419,7 @@ int ggml_cpu_has_wasm_simd(void) { } int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) return 1; #else return 0; From b7143c1a2ed104ef2a89ad6183e15e4af1e995c5 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 11 Apr 2023 21:53:50 +0200 Subject: [PATCH 02/19] Improve ClBlast implementation, avoid recreating buffers, remove redundant transfers --- ggml_blas_adapter.c | 153 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 ggml_blas_adapter.c diff --git a/ggml_blas_adapter.c b/ggml_blas_adapter.c new file mode 100644 index 0000000000000..50faf3807bf02 --- /dev/null +++ b/ggml_blas_adapter.c @@ -0,0 +1,153 @@ +//this is a drop-in for all CLBlast related code, to keep the main ggml.c unmodified +// we will imitate the function definition from OpenBLAS instead, replaced as necessary. + +//windows binaries for clblast obtained from https://github.com/CNugteren/CLBlast (apache license) +//windows binaries for opencl obtained from https://github.com/KhronosGroup/OpenCL-SDK (apache license) + +#if GGML_USE_OPENBLAS +#include +#include +#include + +#if GGML_USE_CLBLAST + +#define CL_TARGET_OPENCL_VERSION 110 +#include + +cl_platform_id platform; +cl_device_id device; +cl_context context; +cl_command_queue queue; +bool cl_initialized = false; +size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; +cl_mem cl_buffer_a, cl_buffer_b, cl_buffer_c; + +static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) { + cl_int err = 0; + + cl_event events[2]; + events[0] = NULL; + events[1] = NULL; + + if (!cl_initialized) { + char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); + char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); + int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); + int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); + printf("\nInitializing CLBlast (First Run)..."); + printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); + cl_uint num_platforms; + clGetPlatformIDs(0, NULL, &num_platforms); + cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); + clGetPlatformIDs(num_platforms, platforms, NULL); + platform = platforms[plat_num]; + char platform_buffer[1024]; + clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); + cl_uint num_devices; + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); + cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); + device = devices[dev_num]; + char device_buffer[1024]; + clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); + printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); + context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL context: %d\n", err); + fflush(stdout); + } + queue = clCreateCommandQueue(context, device, 0, &err); + + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Command Queue: %d\n", err); + fflush(stdout); + } + + free(platforms); + free(devices); + + cl_size_a = m * k * sizeof(float); + cl_size_b = n * k * sizeof(float); + cl_size_c = m * n * sizeof(float); + + // Prepare buffers + cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer A: %d\n", err); + fflush(stdout); + } + cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer B: %d\n", err); + fflush(stdout); + } + cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer C: %d\n", err); + fflush(stdout); + } + + cl_initialized = true; + } + + // Resize buffers if too small + if (m * k * sizeof(float) > cl_size_a) { + clReleaseMemObject(cl_buffer_a); + cl_size_a = m * k * sizeof(float); + cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, NULL); + } + if (n * k * sizeof(float) > cl_size_b) { + clReleaseMemObject(cl_buffer_b); + cl_size_b = n * k * sizeof(float); + cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, NULL); + } + if (m * n * sizeof(float) > cl_size_c) { + clReleaseMemObject(cl_buffer_c); + cl_size_c = m * n * sizeof(float); + cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, NULL); + } + + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, cl_size_a, host_a, 0, NULL, events); + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, cl_size_b, host_b, 0, NULL, events + 1); + // buffer c is not required for this use case + // clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, cl_size_c, host_c, 0, NULL, NULL); + + clWaitForEvents(2, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + + // Call the SGEMM routine. + CLBlastStatusCode status = CLBlastSgemm(order, + trans_a, trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, events); + + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); + + // Wait for completion + if (status == CLBlastSuccess) { + clWaitForEvents(2, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + } +} + +#endif +#endif + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if GGML_USE_CLBLAST +#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\ +ggml_cl_sgemm_wrapper(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\ +}) +#else +#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\ +cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\ +}) +#endif +#endif From 6f668707266e498a551fa2f5655ee2fb29a35434 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 15 Apr 2023 12:03:11 +0200 Subject: [PATCH 03/19] Finish merge of ClBlast support --- ggml.c | 319 +++++++++++++++++++++++++++------------- ggml_blas_adapter.c | 153 ------------------- ggml_clblast_dequant.cl | 46 ++++++ 3 files changed, 265 insertions(+), 253 deletions(-) delete mode 100644 ggml_blas_adapter.c create mode 100644 ggml_clblast_dequant.cl diff --git a/ggml.c b/ggml.c index bb618ea43e0d3..ad4526298cccf 100644 --- a/ggml.c +++ b/ggml.c @@ -143,28 +143,213 @@ inline static void* ggml_aligned_malloc(size_t size) { } \ } while (0) -#if GGML_USE_CLBLAST -#ifndef GGML_USE_OPENBLAS -#define GGML_USE_OPENBLAS -#endif - +#if defined(GGML_USE_ACCELERATE) +#include +#elif defined(GGML_USE_OPENBLAS) +#include +#elif defined(GGML_USE_CUBLAS) +#include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) #define CL_TARGET_OPENCL_VERSION 110 #include +#include +#include cl_platform_id platform; cl_device_id device; cl_context context; cl_command_queue queue; -cl_event event; +cl_program program; +cl_kernel kernel_q4_0, kernel_q4_1; bool cl_initialized = false; -#endif +size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; -#if defined(GGML_USE_ACCELERATE) -#include -#elif defined(GGML_USE_OPENBLAS) -#include -#elif defined(GGML_USE_CUBLAS) -#include "ggml-cuda.h" +cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { + cl_program program; + char *program_log; + size_t program_size, log_size; + int err; + + program_size = strlen(program_buffer); + + program = clCreateProgramWithSource(ctx, 1, + (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + perror("OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(program, 0, NULL, NULL, NULL, NULL); + if(err < 0) { + + clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + log_size + 1, program_log, NULL); + printf("%s\n", program_log); + free(program_log); + exit(1); + } + + return program; +} + +static void ggml_cl_init() { + cl_int err = 0; + char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); + char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); + int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); + int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); + printf("\nInitializing CLBlast (First Run)..."); + printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); + cl_uint num_platforms; + clGetPlatformIDs(0, NULL, &num_platforms); + cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); + clGetPlatformIDs(num_platforms, platforms, NULL); + platform = platforms[plat_num]; + char platform_buffer[1024]; + clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); + cl_uint num_devices; + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); + cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); + device = devices[dev_num]; + char device_buffer[1024]; + clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); + printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); + context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL context: %d\n", err); + fflush(stdout); + } + queue = clCreateCommandQueue(context, device, 0, &err); + + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Command Queue: %d\n", err); + fflush(stdout); + } + + free(platforms); + free(devices); + + program = build_program_from_source(context, device, clblast_dequant); + + // Prepare dequantize kernels + kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err); + if(err < 0) { + printf("Error creating OpenCL dequantize q4_0 kernel: %d\n", err); + fflush(stdout); + }; + kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err); + if(err < 0) { + printf("Error creating OpenCL dequantize q4_1 kernel: %d\n", err); + fflush(stdout); + }; + cl_initialized = true; +} + +static void ggml_cl_sgemm_wrapper(const CBLAS_ORDER order, const CBLAS_TRANSPOSE trans_a, const CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { + cl_int err = 0; + + cl_event events[4]; + events[0] = NULL; + events[1] = NULL; + events[2] = NULL; + events[3] = NULL; + + if (!cl_initialized) { + ggml_cl_init(); + } + + bool dequant = (btype == 2 || btype == 3); + cl_kernel kernel = btype == 2 ? kernel_q4_0 : kernel_q4_1; + + size_t global = n * k, local = 16, qb_size; + cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; + + // Prepare buffers + cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, m*k*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer A: %d\n", err); + fflush(stdout); + } + if (dequant) { + qb_size = global * (sizeof(float) * (btype == 2 ? 1 : 2) + 16) / 32; + cl_buffer_qb = clCreateBuffer(context, CL_MEM_READ_ONLY, qb_size, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer QB: %d\n", err); + fflush(stdout); + } + } + cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer B: %d\n", err); + fflush(stdout); + } + cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer C: %d\n", err); + fflush(stdout); + } + + if (dequant) { + err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); + err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); + if(err < 0) { + printf("Error setting OpenCL kernel args: %d\n", err); + fflush(stdout); + } + clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1); + } else { + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1); + } + + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); + clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2); + if (dequant) { + err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); + if(err < 0) { + printf("Error enqueueing OpenCL dequantize kernel: %d\n", err); + fflush(stdout); + } + } + clWaitForEvents(dequant ? 4 : 3, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + clReleaseEvent(events[2]); + if (dequant) { + clReleaseEvent(events[3]); + } + + // Call the SGEMM routine. + CLBlastStatusCode status = CLBlastSgemm(order, + trans_a, trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, events); + + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); + + // Wait for completion + if (status == CLBlastSuccess) { + clWaitForEvents(2, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + } + + clReleaseMemObject(cl_buffer_a); + if (dequant) { + clReleaseMemObject(cl_buffer_qb); + } + clReleaseMemObject(cl_buffer_b); + clReleaseMemObject(cl_buffer_c); +} #endif #undef MIN @@ -3705,6 +3890,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize cuBLAS #if defined(GGML_USE_CUBLAS) ggml_init_cublas(); + #elif defined(GGML_USE_CLBLAST) + ggml_cl_init(); #endif is_first_call = false; @@ -7464,84 +7651,6 @@ static bool ggml_compute_forward_mul_mat_use_blas( return false; } -#ifdef GGML_USE_CLBLAST -static bool ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) { - cl_int err = 0; - - if (!cl_initialized) { - cl_uint num_platforms; - clGetPlatformIDs(0, NULL, &num_platforms); - cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); - clGetPlatformIDs(num_platforms, platforms, NULL); - platform = platforms[0]; - cl_uint num_devices; - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); - cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); - device = devices[0]; - context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL context: %d\n", err); - fflush(stdout); - } - queue = clCreateCommandQueue(context, device, 0, &err); - event = NULL; - - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Command Queue: %d\n", err); - fflush(stdout); - } - - free(platforms); - free(devices); - cl_initialized = true; - } - - // Prepare buffers - cl_mem cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_WRITE, m*k*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer A: %d\n", err); - fflush(stdout); - } - cl_mem cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer B: %d\n", err); - fflush(stdout); - } - cl_mem cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer C: %d\n", err); - fflush(stdout); - } - - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, m*k*sizeof(float), host_a, 0, NULL, NULL); - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, n*k*sizeof(float), host_b, 0, NULL, NULL); - clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL); - - // Call the SGEMM routine. - CLBlastStatusCode status = CLBlastSgemm(order, - trans_a, trans_b, - m, n, k, - alpha, - cl_buffer_a, 0, lda, - cl_buffer_b, 0, ldb, - beta, - cl_buffer_c, 0, ldc, - &queue, &event); - - // Wait for completion - if (status == CLBlastSuccess) { - clWaitForEvents(1, &event); - clReleaseEvent(event); - } - - clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL); - - clReleaseMemObject(cl_buffer_a); - clReleaseMemObject(cl_buffer_b); - clReleaseMemObject(cl_buffer_c); -} -#endif #endif static void ggml_compute_forward_mul_mat_f32( @@ -7663,14 +7772,14 @@ static void ggml_compute_forward_mul_mat_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#else +#elif defined(GGML_USE_CLBLAST) // zT = y * xT -#ifdef GGML_USE_CLBLAST ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, - 0.0f, d, ne01); + 0.0f, d, ne01, + params->type); #else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, @@ -7892,20 +8001,26 @@ static void ggml_compute_forward_mul_mat_f16_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#else +#elif defined(GGML_USE_CLBLAST) const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // zT = y * xT -#ifdef GGML_USE_CLBLAST ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, - 0.0f, d, ne01); + 0.0f, d, ne01, + params->type); #else + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, @@ -8135,6 +8250,7 @@ static void ggml_compute_forward_mul_mat_q_f32( dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); CUDA_CHECK(cudaGetLastError()); #else +#ifndef GGML_USE_CLBLAST { size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { @@ -8143,6 +8259,9 @@ static void ggml_compute_forward_mul_mat_q_f32( } } const float * x = wdata; +#else + const void* x = src0->data + i03*nb03 + i02*nb02; +#endif #endif @@ -8160,14 +8279,14 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#else +#elif defined(GGML_USE_CLBLAST) // zT = y * xT -#ifdef GGML_USE_CLBLAST ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, - 0.0f, d, ne01); + 0.0f, d, ne01, + type); #else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, diff --git a/ggml_blas_adapter.c b/ggml_blas_adapter.c deleted file mode 100644 index 50faf3807bf02..0000000000000 --- a/ggml_blas_adapter.c +++ /dev/null @@ -1,153 +0,0 @@ -//this is a drop-in for all CLBlast related code, to keep the main ggml.c unmodified -// we will imitate the function definition from OpenBLAS instead, replaced as necessary. - -//windows binaries for clblast obtained from https://github.com/CNugteren/CLBlast (apache license) -//windows binaries for opencl obtained from https://github.com/KhronosGroup/OpenCL-SDK (apache license) - -#if GGML_USE_OPENBLAS -#include -#include -#include - -#if GGML_USE_CLBLAST - -#define CL_TARGET_OPENCL_VERSION 110 -#include - -cl_platform_id platform; -cl_device_id device; -cl_context context; -cl_command_queue queue; -bool cl_initialized = false; -size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; -cl_mem cl_buffer_a, cl_buffer_b, cl_buffer_c; - -static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) { - cl_int err = 0; - - cl_event events[2]; - events[0] = NULL; - events[1] = NULL; - - if (!cl_initialized) { - char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); - char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); - int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); - int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); - printf("\nInitializing CLBlast (First Run)..."); - printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); - cl_uint num_platforms; - clGetPlatformIDs(0, NULL, &num_platforms); - cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); - clGetPlatformIDs(num_platforms, platforms, NULL); - platform = platforms[plat_num]; - char platform_buffer[1024]; - clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); - cl_uint num_devices; - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); - cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); - device = devices[dev_num]; - char device_buffer[1024]; - clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); - printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); - context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL context: %d\n", err); - fflush(stdout); - } - queue = clCreateCommandQueue(context, device, 0, &err); - - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Command Queue: %d\n", err); - fflush(stdout); - } - - free(platforms); - free(devices); - - cl_size_a = m * k * sizeof(float); - cl_size_b = n * k * sizeof(float); - cl_size_c = m * n * sizeof(float); - - // Prepare buffers - cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer A: %d\n", err); - fflush(stdout); - } - cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer B: %d\n", err); - fflush(stdout); - } - cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer C: %d\n", err); - fflush(stdout); - } - - cl_initialized = true; - } - - // Resize buffers if too small - if (m * k * sizeof(float) > cl_size_a) { - clReleaseMemObject(cl_buffer_a); - cl_size_a = m * k * sizeof(float); - cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, NULL); - } - if (n * k * sizeof(float) > cl_size_b) { - clReleaseMemObject(cl_buffer_b); - cl_size_b = n * k * sizeof(float); - cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, NULL); - } - if (m * n * sizeof(float) > cl_size_c) { - clReleaseMemObject(cl_buffer_c); - cl_size_c = m * n * sizeof(float); - cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, NULL); - } - - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, cl_size_a, host_a, 0, NULL, events); - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, cl_size_b, host_b, 0, NULL, events + 1); - // buffer c is not required for this use case - // clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, cl_size_c, host_c, 0, NULL, NULL); - - clWaitForEvents(2, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); - - // Call the SGEMM routine. - CLBlastStatusCode status = CLBlastSgemm(order, - trans_a, trans_b, - m, n, k, - alpha, - cl_buffer_a, 0, lda, - cl_buffer_b, 0, ldb, - beta, - cl_buffer_c, 0, ldc, - &queue, events); - - clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); - - // Wait for completion - if (status == CLBlastSuccess) { - clWaitForEvents(2, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); - } -} - -#endif -#endif - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) -#if GGML_USE_CLBLAST -#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\ -ggml_cl_sgemm_wrapper(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\ -}) -#else -#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\ -cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\ -}) -#endif -#endif diff --git a/ggml_clblast_dequant.cl b/ggml_clblast_dequant.cl new file mode 100644 index 0000000000000..47d39f8a303e1 --- /dev/null +++ b/ggml_clblast_dequant.cl @@ -0,0 +1,46 @@ +#define MULTILINE_QUOTE(...) #__VA_ARGS__ +const char * clblast_dequant = MULTILINE_QUOTE( + +struct __attribute__ ((packed)) block_q4_0 +{ + float d; + uchar qs[16]; +}; + +__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) { + uint i, l; + i = get_global_id(0) / 32; + l = get_local_id(0); + + float d = blocks[i].d; + + uchar vi = blocks[i].qs[l]; + + uint index = i*32 + l*2; + result[index + 0] = ((vi & 0xf) - 8)*d; + result[index + 1] = ((vi >> 4) - 8)*d; +} + +struct __attribute__ ((packed)) block_q4_1 +{ + float d; + float m; + uchar qs[16]; +}; + +__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) { + uint i, l; + i = get_global_id(0) / 32; + l = get_local_id(0); + + float d = blocks[i].d; + float m = blocks[i].m; + + uchar vi = blocks[i].qs[l]; + + uint index = i*32 + l*2; + result[index + 0] = (vi & 0xf) * d + m; + result[index + 1] = (vi >> 4) * d + m; +} + +); From 1b16b8c90dc3f555db70df3f72af00d677f84ffe Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 23 Apr 2023 09:59:45 +0200 Subject: [PATCH 04/19] Move CLBlast implementation to separate file Add buffer reuse code (adapted from slaren's cuda implementation) --- Makefile | 11 ++- ggml-opencl.cpp | 239 ++++++++++++++++++++++++++++++++++++++++++++++++ ggml-opencl.h | 31 +++++++ ggml.c | 207 +---------------------------------------- 4 files changed, 281 insertions(+), 207 deletions(-) create mode 100644 ggml-opencl.cpp create mode 100644 ggml-opencl.h diff --git a/Makefile b/Makefile index 34824d6907277..ca945fa990a3c 100644 --- a/Makefile +++ b/Makefile @@ -34,8 +34,8 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC +CFLAGS = -I. -O3 -DNODEBUG -std=c11 -fPIC +CXXFLAGS = -I. -I./examples -O3 -DNODEBUG -std=c++11 -fPIC LDFLAGS = # warnings @@ -105,8 +105,8 @@ ifdef LLAMA_OPENBLAS LDFLAGS += -lopenblas endif ifdef LLAMA_CUBLAS - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 + CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include + LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native @@ -116,6 +116,9 @@ endif ifdef LLAMA_CLBLAST CFLAGS += -DGGML_USE_CLBLAST LDFLAGS += -lclblast -lOpenCL + OBJS += ggml-opencl.o +ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h + $(CXX) $(CXXFLAGS) -c $< -o $@ endif ifdef LLAMA_GPROF CFLAGS += -pg diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp new file mode 100644 index 0000000000000..fa462ff8e7e2e --- /dev/null +++ b/ggml-opencl.cpp @@ -0,0 +1,239 @@ +#include "ggml-opencl.h" + +#include +#include +#include + +#include + +struct scoped_spin_lock { + std::atomic_flag& lock; + scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { + while (lock.test_and_set(std::memory_order_acquire)) { + ; // spin + } + } + ~scoped_spin_lock() { + lock.clear(std::memory_order_release); + } + scoped_spin_lock(const scoped_spin_lock&) = delete; + scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; +}; + +struct cl_buffer { + cl_mem mem; + size_t size = 0; +}; + +cl_platform_id platform; +cl_device_id device; +cl_context context; +cl_command_queue queue; +cl_program program; +cl_kernel kernel_q4_0, kernel_q4_1; +bool cl_initialized = false; +size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; + +static cl_buffer g_cl_buffer_pool[MAX_CL_BUFFERS]; +static std::atomic_flag g_cl_pool_lock = ATOMIC_FLAG_INIT; + +cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size) { + scoped_spin_lock lock(g_cl_pool_lock); + + for (int i = 0; i < MAX_CL_BUFFERS; ++i) { + cl_buffer& b = g_cl_buffer_pool[i]; + if (b.size >= size && b.size != 0) { + cl_mem mem = b.mem; + *actual_size = b.size; + b.size = 0; + return mem; + } + } + cl_int err; + cl_mem mem = clCreateBuffer(context, 0, size, NULL, &err); + *actual_size = size; + CL_CHECK(err, "clCreateBuffer"); + return mem; +} + +void ggml_cl_pool_free(cl_mem mem, size_t size) { + scoped_spin_lock lock(g_cl_pool_lock); + + for (int i = 0; i < MAX_CL_BUFFERS; ++i) { + cl_buffer& b = g_cl_buffer_pool[i]; + if (b.size == 0) { + b.mem = mem; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: cl buffer pool full, increase MAX_CL_BUFFERS\n"); + clReleaseMemObject(mem); +} + +cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { + cl_program program; + char *program_log; + size_t program_size, log_size; + int err; + + program_size = strlen(program_buffer); + + program = clCreateProgramWithSource(ctx, 1, + (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + fprintf(stderr, "OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(program, 0, NULL, NULL, NULL, NULL); + if(err < 0) { + + clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + log_size + 1, program_log, NULL); + printf("%s\n", program_log); + free(program_log); + exit(1); + } + + return program; +} + +void ggml_cl_init() { + cl_int err = 0; + char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); + char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); + int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); + int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); + printf("\nInitializing CLBlast (First Run)..."); + printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); + cl_uint num_platforms; + clGetPlatformIDs(0, NULL, &num_platforms); + cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); + clGetPlatformIDs(num_platforms, platforms, NULL); + platform = platforms[plat_num]; + char platform_buffer[1024]; + clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); + cl_uint num_devices; + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); + cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); + device = devices[dev_num]; + char device_buffer[1024]; + clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); + printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); + context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL context: %d\n", err); + fflush(stdout); + } + queue = clCreateCommandQueue(context, device, 0, &err); + + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Command Queue: %d\n", err); + fflush(stdout); + } + + free(platforms); + free(devices); + + program = build_program_from_source(context, device, clblast_dequant); + + // Prepare dequantize kernels + kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err); + if(err < 0) { + printf("Error creating OpenCL dequantize q4_0 kernel: %d\n", err); + fflush(stdout); + }; + kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err); + if(err < 0) { + printf("Error creating OpenCL dequantize q4_1 kernel: %d\n", err); + fflush(stdout); + }; + cl_initialized = true; +} + +void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { + cl_int err = 0; + + cl_event events[4]; + events[0] = NULL; + events[1] = NULL; + events[2] = NULL; + events[3] = NULL; + + bool dequant = (btype == 2 || btype == 3); + cl_kernel kernel = btype == 2 ? kernel_q4_0 : kernel_q4_1; + + size_t global = n * k, local = 16, qb_size; + cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; + + size_t buf_size_a, buf_size_qb, buf_size_b, buf_size_c; + + // Prepare buffers + cl_buffer_a = ggml_cl_pool_malloc(m * k * sizeof(float), &buf_size_a); + if (dequant) { + qb_size = global * (sizeof(float) * (btype == 2 ? 1 : 2) + 16) / 32; + cl_buffer_qb = ggml_cl_pool_malloc(qb_size, &buf_size_qb); + } + cl_buffer_b = ggml_cl_pool_malloc(n*k*sizeof(float), &buf_size_b); + cl_buffer_c = ggml_cl_pool_malloc(m*n*sizeof(float), &buf_size_c); + + if (dequant) { + err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); + err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); + if(err < 0) { + printf("Error setting OpenCL kernel args: %d\n", err); + fflush(stdout); + } + clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1); + } else { + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1); + } + + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); + clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2); + if (dequant) { + err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); + if(err < 0) { + printf("Error enqueueing OpenCL dequantize kernel: %d\n", err); + fflush(stdout); + } + } + clWaitForEvents(dequant ? 4 : 3, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + clReleaseEvent(events[2]); + if (dequant) { + clReleaseEvent(events[3]); + } + + // Call the SGEMM routine. + CLBlastStatusCode status = CLBlastSgemm(order, + trans_a, trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, events); + + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); + + // Wait for completion + clWaitForEvents(2, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + + ggml_cl_pool_free(cl_buffer_a, buf_size_a); + if (dequant) { + ggml_cl_pool_free(cl_buffer_qb, buf_size_qb); + } + ggml_cl_pool_free(cl_buffer_b, buf_size_b); + ggml_cl_pool_free(cl_buffer_c, buf_size_c); +} diff --git a/ggml-opencl.h b/ggml-opencl.h new file mode 100644 index 0000000000000..2486845d84dd6 --- /dev/null +++ b/ggml-opencl.h @@ -0,0 +1,31 @@ +#pragma once + +#define CL_TARGET_OPENCL_VERSION 110 +#include +#define MAX_CL_BUFFERS 16 + +#ifdef __cplusplus +extern "C" { +#endif + +// Buffer reuse code adapted from cuda implementation by slaren +#define CL_CHECK(err, name) \ + do { \ + cl_int err_ = (err); \ + if (err_ != CL_SUCCESS) { \ + fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size); +void ggml_cl_pool_free(cl_mem mem, size_t size); + +cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer); +void ggml_cl_init(); + +void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); + +#ifdef __cplusplus +} +#endif diff --git a/ggml.c b/ggml.c index ad4526298cccf..3daff3fe9df73 100644 --- a/ggml.c +++ b/ggml.c @@ -150,206 +150,7 @@ inline static void* ggml_aligned_malloc(size_t size) { #elif defined(GGML_USE_CUBLAS) #include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) -#define CL_TARGET_OPENCL_VERSION 110 -#include -#include -#include - -cl_platform_id platform; -cl_device_id device; -cl_context context; -cl_command_queue queue; -cl_program program; -cl_kernel kernel_q4_0, kernel_q4_1; -bool cl_initialized = false; -size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; - -cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { - cl_program program; - char *program_log; - size_t program_size, log_size; - int err; - - program_size = strlen(program_buffer); - - program = clCreateProgramWithSource(ctx, 1, - (const char**)&program_buffer, &program_size, &err); - if(err < 0) { - perror("OpenCL error creating program"); - exit(1); - } - - err = clBuildProgram(program, 0, NULL, NULL, NULL, NULL); - if(err < 0) { - - clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, - 0, NULL, &log_size); - program_log = (char*) malloc(log_size + 1); - program_log[log_size] = '\0'; - clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, - log_size + 1, program_log, NULL); - printf("%s\n", program_log); - free(program_log); - exit(1); - } - - return program; -} - -static void ggml_cl_init() { - cl_int err = 0; - char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); - char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); - int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); - int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); - printf("\nInitializing CLBlast (First Run)..."); - printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); - cl_uint num_platforms; - clGetPlatformIDs(0, NULL, &num_platforms); - cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); - clGetPlatformIDs(num_platforms, platforms, NULL); - platform = platforms[plat_num]; - char platform_buffer[1024]; - clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); - cl_uint num_devices; - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); - cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); - device = devices[dev_num]; - char device_buffer[1024]; - clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); - printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); - context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL context: %d\n", err); - fflush(stdout); - } - queue = clCreateCommandQueue(context, device, 0, &err); - - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Command Queue: %d\n", err); - fflush(stdout); - } - - free(platforms); - free(devices); - - program = build_program_from_source(context, device, clblast_dequant); - - // Prepare dequantize kernels - kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err); - if(err < 0) { - printf("Error creating OpenCL dequantize q4_0 kernel: %d\n", err); - fflush(stdout); - }; - kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err); - if(err < 0) { - printf("Error creating OpenCL dequantize q4_1 kernel: %d\n", err); - fflush(stdout); - }; - cl_initialized = true; -} - -static void ggml_cl_sgemm_wrapper(const CBLAS_ORDER order, const CBLAS_TRANSPOSE trans_a, const CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { - cl_int err = 0; - - cl_event events[4]; - events[0] = NULL; - events[1] = NULL; - events[2] = NULL; - events[3] = NULL; - - if (!cl_initialized) { - ggml_cl_init(); - } - - bool dequant = (btype == 2 || btype == 3); - cl_kernel kernel = btype == 2 ? kernel_q4_0 : kernel_q4_1; - - size_t global = n * k, local = 16, qb_size; - cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; - - // Prepare buffers - cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, m*k*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer A: %d\n", err); - fflush(stdout); - } - if (dequant) { - qb_size = global * (sizeof(float) * (btype == 2 ? 1 : 2) + 16) / 32; - cl_buffer_qb = clCreateBuffer(context, CL_MEM_READ_ONLY, qb_size, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer QB: %d\n", err); - fflush(stdout); - } - } - cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer B: %d\n", err); - fflush(stdout); - } - cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer C: %d\n", err); - fflush(stdout); - } - - if (dequant) { - err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); - err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); - if(err < 0) { - printf("Error setting OpenCL kernel args: %d\n", err); - fflush(stdout); - } - clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1); - } else { - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1); - } - - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); - clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2); - if (dequant) { - err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); - if(err < 0) { - printf("Error enqueueing OpenCL dequantize kernel: %d\n", err); - fflush(stdout); - } - } - clWaitForEvents(dequant ? 4 : 3, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); - clReleaseEvent(events[2]); - if (dequant) { - clReleaseEvent(events[3]); - } - - // Call the SGEMM routine. - CLBlastStatusCode status = CLBlastSgemm(order, - trans_a, trans_b, - m, n, k, - alpha, - cl_buffer_a, 0, lda, - cl_buffer_b, 0, ldb, - beta, - cl_buffer_c, 0, ldc, - &queue, events); - - clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); - - // Wait for completion - if (status == CLBlastSuccess) { - clWaitForEvents(2, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); - } - - clReleaseMemObject(cl_buffer_a); - if (dequant) { - clReleaseMemObject(cl_buffer_qb); - } - clReleaseMemObject(cl_buffer_b); - clReleaseMemObject(cl_buffer_c); -} +#include "ggml-opencl.h" #endif #undef MIN @@ -7774,7 +7575,7 @@ static void ggml_compute_forward_mul_mat_f32( CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #elif defined(GGML_USE_CLBLAST) // zT = y * xT - ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, + ggml_cl_sgemm_wrapper(CLBlastLayoutRowMajor, CLBlastTransposeNo, CLBlastTransposeYes, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, @@ -8008,7 +7809,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // zT = y * xT - ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, + ggml_cl_sgemm_wrapper(CLBlastLayoutRowMajor, CLBlastTransposeNo, CLBlastTransposeYes, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, @@ -8281,7 +8082,7 @@ static void ggml_compute_forward_mul_mat_q_f32( CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #elif defined(GGML_USE_CLBLAST) // zT = y * xT - ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, + ggml_cl_sgemm_wrapper(CLBlastLayoutRowMajor, CLBlastTransposeNo, CLBlastTransposeYes, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, From 309af7fce92c071cd6ef872cb8ec1e802fee641f Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 24 Apr 2023 07:16:43 +0200 Subject: [PATCH 05/19] Add q4_2 and q4_3 CLBlast support, improve code --- ggml-opencl.cpp | 73 ++++++++++++++++++++++------------------- ggml_clblast_dequant.cl | 68 +++++++++++++++++++++++++++++--------- 2 files changed, 93 insertions(+), 48 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index fa462ff8e7e2e..904e32ac9f9b1 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -30,8 +30,7 @@ cl_device_id device; cl_context context; cl_command_queue queue; cl_program program; -cl_kernel kernel_q4_0, kernel_q4_1; -bool cl_initialized = false; +cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3; size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; static cl_buffer g_cl_buffer_pool[MAX_CL_BUFFERS]; @@ -127,16 +126,9 @@ void ggml_cl_init() { clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL context: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clCreateContext"); queue = clCreateCommandQueue(context, device, 0, &err); - - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Command Queue: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clCreateCommandQueue"); free(platforms); free(devices); @@ -145,16 +137,13 @@ void ggml_cl_init() { // Prepare dequantize kernels kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err); - if(err < 0) { - printf("Error creating OpenCL dequantize q4_0 kernel: %d\n", err); - fflush(stdout); - }; + CL_CHECK(err, "clCreateKernel"); kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err); - if(err < 0) { - printf("Error creating OpenCL dequantize q4_1 kernel: %d\n", err); - fflush(stdout); - }; - cl_initialized = true; + CL_CHECK(err, "clCreateKernel"); + kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err); + CL_CHECK(err, "clCreateKernel"); + kernel_q4_3 = clCreateKernel(program, "dequantize_row_q4_3", &err); + CL_CHECK(err, "clCreateKernel"); } void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { @@ -166,10 +155,36 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra events[2] = NULL; events[3] = NULL; - bool dequant = (btype == 2 || btype == 3); - cl_kernel kernel = btype == 2 ? kernel_q4_0 : kernel_q4_1; + cl_kernel kernel; + size_t global, local, qb_size; + bool dequant = btype >= 2 && btype < 6; + if (dequant) { + global = n * k; + + switch (btype) { + case 2: + kernel = kernel_q4_0; + local = 16; + qb_size = global * (sizeof(float) + local) / 32; + break; + case 3: + kernel = kernel_q4_1; + local = 16; + qb_size = global * (sizeof(float) * 2 + local) / 32; + break; + case 4: + kernel = kernel_q4_2; + local = 8; + qb_size = global * (sizeof(short) + local) / 16; + break; + case 5: + kernel = kernel_q4_3; + local = 8; + qb_size = global * (sizeof(short) * 2 + local) / 16; + break; + } + } - size_t global = n * k, local = 16, qb_size; cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; size_t buf_size_a, buf_size_qb, buf_size_b, buf_size_c; @@ -177,7 +192,6 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra // Prepare buffers cl_buffer_a = ggml_cl_pool_malloc(m * k * sizeof(float), &buf_size_a); if (dequant) { - qb_size = global * (sizeof(float) * (btype == 2 ? 1 : 2) + 16) / 32; cl_buffer_qb = ggml_cl_pool_malloc(qb_size, &buf_size_qb); } cl_buffer_b = ggml_cl_pool_malloc(n*k*sizeof(float), &buf_size_b); @@ -186,23 +200,16 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra if (dequant) { err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); - if(err < 0) { - printf("Error setting OpenCL kernel args: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clSetKernelArg"); clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1); } else { clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1); } clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); - clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2); if (dequant) { err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); - if(err < 0) { - printf("Error enqueueing OpenCL dequantize kernel: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clEnqueueNDRangeKernel"); } clWaitForEvents(dequant ? 4 : 3, events); clReleaseEvent(events[0]); diff --git a/ggml_clblast_dequant.cl b/ggml_clblast_dequant.cl index 47d39f8a303e1..99474fdb3aa33 100644 --- a/ggml_clblast_dequant.cl +++ b/ggml_clblast_dequant.cl @@ -1,27 +1,26 @@ #define MULTILINE_QUOTE(...) #__VA_ARGS__ const char * clblast_dequant = MULTILINE_QUOTE( -struct __attribute__ ((packed)) block_q4_0 +struct block_q4_0 { float d; uchar qs[16]; }; __kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) { - uint i, l; - i = get_global_id(0) / 32; - l = get_local_id(0); + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); - float d = blocks[i].d; + const float d = blocks[i].d; - uchar vi = blocks[i].qs[l]; + const uchar vi = blocks[i].qs[l]; - uint index = i*32 + l*2; + const uint index = i*32 + l*2; result[index + 0] = ((vi & 0xf) - 8)*d; result[index + 1] = ((vi >> 4) - 8)*d; } -struct __attribute__ ((packed)) block_q4_1 +struct block_q4_1 { float d; float m; @@ -29,16 +28,55 @@ struct __attribute__ ((packed)) block_q4_1 }; __kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) { - uint i, l; - i = get_global_id(0) / 32; - l = get_local_id(0); + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); - float d = blocks[i].d; - float m = blocks[i].m; + const float d = blocks[i].d; + const float m = blocks[i].m; - uchar vi = blocks[i].qs[l]; + const uchar vi = blocks[i].qs[l]; - uint index = i*32 + l*2; + const uint index = i*32 + l*2; + result[index + 0] = (vi & 0xf) * d + m; + result[index + 1] = (vi >> 4) * d + m; +} + +struct block_q4_2 +{ + ushort d; + uchar qs[8]; +}; + +__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) { + const uint i = get_global_id(0) / 16; + const uint l = get_local_id(0); + + const float d = vload_half(0, (const half*) &blocks[i].d);; + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*16 + l*2; + result[index + 0] = ((vi & 0xf) - 8)*d; + result[index + 1] = ((vi >> 4) - 8)*d; +} + +struct block_q4_3 +{ + ushort d; + ushort m; + uchar qs[8]; +}; + +__kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global float* result) { + const uint i = get_global_id(0) / 16; + const uint l = get_local_id(0); + + const float d = vload_half(0, (const half*) &(blocks[i].d)); + const float m = vload_half(0, (const half*) &(blocks[i].m)); + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*16 + l*2; result[index + 0] = (vi & 0xf) * d + m; result[index + 1] = (vi >> 4) * d + m; } From f469d9afa0aa507391e882d82ce9e18257eff26e Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 24 Apr 2023 15:15:23 +0200 Subject: [PATCH 06/19] Double CLBlast speed by disabling OpenBLAS thread workaround Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> Co-authored-by: slaren <2141330+slaren@users.noreply.github.com> --- ggml.c | 12 ++++++++++++ ggml.h | 3 ++- llama.cpp | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 3daff3fe9df73..cd2ec11bab4b2 100644 --- a/ggml.c +++ b/ggml.c @@ -12354,6 +12354,18 @@ int ggml_cpu_has_cublas(void) { #endif } +int ggml_cpu_has_clblast(void) { +#if defined(GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_gpublas(void) { + return ggml_cpu_has_cublas() || ggml_cpu_has_clblast(); +} + int ggml_cpu_has_sse3(void) { #if defined(__SSE3__) return 1; diff --git a/ggml.h b/ggml.h index 2758907818182..0e8b1ba5e2f1a 100644 --- a/ggml.h +++ b/ggml.h @@ -852,10 +852,11 @@ extern "C" { GGML_API int ggml_cpu_has_wasm_simd (void); GGML_API int ggml_cpu_has_blas (void); GGML_API int ggml_cpu_has_cublas (void); + GGML_API int ggml_cpu_has_clblast (void); + GGML_API int ggml_cpu_has_gpublas (void); GGML_API int ggml_cpu_has_sse3 (void); GGML_API int ggml_cpu_has_vsx (void); - // // Internal types and functions exposed for tests and benchmarks // diff --git a/llama.cpp b/llama.cpp index 28d27916a049d..5cb7068e9ef7f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1076,7 +1076,7 @@ static bool llama_eval_internal( // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ggml_cgraph gf = {}; - gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads; + gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); From 8603c25e3ce074d4717584a2c2ba60622e67b287 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 24 Apr 2023 15:53:48 +0200 Subject: [PATCH 07/19] Fix device selection env variable names --- ggml-opencl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 904e32ac9f9b1..e7f40bbefa18c 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -104,10 +104,10 @@ cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const cha void ggml_cl_init() { cl_int err = 0; - char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); - char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); - int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); - int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); + char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM"); + char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE"); + int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM)); + int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE)); printf("\nInitializing CLBlast (First Run)..."); printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); cl_uint num_platforms; From 18cc05bde463ceb3c29d160c32ac88ae596e0c89 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 24 Apr 2023 16:13:43 +0200 Subject: [PATCH 08/19] Fix cast in opencl kernels --- ggml_clblast_dequant.cl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml_clblast_dequant.cl b/ggml_clblast_dequant.cl index 99474fdb3aa33..191b2e57500ad 100644 --- a/ggml_clblast_dequant.cl +++ b/ggml_clblast_dequant.cl @@ -51,7 +51,7 @@ __kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global f const uint i = get_global_id(0) / 16; const uint l = get_local_id(0); - const float d = vload_half(0, (const half*) &blocks[i].d);; + const float d = vload_half(0, (__global half*) &blocks[i].d);; const uchar vi = blocks[i].qs[l]; @@ -71,8 +71,8 @@ __kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global f const uint i = get_global_id(0) / 16; const uint l = get_local_id(0); - const float d = vload_half(0, (const half*) &(blocks[i].d)); - const float m = vload_half(0, (const half*) &(blocks[i].m)); + const float d = vload_half(0, (__global half*) &(blocks[i].d)); + const float m = vload_half(0, (__global half*) &(blocks[i].m)); const uchar vi = blocks[i].qs[l]; From ae73887fb9543cee671883b8466b119a77bb27cc Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 24 Apr 2023 21:22:41 +0200 Subject: [PATCH 09/19] Add CLBlast to CMakeLists.txt --- CMakeLists.txt | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 11ebe9eb66fae..79dfd070b9097 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,6 +67,7 @@ endif() option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -168,6 +169,21 @@ if (LLAMA_CUBLAS) endif() endif() +if (LLAMA_CLBLAST) + find_package(CLBlast) + if (CLBlast_FOUND) + message(STATUS "CLBlast found") + + set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + else() + message(WARNING "CLBlast not found") + endif() +endif() + if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(c_flags @@ -307,7 +323,8 @@ endif() add_library(ggml OBJECT ggml.c ggml.h - ${GGML_CUDA_SOURCES}) + ${GGML_CUDA_SOURCES} + ${GGML_OPENCL_SOURCES}) target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump From daa5df51f773225e92b0b28ca5f0788c7482fb87 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 24 Apr 2023 22:08:51 +0200 Subject: [PATCH 10/19] Replace buffer pool with static buffers a, b, qb, c Fix compile warnings --- ggml-opencl.cpp | 138 ++++++++++++++++-------------------------------- ggml-opencl.h | 2 +- ggml.c | 8 ++- 3 files changed, 50 insertions(+), 98 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index e7f40bbefa18c..3ae56fd287dff 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -4,26 +4,9 @@ #include #include -#include - -struct scoped_spin_lock { - std::atomic_flag& lock; - scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { - while (lock.test_and_set(std::memory_order_acquire)) { - ; // spin - } - } - ~scoped_spin_lock() { - lock.clear(std::memory_order_release); - } - scoped_spin_lock(const scoped_spin_lock&) = delete; - scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; -}; +#include "ggml.h" -struct cl_buffer { - cl_mem mem; - size_t size = 0; -}; +#include cl_platform_id platform; cl_device_id device; @@ -31,44 +14,8 @@ cl_context context; cl_command_queue queue; cl_program program; cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3; -size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; - -static cl_buffer g_cl_buffer_pool[MAX_CL_BUFFERS]; -static std::atomic_flag g_cl_pool_lock = ATOMIC_FLAG_INIT; - -cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size) { - scoped_spin_lock lock(g_cl_pool_lock); - - for (int i = 0; i < MAX_CL_BUFFERS; ++i) { - cl_buffer& b = g_cl_buffer_pool[i]; - if (b.size >= size && b.size != 0) { - cl_mem mem = b.mem; - *actual_size = b.size; - b.size = 0; - return mem; - } - } - cl_int err; - cl_mem mem = clCreateBuffer(context, 0, size, NULL, &err); - *actual_size = size; - CL_CHECK(err, "clCreateBuffer"); - return mem; -} - -void ggml_cl_pool_free(cl_mem mem, size_t size) { - scoped_spin_lock lock(g_cl_pool_lock); - - for (int i = 0; i < MAX_CL_BUFFERS; ++i) { - cl_buffer& b = g_cl_buffer_pool[i]; - if (b.size == 0) { - b.mem = mem; - b.size = size; - return; - } - } - fprintf(stderr, "WARNING: cl buffer pool full, increase MAX_CL_BUFFERS\n"); - clReleaseMemObject(mem); -} +cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; +size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { cl_program program; @@ -102,7 +49,7 @@ cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const cha return program; } -void ggml_cl_init() { +void ggml_cl_init(void) { cl_int err = 0; char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM"); char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE"); @@ -146,6 +93,21 @@ void ggml_cl_init() { CL_CHECK(err, "clCreateKernel"); } +void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) { + if (req_size <= *cur_size) { + return; + } + + // Reallocate buffer with enough space + if (*cur_size > 0) { + clReleaseMemObject(*buf); + } + cl_int err; + *buf = clCreateBuffer(context, flags, req_size, NULL, &err); + *cur_size = req_size; + CL_CHECK(err, "clCreateBuffer"); +} + void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { cl_int err = 0; @@ -156,8 +118,8 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra events[3] = NULL; cl_kernel kernel; - size_t global, local, qb_size; - bool dequant = btype >= 2 && btype < 6; + size_t global, local, size_qb; + const bool dequant = btype >= 2 && btype < 6; if (dequant) { global = n * k; @@ -165,48 +127,48 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra case 2: kernel = kernel_q4_0; local = 16; - qb_size = global * (sizeof(float) + local) / 32; + size_qb = global * (sizeof(float) + local) / 32; break; case 3: kernel = kernel_q4_1; local = 16; - qb_size = global * (sizeof(float) * 2 + local) / 32; + size_qb = global * (sizeof(float) * 2 + local) / 32; break; case 4: kernel = kernel_q4_2; local = 8; - qb_size = global * (sizeof(short) + local) / 16; + size_qb = global * (sizeof(short) + local) / 16; break; case 5: kernel = kernel_q4_3; local = 8; - qb_size = global * (sizeof(short) * 2 + local) / 16; + size_qb = global * (sizeof(short) * 2 + local) / 16; break; } } - cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; - - size_t buf_size_a, buf_size_qb, buf_size_b, buf_size_c; + const size_t size_a = m * k * sizeof(float); + const size_t size_b = n * k * sizeof(float); + const size_t size_c = m * n * sizeof(float); // Prepare buffers - cl_buffer_a = ggml_cl_pool_malloc(m * k * sizeof(float), &buf_size_a); + ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a); if (dequant) { - cl_buffer_qb = ggml_cl_pool_malloc(qb_size, &buf_size_qb); + ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb); } - cl_buffer_b = ggml_cl_pool_malloc(n*k*sizeof(float), &buf_size_b); - cl_buffer_c = ggml_cl_pool_malloc(m*n*sizeof(float), &buf_size_c); + ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b); + ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c); if (dequant) { err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); CL_CHECK(err, "clSetKernelArg"); - clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1); + clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, events + 1); } else { - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1); + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, events + 1); } - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, events); if (dequant) { err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); CL_CHECK(err, "clEnqueueNDRangeKernel"); @@ -219,28 +181,20 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra clReleaseEvent(events[3]); } - // Call the SGEMM routine. - CLBlastStatusCode status = CLBlastSgemm(order, - trans_a, trans_b, - m, n, k, - alpha, - cl_buffer_a, 0, lda, - cl_buffer_b, 0, ldb, - beta, - cl_buffer_c, 0, ldc, - &queue, events); + CLBlastSgemm(order, + trans_a, trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, events); - clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, events, events + 1); // Wait for completion clWaitForEvents(2, events); clReleaseEvent(events[0]); clReleaseEvent(events[1]); - - ggml_cl_pool_free(cl_buffer_a, buf_size_a); - if (dequant) { - ggml_cl_pool_free(cl_buffer_qb, buf_size_qb); - } - ggml_cl_pool_free(cl_buffer_b, buf_size_b); - ggml_cl_pool_free(cl_buffer_c, buf_size_c); } diff --git a/ggml-opencl.h b/ggml-opencl.h index 2486845d84dd6..9d7f911737baf 100644 --- a/ggml-opencl.h +++ b/ggml-opencl.h @@ -22,7 +22,7 @@ cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size); void ggml_cl_pool_free(cl_mem mem, size_t size); cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer); -void ggml_cl_init(); +void ggml_cl_init(void); void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); diff --git a/ggml.c b/ggml.c index cd2ec11bab4b2..fc06acf1e7f3a 100644 --- a/ggml.c +++ b/ggml.c @@ -8031,7 +8031,7 @@ static void ggml_compute_forward_mul_mat_q_f32( else { GGML_ASSERT(false); } -#else +#elif !defined(GGML_USE_CLBLAST) float * const wdata = params->wdata; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; #endif @@ -8050,8 +8050,9 @@ static void ggml_compute_forward_mul_mat_q_f32( dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); CUDA_CHECK(cudaGetLastError()); +#elif defined(GGML_USE_CLBLAST) + const void* x = (char *) src0->data + i03*nb03 + i02*nb02; #else -#ifndef GGML_USE_CLBLAST { size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { @@ -8060,9 +8061,6 @@ static void ggml_compute_forward_mul_mat_q_f32( } } const float * x = wdata; -#else - const void* x = src0->data + i03*nb03 + i02*nb02; -#endif #endif From 36bfb3c1585abc9f57fcbaf28ba60e5bd5180482 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 25 Apr 2023 18:43:31 +0200 Subject: [PATCH 11/19] Fix typos, use GGML_TYPE defines, improve code --- Makefile | 4 ++-- ggml-opencl.cpp | 14 +++++--------- ggml.c | 4 ++-- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index ca945fa990a3c..9b9c8ee7a877f 100644 --- a/Makefile +++ b/Makefile @@ -34,8 +34,8 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -DNODEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNODEBUG -std=c++11 -fPIC +CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC LDFLAGS = # warnings diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 3ae56fd287dff..8a87edb881f95 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -111,11 +111,7 @@ void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_me void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { cl_int err = 0; - cl_event events[4]; - events[0] = NULL; - events[1] = NULL; - events[2] = NULL; - events[3] = NULL; + cl_event events[4] = { NULL }; cl_kernel kernel; size_t global, local, size_qb; @@ -124,22 +120,22 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra global = n * k; switch (btype) { - case 2: + case GGML_TYPE_Q4_0: kernel = kernel_q4_0; local = 16; size_qb = global * (sizeof(float) + local) / 32; break; - case 3: + case GGML_TYPE_Q4_1: kernel = kernel_q4_1; local = 16; size_qb = global * (sizeof(float) * 2 + local) / 32; break; - case 4: + case GGML_TYPE_Q4_2: kernel = kernel_q4_2; local = 8; size_qb = global * (sizeof(short) + local) / 16; break; - case 5: + case GGML_TYPE_Q4_3: kernel = kernel_q4_3; local = 8; size_qb = global * (sizeof(short) * 2 + local) / 16; diff --git a/ggml.c b/ggml.c index fc06acf1e7f3a..4bc86110cb16c 100644 --- a/ggml.c +++ b/ggml.c @@ -7580,7 +7580,7 @@ static void ggml_compute_forward_mul_mat_f32( 1.0f, y, ne10, x, ne10, 0.0f, d, ne01, - params->type); + GGML_TYPE_F32); #else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, @@ -7814,7 +7814,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( 1.0f, y, ne10, x, ne10, 0.0f, d, ne01, - params->type); + GGML_TYPE_F32); #else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); From 137071003c12deb8b046890edb48612a582bd686 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 25 Apr 2023 19:40:54 +0200 Subject: [PATCH 12/19] Improve btype dequant kernel selection code, add error if type is unsupported --- ggml-opencl.cpp | 61 +++++++++++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 8a87edb881f95..aa426fe3f81fc 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -114,33 +114,40 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra cl_event events[4] = { NULL }; cl_kernel kernel; - size_t global, local, size_qb; - const bool dequant = btype >= 2 && btype < 6; - if (dequant) { - global = n * k; - - switch (btype) { - case GGML_TYPE_Q4_0: - kernel = kernel_q4_0; - local = 16; - size_qb = global * (sizeof(float) + local) / 32; - break; - case GGML_TYPE_Q4_1: - kernel = kernel_q4_1; - local = 16; - size_qb = global * (sizeof(float) * 2 + local) / 32; - break; - case GGML_TYPE_Q4_2: - kernel = kernel_q4_2; - local = 8; - size_qb = global * (sizeof(short) + local) / 16; - break; - case GGML_TYPE_Q4_3: - kernel = kernel_q4_3; - local = 8; - size_qb = global * (sizeof(short) * 2 + local) / 16; - break; - } + size_t global = n * k, local, size_qb; + bool dequant; + + switch (btype) { + case GGML_TYPE_F32: + dequant = false; + break; + case GGML_TYPE_Q4_0: + dequant = true; + kernel = kernel_q4_0; + local = 16; + size_qb = global * (sizeof(float) + local) / 32; + break; + case GGML_TYPE_Q4_1: + dequant = true; + kernel = kernel_q4_1; + local = 16; + size_qb = global * (sizeof(float) * 2 + local) / 32; + break; + case GGML_TYPE_Q4_2: + dequant = true; + kernel = kernel_q4_2; + local = 8; + size_qb = global * (sizeof(short) + local) / 16; + break; + case GGML_TYPE_Q4_3: + dequant = true; + kernel = kernel_q4_3; + local = 8; + size_qb = global * (sizeof(short) * 2 + local) / 16; + break; + default: + fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype); + abort(); } const size_t size_a = m * k * sizeof(float); From 2b0c6a56f9ea7c690520ade1897bffd09dfb33de Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Wed, 26 Apr 2023 07:48:04 +0200 Subject: [PATCH 13/19] Improve code quality * Move internal stuff out of header * Use internal enums instead of CLBlast enums * Remove leftover C++ includes and defines * Make event use easier to read Co-authored-by: Henri Vasserman --- CMakeLists.txt | 2 +- Makefile | 2 +- ggml-opencl.cpp => ggml-opencl.c | 66 ++++++++++++++++++++------------ ggml-opencl.h | 29 ++++++-------- ggml.c | 6 +-- 5 files changed, 58 insertions(+), 47 deletions(-) rename ggml-opencl.cpp => ggml-opencl.c (75%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 79dfd070b9097..5fdbeddfca443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -174,7 +174,7 @@ if (LLAMA_CLBLAST) if (CLBlast_FOUND) message(STATUS "CLBlast found") - set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h) + set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h) add_compile_definitions(GGML_USE_CLBLAST) diff --git a/Makefile b/Makefile index 9b9c8ee7a877f..b056a01d00fcc 100644 --- a/Makefile +++ b/Makefile @@ -117,7 +117,7 @@ ifdef LLAMA_CLBLAST CFLAGS += -DGGML_USE_CLBLAST LDFLAGS += -lclblast -lOpenCL OBJS += ggml-opencl.o -ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h +ggml-opencl.o: ggml-opencl.c ggml-opencl.h $(CXX) $(CXXFLAGS) -c $< -o $@ endif ifdef LLAMA_GPROF diff --git a/ggml-opencl.cpp b/ggml-opencl.c similarity index 75% rename from ggml-opencl.cpp rename to ggml-opencl.c index aa426fe3f81fc..08e8df81132fd 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.c @@ -1,13 +1,24 @@ #include "ggml-opencl.h" -#include -#include -#include +#define CL_TARGET_OPENCL_VERSION 110 +#include + +#include +#include #include "ggml.h" #include +#define CL_CHECK(err, name) \ + do { \ + cl_int err_ = (err); \ + if (err_ != CL_SUCCESS) { \ + fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + cl_platform_id platform; cl_device_id device; cl_context context; @@ -74,7 +85,7 @@ void ggml_cl_init(void) { printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); CL_CHECK(err, "clCreateContext"); - queue = clCreateCommandQueue(context, device, 0, &err); + queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err); CL_CHECK(err, "clCreateCommandQueue"); free(platforms); @@ -93,7 +104,7 @@ void ggml_cl_init(void) { CL_CHECK(err, "clCreateKernel"); } -void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) { +static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) { if (req_size <= *cur_size) { return; } @@ -108,11 +119,14 @@ void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_me CL_CHECK(err, "clCreateBuffer"); } -void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { +void ggml_cl_sgemm_wrapper( + const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, + const int m, const int n, const int k, + const float alpha, const void *host_a, const int lda, + const float *host_b, const int ldb, const float beta, + float *host_c, const int ldc, const int btype) { cl_int err = 0; - cl_event events[4] = { NULL }; - cl_kernel kernel; size_t global = n * k, local, size_qb; bool dequant; @@ -162,42 +176,46 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b); ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c); + cl_event ev_a, ev_qb, ev_b; + if (dequant) { err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); CL_CHECK(err, "clSetKernelArg"); - clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, events + 1); + clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb); } else { - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, events + 1); + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b); } - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, events); + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); if (dequant) { - err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); + err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b); CL_CHECK(err, "clEnqueueNDRangeKernel"); } - clWaitForEvents(dequant ? 4 : 3, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); - clReleaseEvent(events[2]); + clWaitForEvents(1, &ev_a); + clWaitForEvents(1, &ev_b); + clReleaseEvent(ev_a); + clReleaseEvent(ev_b); if (dequant) { - clReleaseEvent(events[3]); + clReleaseEvent(ev_qb); } - CLBlastSgemm(order, - trans_a, trans_b, + cl_event ev_sgemm; + CLBlastSgemm((CLBlastLayout)order, + (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b, m, n, k, alpha, cl_buffer_a, 0, lda, cl_buffer_b, 0, ldb, beta, cl_buffer_c, 0, ldc, - &queue, events); + &queue, &ev_sgemm); - clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, events, events + 1); + cl_event ev_c; + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c); // Wait for completion - clWaitForEvents(2, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); + clWaitForEvents(1, &ev_c); + clReleaseEvent(ev_sgemm); + clReleaseEvent(ev_c); } diff --git a/ggml-opencl.h b/ggml-opencl.h index 9d7f911737baf..7bcc603ef8432 100644 --- a/ggml-opencl.h +++ b/ggml-opencl.h @@ -1,30 +1,23 @@ #pragma once -#define CL_TARGET_OPENCL_VERSION 110 -#include -#define MAX_CL_BUFFERS 16 - #ifdef __cplusplus extern "C" { #endif -// Buffer reuse code adapted from cuda implementation by slaren -#define CL_CHECK(err, name) \ - do { \ - cl_int err_ = (err); \ - if (err_ != CL_SUCCESS) { \ - fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \ - exit(1); \ - } \ - } while (0) +void ggml_cl_init(void); -cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size); -void ggml_cl_pool_free(cl_mem mem, size_t size); +enum ggml_blas_order { + GGML_BLAS_ORDER_ROW_MAJOR = 101, + GGML_BLAS_ORDER_COLUMN_MAJOR = 102, +}; -cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer); -void ggml_cl_init(void); +enum ggml_blas_op { + GGML_BLAS_OP_N = 111, + GGML_BLAS_OP_T = 112, + GGML_BLAS_OP_C = 113, +}; -void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); +void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); #ifdef __cplusplus } diff --git a/ggml.c b/ggml.c index 4bc86110cb16c..1576792b0de9f 100644 --- a/ggml.c +++ b/ggml.c @@ -7575,7 +7575,7 @@ static void ggml_compute_forward_mul_mat_f32( CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #elif defined(GGML_USE_CLBLAST) // zT = y * xT - ggml_cl_sgemm_wrapper(CLBlastLayoutRowMajor, CLBlastTransposeNo, CLBlastTransposeYes, + ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, @@ -7809,7 +7809,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // zT = y * xT - ggml_cl_sgemm_wrapper(CLBlastLayoutRowMajor, CLBlastTransposeNo, CLBlastTransposeYes, + ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, @@ -8080,7 +8080,7 @@ static void ggml_compute_forward_mul_mat_q_f32( CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); #elif defined(GGML_USE_CLBLAST) // zT = y * xT - ggml_cl_sgemm_wrapper(CLBlastLayoutRowMajor, CLBlastTransposeNo, CLBlastTransposeYes, + ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, From b7464582810d0d723c71ac40fcde1a1d05c45587 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Wed, 26 Apr 2023 18:38:31 +0200 Subject: [PATCH 14/19] Use c compiler for opencl files --- Makefile | 2 +- ggml-opencl.c | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index b056a01d00fcc..e7f0b1b365221 100644 --- a/Makefile +++ b/Makefile @@ -118,7 +118,7 @@ ifdef LLAMA_CLBLAST LDFLAGS += -lclblast -lOpenCL OBJS += ggml-opencl.o ggml-opencl.o: ggml-opencl.c ggml-opencl.h - $(CXX) $(CXXFLAGS) -c $< -o $@ + $(CC) $(CFLAGS) -c $< -o $@ endif ifdef LLAMA_GPROF CFLAGS += -pg diff --git a/ggml-opencl.c b/ggml-opencl.c index 08e8df81132fd..e4ee289c3e374 100644 --- a/ggml-opencl.c +++ b/ggml-opencl.c @@ -29,35 +29,35 @@ cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { - cl_program program; + cl_program p; char *program_log; size_t program_size, log_size; int err; program_size = strlen(program_buffer); - program = clCreateProgramWithSource(ctx, 1, + p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err); if(err < 0) { fprintf(stderr, "OpenCL error creating program"); exit(1); } - err = clBuildProgram(program, 0, NULL, NULL, NULL, NULL); + err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL); if(err < 0) { - clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size); program_log = (char*) malloc(log_size + 1); program_log[log_size] = '\0'; - clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL); printf("%s\n", program_log); free(program_log); exit(1); } - return program; + return p; } void ggml_cl_init(void) { From ce97a807cb6b26183a189d0c73a6e43faa7fee9d Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Wed, 26 Apr 2023 18:39:04 +0200 Subject: [PATCH 15/19] Simplify code, fix include --- ggml-opencl.c | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ggml-opencl.c b/ggml-opencl.c index e4ee289c3e374..335ecf95dbd44 100644 --- a/ggml-opencl.c +++ b/ggml-opencl.c @@ -8,7 +8,7 @@ #include "ggml.h" -#include +#include "ggml_clblast_dequant.cl" #define CL_CHECK(err, name) \ do { \ @@ -190,15 +190,13 @@ void ggml_cl_sgemm_wrapper( clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); if (dequant) { err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b); + clReleaseEvent(ev_qb); CL_CHECK(err, "clEnqueueNDRangeKernel"); } clWaitForEvents(1, &ev_a); clWaitForEvents(1, &ev_b); clReleaseEvent(ev_a); clReleaseEvent(ev_b); - if (dequant) { - clReleaseEvent(ev_qb); - } cl_event ev_sgemm; CLBlastSgemm((CLBlastLayout)order, From 4a35ec9df5f2c516bf9785329a3d11be06b03a3c Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Wed, 26 Apr 2023 19:56:58 +0200 Subject: [PATCH 16/19] First check error, then release event --- ggml-opencl.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-opencl.c b/ggml-opencl.c index 335ecf95dbd44..bbd72597a9014 100644 --- a/ggml-opencl.c +++ b/ggml-opencl.c @@ -190,8 +190,8 @@ void ggml_cl_sgemm_wrapper( clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); if (dequant) { err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b); - clReleaseEvent(ev_qb); CL_CHECK(err, "clEnqueueNDRangeKernel"); + clReleaseEvent(ev_qb); } clWaitForEvents(1, &ev_a); clWaitForEvents(1, &ev_b); From fafebff53cd360c21009ffd1a848c49ab7b36a68 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Thu, 27 Apr 2023 15:25:09 +0200 Subject: [PATCH 17/19] Make globals static, fix indentation --- ggml-opencl.c | 75 +++++++++++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/ggml-opencl.c b/ggml-opencl.c index bbd72597a9014..7deb27bbf5a02 100644 --- a/ggml-opencl.c +++ b/ggml-opencl.c @@ -19,45 +19,42 @@ } \ } while (0) -cl_platform_id platform; -cl_device_id device; -cl_context context; -cl_command_queue queue; -cl_program program; -cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3; -cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; -size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; - -cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { - cl_program p; - char *program_log; - size_t program_size, log_size; - int err; - - program_size = strlen(program_buffer); - - p = clCreateProgramWithSource(ctx, 1, - (const char**)&program_buffer, &program_size, &err); - if(err < 0) { - fprintf(stderr, "OpenCL error creating program"); - exit(1); - } - - err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL); - if(err < 0) { - - clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, - 0, NULL, &log_size); - program_log = (char*) malloc(log_size + 1); - program_log[log_size] = '\0'; - clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, - log_size + 1, program_log, NULL); - printf("%s\n", program_log); - free(program_log); - exit(1); - } - - return p; +static cl_platform_id platform; +static cl_device_id device; +static cl_context context; +static cl_command_queue queue; +static cl_program program; +static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3; +static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; +static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; + +static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { + cl_program p; + char *program_log; + size_t program_size, log_size; + int err; + + program_size = strlen(program_buffer); + + p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + fprintf(stderr, "OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL); + if(err < 0) { + + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL); + printf("%s\n", program_log); + free(program_log); + exit(1); + } + + return p; } void ggml_cl_init(void) { From 96346fb2a4dccbd4525717e9526869dbf50c1e04 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Thu, 27 Apr 2023 15:27:22 +0200 Subject: [PATCH 18/19] Rename dequant kernels file to conform with other file names --- ggml_clblast_dequant.cl => ggml-opencl-dequant.cl | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ggml_clblast_dequant.cl => ggml-opencl-dequant.cl (100%) diff --git a/ggml_clblast_dequant.cl b/ggml-opencl-dequant.cl similarity index 100% rename from ggml_clblast_dequant.cl rename to ggml-opencl-dequant.cl From bbfba5f740a61e3543b773c199c1552fda51d149 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Thu, 27 Apr 2023 15:30:30 +0200 Subject: [PATCH 19/19] Fix import cl file name --- ggml-opencl.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-opencl.c b/ggml-opencl.c index 7deb27bbf5a02..1d68f19ee1e78 100644 --- a/ggml-opencl.c +++ b/ggml-opencl.c @@ -8,7 +8,7 @@ #include "ggml.h" -#include "ggml_clblast_dequant.cl" +#include "ggml-opencl-dequant.cl" #define CL_CHECK(err, name) \ do { \