Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
cd test/srt
python3 run_suite.py --suite per-commit-2-gpu

unittest-test-backend-4-gpu:
unit-test-backend-4-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
needs: [unit-test-frontend, unit-test-backend-2-gpu]
Expand All @@ -108,7 +108,7 @@ jobs:
cd test/srt
python3 run_suite.py --suite per-commit-4-gpu

unittest-test-backend-8-gpu:
unit-test-backend-8-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
needs: [unit-test-frontend, unit-test-backend-2-gpu]
Expand Down Expand Up @@ -306,12 +306,51 @@ jobs:
cd test/srt
python3 test_moe_eval_accuracy_large.py

unit-test-deepep-4-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
runs-on: 4-gpu-runner
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install dependencies
run: |
bash scripts/ci_install_deepep.sh

- name: Run test
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite per-commit-4-gpu-deepep

unit-test-deepep-8-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
runs-on: 8-gpu-runner
needs: [
unit-test-deepep-4-gpu,
]
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install dependencies
run: |
bash scripts/ci_install_deepep.sh

- name: Run test
timeout-minutes: 20
run: |
cd test/srt
python3 run_suite.py --suite per-commit-8-gpu-deepep

finish:
if: always()
needs: [
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unittest-test-backend-8-gpu,
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
accuracy-test-1-gpu, accuracy-test-2-gpu,
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unit-test-backend-4-gpu,
unit-test-backend-8-gpu, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
accuracy-test-1-gpu, accuracy-test-2-gpu, unit-test-deepep-4-gpu, unit-test-deepep-8-gpu,
]
runs-on: ubuntu-latest
steps:
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
from sglang.utils import is_in_ci

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1058,7 +1059,7 @@ def ep_gather(
input_index: torch.Tensor,
output_tensor: torch.Tensor,
):
BLOCK_D = 1024 # block size of quantization
BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider defining BLOCK_D based on is_in_ci() as a constant at the file level to improve readability and maintainability.

_BLOCK_D_CI = 128
_BLOCK_D_DEFAULT = 1024
BLOCK_D = _BLOCK_D_CI if is_in_ci() else _BLOCK_D_DEFAULT

num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
Expand Down
75 changes: 75 additions & 0 deletions scripts/ci_install_deepep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/bin/bash
# Install the dependency in CI.
set -euxo pipefail

bash scripts/ci_install_dependency.sh

if python3 -c "import deep_ep" >/dev/null 2>&1; then
echo "deep_ep is already installed or importable. Skipping installation."
exit 0
fi

export GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/
export NVSHMEM_DIR=/opt/nvshmem/install
export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH"
export PATH="${NVSHMEM_DIR}/bin:$PATH"
export CUDA_HOME=/usr/local/cuda

# Install system dependencies
apt install -y curl wget git sudo libibverbs-dev rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 build-essential cmake

# Install GDRCopy
rm -rf /opt/gdrcopy && mkdir -p /opt/gdrcopy
mkdir -p /opt/nvshmem
cd /opt/gdrcopy
git clone https://github.com/NVIDIA/gdrcopy.git .
git checkout v2.4.4
apt update
apt install -y nvidia-dkms-535
apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms
apt install -y check libsubunit0 libsubunit-dev
cd packages
CUDA=/usr/local/cuda ./build-deb-packages.sh
dpkg -i gdrdrv-dkms_*.deb
dpkg -i libgdrapi_*.deb
dpkg -i gdrcopy-tests_*.deb
dpkg -i gdrcopy_*.deb

if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then
ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so
fi
apt-get update && apt-get install -y libfabric-dev

# Clone DeepEP
rm -rf /root/.cache/deepep && git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep && cd /root/.cache/deepep && git checkout eef7ab50fa5cf0ab1dd3fce4c6493c90bdf290ac

# Install NVSHMEM
cd /opt/nvshmem
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
tar -xf nvshmem_src_3.2.5-1.txz
rm nvshmem && mv nvshmem_src nvshmem
cd nvshmem
git apply /root/.cache/deepep/third-party/nvshmem.patch
NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90
cd build
make -j$(nproc) install

# Install DeepEP
cd /root/.cache/deepep && git checkout eef7ab50fa5cf0ab1dd3fce4c6493c90bdf290ac && python3 setup.py install

# Verify configuration
echo "=== NCCL Configuration ==="
nvidia-smi topo -m
nvidia-smi nvlink -s
echo "=== Verify GDRCOPY ==="
gdrcopy_copybw
echo "=== Verify NVSHMEM ==="
nvshmem-info -a
11 changes: 6 additions & 5 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,22 @@ class TestFile:
TestFile("test_pp_single_node.py", 150),
TestFile("test_multi_instance_release_memory_occupation.py", 64),
],
"per-commit-4-gpu-deepep": [
TestFile("test_deepep_small.py", 531),
],
"per-commit-4-gpu-amd": [
TestFile("test_pp_single_node.py", 150),
],
"per-commit-8-gpu": [
# Disabled deepep tests temporarily because it takes too much time.
# TODO: re-enable them after reducing the test time with compilation cache and smaller models.
# TestFile("test_deepep_intranode.py", 50),
# TestFile("test_deepep_low_latency.py", 50),
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
# Disabled because it hangs on the CI.
# TestFile("test_moe_ep.py", 181),
TestFile("test_disaggregation.py", 270),
TestFile("test_disaggregation_different_tp.py", 155),
TestFile("test_full_deepseek_v3.py", 463),
],
"per-commit-8-gpu-deepep": [
TestFile("test_deepep_large.py", 485),
],
"per-commit-8-gpu-amd": [
TestFile("test_full_deepseek_v3.py", 250),
],
Expand Down
146 changes: 146 additions & 0 deletions test/srt/test_deepep_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestDeepseek(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"8",
"--enable-dp-attention",
"--dp",
"8",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--enable-two-batch-overlap",
"--ep-num-redundant-experts",
"32",
"--ep-dispatch-algorithm",
"dynamic",
"--eplb-algorithm",
"deepseek",
"--cuda-graph-bs",
"256",
"--max-running-requests",
"2048",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1250,
parallel=1250,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")

self.assertGreater(metrics["accuracy"], 0.93)
self.assertGreater(metrics["output_throughput"], 3800)


class TestDeepseekMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"8",
"--enable-dp-attention",
"--dp",
"8",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--enable-deepep-moe",
"--enable-two-batch-overlap",
"--ep-num-redundant-experts",
"32",
"--ep-dispatch-algorithm",
"dynamic",
"--eplb-algorithm",
"deepseek",
"--cuda-graph-bs",
"64", # TODO: increase it to 128 when TBO is supported in draft_extend
"--max-running-requests",
"512",
"--speculative-algorithm",
"NEXTN",
"--speculative-num-steps",
"1",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"2",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1250,
parallel=1250,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")

self.assertGreater(metrics["accuracy"], 0.93)

server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(
f"###test_gsm8k:\n"
f"accuracy={metrics['accuracy']=:.3f}\n"
f"{avg_spec_accept_length=:.3f}\n"
)
self.assertGreater(avg_spec_accept_length, 1.9)


if __name__ == "__main__":
unittest.main()
Loading
Loading