Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -586,17 +586,42 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
}

static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
#ifdef FAST_FP16_AVAILABLE
acc += v*u;
#else
const float2 tmpv = __half22float2(v);
const float2 tmpu = __half22float2(u);
float2 tmpacc = __half22float2(acc);
tmpacc.x += tmpv.x * tmpu.x;
tmpacc.y += tmpv.y * tmpu.y;
acc = make_half2(tmpacc.x, tmpacc.y);
#endif // FAST_FP16_AVAILABLE
}

// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
template <int nbytes>
template <int nbytes, int alignment = 0>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (nbytes == 4) {
*(int *) dst = *(const int *) src;
} else if constexpr (nbytes == 8) {
*(int2 *) dst = *(const int2 *) src;
} else if constexpr (nbytes == 16) {
*(int4 *) dst = *(const int4 *) src;
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
if constexpr (alignment != 0) {
static_assert(nbytes % alignment == 0, "bad alignment");
}
constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;

#pragma unroll
for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
if constexpr (nb_per_cpy == 1) {
((char *) dst)[i] = ((const char *) src)[i];
} else if constexpr (nb_per_cpy == 2) {
((short *) dst)[i] = ((const short *) src)[i];
} else if constexpr (nb_per_cpy == 4) {
((int *) dst)[i] = ((const int *) src)[i];
} else if constexpr (nb_per_cpy == 8) {
((int2 *) dst)[i] = ((const int2 *) src)[i];
} else if constexpr (nb_per_cpy == 16) {
((int4 *) dst)[i] = ((const int4 *) src)[i];
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
}
}
}

Expand Down
Loading
Loading