diff --git a/docker/Dockerfile b/docker/Dockerfile index 22a09368b2..69023b61cf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -263,7 +263,18 @@ RUN if [ -n "$MOFED_VERSION" ] ; then \ RUN if [ -n "$CUDA_VERSION" ] ; then \ pip${PYTHON_VERSION} install --upgrade --no-cache-dir ninja==1.11.1 && \ pip${PYTHON_VERSION} install --upgrade --no-cache-dir --force-reinstall packaging==22.0 && \ - MAX_JOBS=1 pip${PYTHON_VERSION} install --no-cache-dir --no-build-isolation flash-attn==2.6.3; \ + MAJOR_CUDA_VERSION=$(echo $CUDA_VERSION | cut -d. -f1) && \ + MINOR_CUDA_VERSION=$(echo $CUDA_VERSION | cut -d. -f2) && \ + CUDA_STRING="cu${MAJOR_CUDA_VERSION}${MINOR_CUDA_VERSION}" && \ + PYTORCH_MAJOR=$(echo $PYTORCH_VERSION | cut -d. -f1) && \ + PYTORCH_MINOR=$(echo $PYTORCH_VERSION | cut -d. -f2) && \ + PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1) && \ + PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2) && \ + WHEEL_URL="https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4/flash_attn-2.7.4.post1+${CUDA_STRING}torch${PYTORCH_MAJOR}.${PYTORCH_MINOR}cxx11abiTRUE-cp${PYTHON_MAJOR}${PYTHON_MINOR}-cp${PYTHON_MAJOR}${PYTHON_MINOR}-linux_x86_64.whl" && \ + echo "Installing Flash Attention from: $WHEEL_URL" && \ + pip${PYTHON_VERSION} install --no-cache-dir $WHEEL_URL || \ + (echo "Pre-built wheel not found, falling back to source installation" && \ + MAX_JOBS=1 pip${PYTHON_VERSION} install --no-cache-dir --no-build-isolation flash-attn==2.7.4.post1); \ cd .. ; \ fi