diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index ac93dc18370..2378695e21e 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -89,7 +89,7 @@ jobs: cd test/srt python3 run_suite.py --suite per-commit-2-gpu - unittest-test-backend-4-gpu: + unit-test-backend-4-gpu: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false needs: [unit-test-frontend, unit-test-backend-2-gpu] @@ -108,7 +108,7 @@ jobs: cd test/srt python3 run_suite.py --suite per-commit-4-gpu - unittest-test-backend-8-gpu: + unit-test-backend-8-gpu: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false needs: [unit-test-frontend, unit-test-backend-2-gpu] @@ -306,12 +306,51 @@ jobs: cd test/srt python3 test_moe_eval_accuracy_large.py + unit-test-deepep-4-gpu: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: 4-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + bash scripts/ci_install_deepep.sh + + - name: Run test + timeout-minutes: 20 + run: | + cd test/srt + python3 run_suite.py --suite per-commit-4-gpu-deepep + + unit-test-deepep-8-gpu: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: 8-gpu-runner + needs: [ + unit-test-deepep-4-gpu, + ] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + bash scripts/ci_install_deepep.sh + + - name: Run test + timeout-minutes: 20 + run: | + cd test/srt + python3 run_suite.py --suite per-commit-8-gpu-deepep + finish: if: always() needs: [ - unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unittest-test-backend-8-gpu, - performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, - accuracy-test-1-gpu, accuracy-test-2-gpu, + unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unit-test-backend-4-gpu, + unit-test-backend-8-gpu, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, + accuracy-test-1-gpu, accuracy-test-2-gpu, unit-test-deepep-4-gpu, unit-test-deepep-8-gpu, ] runs-on: ubuntu-latest steps: diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 1d661931cf7..7f9bdc7486a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -6,6 +6,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda +from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -1058,7 +1059,7 @@ def ep_gather( input_index: torch.Tensor, output_tensor: torch.Tensor, ): - BLOCK_D = 1024 # block size of quantization + BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization num_warps = 2 num_tokens = output_tensor.shape[0] hidden_size = input_tensor.shape[1] diff --git a/scripts/ci_install_deepep.sh b/scripts/ci_install_deepep.sh new file mode 100755 index 00000000000..aa4dab097bb --- /dev/null +++ b/scripts/ci_install_deepep.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Install the dependency in CI. +set -euxo pipefail + +bash scripts/ci_install_dependency.sh + +if python3 -c "import deep_ep" >/dev/null 2>&1; then + echo "deep_ep is already installed or importable. Skipping installation." + exit 0 +fi + +export GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ +export NVSHMEM_DIR=/opt/nvshmem/install +export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" +export PATH="${NVSHMEM_DIR}/bin:$PATH" +export CUDA_HOME=/usr/local/cuda + +# Install system dependencies +apt install -y curl wget git sudo libibverbs-dev rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 build-essential cmake + +# Install GDRCopy +rm -rf /opt/gdrcopy && mkdir -p /opt/gdrcopy +mkdir -p /opt/nvshmem +cd /opt/gdrcopy +git clone https://github.com/NVIDIA/gdrcopy.git . +git checkout v2.4.4 +apt update +apt install -y nvidia-dkms-535 +apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms +apt install -y check libsubunit0 libsubunit-dev +cd packages +CUDA=/usr/local/cuda ./build-deb-packages.sh +dpkg -i gdrdrv-dkms_*.deb +dpkg -i libgdrapi_*.deb +dpkg -i gdrcopy-tests_*.deb +dpkg -i gdrcopy_*.deb + +if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then + ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so +fi +apt-get update && apt-get install -y libfabric-dev + +# Clone DeepEP +rm -rf /root/.cache/deepep && git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep && cd /root/.cache/deepep && git checkout eef7ab50fa5cf0ab1dd3fce4c6493c90bdf290ac + +# Install NVSHMEM +cd /opt/nvshmem +wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz +tar -xf nvshmem_src_3.2.5-1.txz +rm -rf nvshmem && mv nvshmem_src nvshmem +cd nvshmem +git apply /root/.cache/deepep/third-party/nvshmem.patch +NVSHMEM_SHMEM_SUPPORT=0 \ +NVSHMEM_UCX_SUPPORT=0 \ +NVSHMEM_USE_NCCL=0 \ +NVSHMEM_MPI_SUPPORT=0 \ +NVSHMEM_IBGDA_SUPPORT=1 \ +NVSHMEM_PMIX_SUPPORT=0 \ +NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ +NVSHMEM_USE_GDRCOPY=1 \ +cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/opt/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90 +cd build +make -j$(nproc) install + +# Install DeepEP +cd /root/.cache/deepep && python3 setup.py install + +# Verify configuration +echo "=== NCCL Configuration ===" +nvidia-smi topo -m +nvidia-smi nvlink -s +echo "=== Verify GDRCOPY ===" +gdrcopy_copybw +echo "=== Verify NVSHMEM ===" +nvshmem-info -a diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a3c8e1a8dfe..90e4f009404 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -170,21 +170,22 @@ class TestFile: TestFile("test_pp_single_node.py", 150), TestFile("test_multi_instance_release_memory_occupation.py", 64), ], + "per-commit-4-gpu-deepep": [ + TestFile("test_deepep_small.py", 531), + ], "per-commit-4-gpu-amd": [ TestFile("test_pp_single_node.py", 150), ], "per-commit-8-gpu": [ - # Disabled deepep tests temporarily because it takes too much time. - # TODO: re-enable them after reducing the test time with compilation cache and smaller models. - # TestFile("test_deepep_intranode.py", 50), - # TestFile("test_deepep_low_latency.py", 50), - # TestFile("test_moe_deepep_eval_accuracy_large.py", 250), # Disabled because it hangs on the CI. # TestFile("test_moe_ep.py", 181), TestFile("test_disaggregation.py", 270), TestFile("test_disaggregation_different_tp.py", 155), TestFile("test_full_deepseek_v3.py", 463), ], + "per-commit-8-gpu-deepep": [ + TestFile("test_deepep_large.py", 485), + ], "per-commit-8-gpu-amd": [ TestFile("test_full_deepseek_v3.py", 250), ], diff --git a/test/srt/test_deepep_large.py b/test/srt/test_deepep_large.py new file mode 100644 index 00000000000..8afb2896f8f --- /dev/null +++ b/test/srt/test_deepep_large.py @@ -0,0 +1,145 @@ +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestDeepseek(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "8", + "--enable-dp-attention", + "--dp", + "8", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--enable-deepep-moe", + "--enable-two-batch-overlap", + "--ep-num-redundant-experts", + "32", + "--ep-dispatch-algorithm", + "dynamic", + "--eplb-algorithm", + "deepseek", + "--cuda-graph-bs", + "256", + "--max-running-requests", + "2048", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=1250, + parallel=1250, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Eval accuracy of GSM8K: {metrics=}") + + self.assertGreater(metrics["accuracy"], 0.93) + + +class TestDeepseekMTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "8", + "--enable-dp-attention", + "--dp", + "8", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--enable-deepep-moe", + "--enable-two-batch-overlap", + "--ep-num-redundant-experts", + "32", + "--ep-dispatch-algorithm", + "dynamic", + "--eplb-algorithm", + "deepseek", + "--cuda-graph-bs", + "64", # TODO: increase it to 128 when TBO is supported in draft_extend + "--max-running-requests", + "512", + "--speculative-algorithm", + "NEXTN", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=1250, + parallel=1250, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Eval accuracy of GSM8K: {metrics=}") + + self.assertGreater(metrics["accuracy"], 0.93) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print( + f"###test_gsm8k:\n" + f"accuracy={metrics['accuracy']=:.3f}\n" + f"{avg_spec_accept_length=:.3f}\n" + ) + self.assertGreater(avg_spec_accept_length, 1.9) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_deepep_small.py b/test/srt/test_deepep_small.py new file mode 100644 index 00000000000..a60f8296c67 --- /dev/null +++ b/test/srt/test_deepep_small.py @@ -0,0 +1,384 @@ +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestPureDP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--enable-deepep-moe", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +class TestHybridDPTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "2", + "--enable-deepep-moe", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +class TestTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-deepep-moe", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +@unittest.skip("covered in test_deepep_large.py") +class TestNoGatherdBuffer(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--enable-deepep-moe", + "--cuda-graph-max-bs", + "32", + "--max-running-requests", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +class TestTBO(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-deepep-moe", + "--enable-two-batch-overlap", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "128", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +@unittest.skip("covered in TestMTPWithTBO") +class TestMTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "2", + "--enable-dp-lm-head", + "--enable-deepep-moe", + "--speculative-algo", + "NEXTN", + "--speculative-draft", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, + "--speculative-num-steps", + "2", + "--speculative-eagle-topk", + "3", + "--speculative-num-draft-tokens", + "3", + "--cuda-graph-max-bs", + "32", + "--max-running-requests", + "32", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print( + f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n" + f"accuracy={metrics['accuracy']=:.3f}\n" + f"{avg_spec_accept_length=:.3f}\n" + ) + self.assertGreater(avg_spec_accept_length, 2.1) + + +class TestMTPWithTBO(CustomTestCase): + @classmethod + def setUpClass(cls): + import os + + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tp-size", + "4", + "--enable-dp-attention", + "--dp-size", + "4", + "--enable-two-batch-overlap", + "--enable-deepep-moe", + "--trust-remote-code", + "--speculative-algorithm", + "NEXTN", + "--speculative-num-steps", + "2", + "--speculative-eagle-topk", + "3", + "--speculative-num-draft-tokens", + "3", + "--speculative-draft", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, + "--chunked-prefill-size", + "256", + "--cuda-graph-max-bs", + "32", + "--max-running-requests", + "32", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print( + f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n" + f"accuracy={metrics['accuracy']=:.3f}\n" + f"{avg_spec_accept_length=:.3f}\n" + ) + self.assertGreater(avg_spec_accept_length, 2.1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 085dc206bb5..af50dc7803c 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -137,86 +137,5 @@ def test_gsm8k(self): self.assertGreater(avg_spec_accept_length, 2.5) -# TODO: enable this test later -# class TestDPAttentionDP2TP2DeepseekV3MTPTBO(CustomTestCase): -# @classmethod -# def setUpClass(cls): -# import os - -# # print debug log for tbo -# os.environ["SGLANG_TBO_DEBUG"] = "1" -# cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA -# cls.base_url = DEFAULT_URL_FOR_TEST -# other_args = [ -# "--trust-remote-code", -# "--disable-radix", -# "--speculative-algorithm", -# "EAGLE", -# "--speculative-num-steps", -# "2", -# "--speculative-eagle-topk", -# "4", -# "--speculative-num-draft-tokens", -# "4", -# "--speculative-draft", -# DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, -# "--tp-size", -# "2", -# "--enable-dp-attention", -# "--dp-size", -# "2", -# "--enable-two-batch-overlap", -# "--enable-deepep-moe", -# "--deepep-mode", -# "low_latency", -# "--chunked-prefill-size", -# "256", -# "--cuda-graph-max-bs", -# "32", -# "--max-running-requests", -# "32", -# ] -# if not is_in_amd_ci(): -# other_args += ["--mem-frac", "0.7"] -# cls.process = popen_launch_server( -# cls.model, -# cls.base_url, -# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, -# other_args=other_args, -# ) - -# @classmethod -# def tearDownClass(cls): -# kill_process_tree(cls.process.pid) - -# def test_gsm8k(self): -# requests.get(self.base_url + "/flush_cache") - -# args = SimpleNamespace( -# num_shots=5, -# data_path=None, -# num_questions=200, -# max_new_tokens=512, -# parallel=128, -# host="http://127.0.0.1", -# port=int(self.base_url.split(":")[-1]), -# ) -# metrics = run_eval_few_shot_gsm8k(args) -# print(metrics) - -# self.assertGreater(metrics["accuracy"], 0.60) - -# server_info = requests.get(self.base_url + "/get_server_info") -# avg_spec_accept_length = server_info.json()["internal_states"][0][ -# "avg_spec_accept_length" -# ] -# print( -# f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n" -# f"accuracy={metrics['accuracy']=:.3f}\n" -# f"{avg_spec_accept_length=:.3f}\n" -# ) -# self.assertGreater(avg_spec_accept_length, 2.3) - - if __name__ == "__main__": unittest.main()