Skip to content

Commit 6c843ce

Browse files
committed
Actually, attempt 3
1 parent dc1e4d5 commit 6c843ce

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

ggml/src/ggml-cuda/unary.cu

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,17 +385,30 @@ static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alp
385385
return;
386386
}
387387

388-
const float xi = sizeof(x[i]) == sizeof(half) ? __half2float(x[i]) : (float) x[i];
389-
const float gate_pos = (xi > 0.0f);
388+
float xi;
389+
#if __CUDA_ARCH__ >= 530
390+
if constexpr (std::is_same<T, half>::value) {
391+
xi = __half2float(x[i]);
392+
} else
393+
#endif
394+
{
395+
xi = (float)x[i];
396+
}
390397

398+
const float gate_pos = (xi > 0.0f);
391399
const float y_pos = alpha_p * xi * xi + beta * xi;
392-
393400
const float min_v_eps = fminf(xi, eps);
394401
const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
395-
396402
const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
397403

398-
dst[i] = (T) (sizeof(dst[i]) == sizeof(float)) ? out : __float2half(out));
404+
#if __CUDA_ARCH__ >= 530
405+
if constexpr (std::is_same<T, half>::value) {
406+
dst[i] = __float2half(out);
407+
} else
408+
#endif
409+
{
410+
dst[i] = (T)out;
411+
}
399412
}
400413

401414
template <typename T>

0 commit comments

Comments
 (0)