Skip to content

Commit 65d51f4

Browse files
committed
nix: Added and enabled rocmWmma support along with ROWMMA_PATH option in CMake
1 parent 9e3a5df commit 65d51f4

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

.devops/nix/package.nix

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
useMpi ? false,
3333
useRocm ? config.rocmSupport,
3434
rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets,
35+
rocmUseWmma ? true,
3536
enableCurl ? true,
3637
useVulkan ? false,
3738
buildAllCudaFaQuants ? false,
@@ -92,13 +93,16 @@ let
9293
libcublas
9394
];
9495

95-
rocmBuildInputs = with rocmPackages; [
96-
clr
97-
hipblas
98-
rocblas
99-
llvm.lld
100-
llvm.bintools
101-
];
96+
rocmBuildInputs =
97+
with rocmPackages;
98+
[
99+
clr
100+
hipblas
101+
rocblas
102+
llvm.lld
103+
llvm.bintools
104+
]
105+
++ optionals rocmUseWmma [ rocmPackages.rocwmma ];
102106

103107
vulkanBuildInputs = [
104108
vulkan-headers
@@ -198,6 +202,10 @@ effectiveStdenv.mkDerivation (finalAttrs: {
198202
(cmakeFeature "AMDGPU_TARGETS" rocmGpuTargets)
199203
(cmakeBool "GGML_CUDA_FA_ALL_QUANTS" buildAllCudaFaQuants)
200204
]
205+
++ optionals rocmUseWmma [
206+
(cmakeBool "GGML_HIP_ROCWMMA_FATTN" rocmUseWmma)
207+
(cmakeFeature "GGML_HIP_ROCWMMA_PATH" "${rocmPackages.rocwmma}")
208+
]
201209
++ optionals useMetalKit [
202210
(lib.cmakeFeature "CMAKE_C_FLAGS" "-D__ARM_FEATURE_DOTPROD=1")
203211
(cmakeBool "GGML_METAL_EMBED_LIBRARY" (!precompileMetalShaders))

ggml/src/ggml-hip/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ endif()
3939
find_package(hip REQUIRED)
4040
find_package(hipblas REQUIRED)
4141
find_package(rocblas REQUIRED)
42-
if (GGML_HIP_ROCWMMA_FATTN)
42+
if (GGML_HIP_ROCWMMA_FATTN AND NOT GGML_HIP_ROCWMMA_PATH)
4343
CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
4444
if (NOT ${FOUND_ROCWMMA})
4545
message(FATAL_ERROR "rocwmma has not been found")
@@ -111,6 +111,9 @@ endif()
111111

112112
if (GGML_HIP_ROCWMMA_FATTN)
113113
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
114+
if (GGML_HIP_ROCWMMA_PATH)
115+
target_include_directories(ggml-hip PRIVATE ${GGML_HIP_ROCWMMA_PATH}/include)
116+
endif()
114117
endif()
115118

116119
if (NOT GGML_HIP_MMQ_MFMA)

0 commit comments

Comments
 (0)