Skip to content

Commit 44b4b07

Browse files
ArthurZuckerIsotr0pyhmellorDarkLight1337mgoin
authored andcommitted
[Model]: Add transformers backend support (#11330)
# Adds support for `transformers` as a backend Following huggingface/transformers#35235, a bunch of models should already be supported, we are ramping up support for more models. Thanks @Isotr0py for the TP support, and @hmellor for his help as well! This includes: - `trust_remote_code=True` support: any model on the hub, if it implements attention the correct way can be natively supported!! - tensor parallel support --------- Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent 78e7ad1 commit 44b4b07

File tree

11 files changed

+528
-9
lines changed

11 files changed

+528
-9
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ steps:
349349
- vllm/
350350
- tests/models
351351
commands:
352+
- pytest -v -s models/test_transformers.py
352353
- pytest -v -s models/test_registry.py
353354
- pytest -v -s models/test_initialization.py
354355

@@ -485,6 +486,7 @@ steps:
485486
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
486487
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
487488
# Avoid importing model tests that cause CUDA reinitialization error
489+
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
488490
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
489491
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
490492
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'

docs/source/models/supported_models.md

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,82 @@ If vLLM successfully returns text (for generative models) or hidden states (for
4040
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
4141
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.
4242

43+
### Transformers fallback
44+
45+
After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
46+
47+
To check if the backend is `transformers`, you can simply do this:
48+
49+
```python
50+
from vllm import LLM
51+
llm = LLM(model=..., task="generate") # Name or path of your model
52+
llm.apply_model(lambda model: print(model.__class__))
53+
```
54+
55+
If it is `TransformersModel` then it means it's based on `transformers`!
56+
57+
#### Supported features
58+
59+
##### LORA and quantization
60+
61+
Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
62+
63+
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
64+
65+
Hints as to how this would look like:
66+
67+
```python
68+
class TransformersModel(nn.Module, SupportsLoRA):
69+
def __init__(*):
70+
...
71+
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
72+
```
73+
74+
Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!
75+
76+
##### Remote code
77+
78+
This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!
79+
80+
```python
81+
from vllm import LLM
82+
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
83+
llm.apply_model(lambda model: print(model.__class__))
84+
```
85+
86+
A model just needs the following two things:
87+
88+
```python
89+
from transformers import PreTrainedModel
90+
from torch import nn
91+
92+
class MyAttention(nn.Module):
93+
94+
def forward(self, hidden_states, **kwargs): # <- kwargs are required
95+
96+
...
97+
attention_interface = attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
98+
attn_output, attn_weights = attention_interface(
99+
self,
100+
query_states,
101+
key_states,
102+
value_states,
103+
**kwargs,
104+
)
105+
...
106+
107+
class MyModel(PreTrainedModel):
108+
_supports_attention_backend = True
109+
```
110+
111+
Here is what happens in the background:
112+
113+
1. The config is loaded
114+
2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
115+
3. The `TransformersModel` backend is used. See `/model_executors/models/transformers`, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
116+
117+
That's it!
118+
43119
### ModelScope
44120

45121
To use models from [ModelScope](https://www.modelscope.cn) instead of HuggingFace Hub, set an environment variable:

requirements-common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requests >= 2.26.0
55
tqdm
66
blake3
77
py-cpuinfo
8-
transformers >= 4.48.2 # Required for Bamba.
8+
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
99
tokenizers >= 0.19.1 # Required for Llama 3.
1010
protobuf # Required by LlamaTokenizer.
1111
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'

tests/models/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,17 @@ def check_available_online(
281281
speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501
282282
}
283283

284+
_FALLBACK_MODEL = {
285+
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
286+
}
287+
284288
_EXAMPLE_MODELS = {
285289
**_TEXT_GENERATION_EXAMPLE_MODELS,
286290
**_EMBEDDING_EXAMPLE_MODELS,
287291
**_CROSS_ENCODER_EXAMPLE_MODELS,
288292
**_MULTIMODAL_EXAMPLE_MODELS,
289293
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
294+
**_FALLBACK_MODEL,
290295
}
291296

292297

tests/models/test_oot_registration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def test_plugin(dummy_opt_path):
1515
os.environ["VLLM_PLUGINS"] = ""
1616
with pytest.raises(Exception) as excinfo:
1717
LLM(model=dummy_opt_path, load_format="dummy")
18-
assert "are not supported for now" in str(excinfo.value)
18+
error_msg = "has no vLLM implementation and " \
19+
"the Transformers implementation is not compatible with vLLM."
20+
assert (error_msg in str(excinfo.value))
1921

2022

2123
@fork_new_process_for_each_test

tests/models/test_transformers.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Test the functionality of the Transformers backend.
2+
3+
Run `pytest tests/models/test_transformers.py`.
4+
"""
5+
from contextlib import nullcontext
6+
from typing import Type
7+
8+
import pytest
9+
10+
from ..conftest import HfRunner, VllmRunner
11+
from ..utils import multi_gpu_test
12+
from .utils import check_logprobs_close
13+
14+
15+
def check_implementation(
16+
hf_runner: Type[HfRunner],
17+
vllm_runner: Type[VllmRunner],
18+
example_prompts: list[str],
19+
model: str,
20+
**kwargs,
21+
):
22+
max_tokens = 32
23+
num_logprobs = 5
24+
25+
with vllm_runner(model, **kwargs) as vllm_model:
26+
vllm_outputs = vllm_model.generate_greedy_logprobs(
27+
example_prompts, max_tokens, num_logprobs)
28+
29+
with hf_runner(model) as hf_model:
30+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
31+
example_prompts, max_tokens, num_logprobs)
32+
33+
check_logprobs_close(
34+
outputs_0_lst=hf_outputs,
35+
outputs_1_lst=vllm_outputs,
36+
name_0="hf",
37+
name_1="vllm",
38+
)
39+
40+
41+
@pytest.mark.parametrize(
42+
"model,model_impl",
43+
[
44+
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
45+
("openai-community/gpt2", "transformers"),
46+
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
47+
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
48+
]) # trust_remote_code=True by default
49+
def test_models(hf_runner, vllm_runner, example_prompts, model,
50+
model_impl) -> None:
51+
52+
maybe_raises = nullcontext()
53+
if model == "openai-community/gpt2" and model_impl == "transformers":
54+
# Model is not backend compatible
55+
maybe_raises = pytest.raises(
56+
ValueError,
57+
match="The Transformers implementation.*not compatible with vLLM")
58+
59+
with maybe_raises:
60+
check_implementation(hf_runner,
61+
vllm_runner,
62+
example_prompts,
63+
model,
64+
model_impl=model_impl)
65+
66+
67+
@multi_gpu_test(num_gpus=2)
68+
def test_distributed(
69+
hf_runner,
70+
vllm_runner,
71+
example_prompts,
72+
):
73+
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
74+
check_implementation(hf_runner, vllm_runner, example_prompts,
75+
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)

vllm/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def compute_hash(self) -> str:
8383
...
8484

8585

86+
class ModelImpl(str, enum.Enum):
87+
AUTO = "auto"
88+
VLLM = "vllm"
89+
TRANSFORMERS = "transformers"
90+
91+
8692
class ModelConfig:
8793
"""Configuration for the model.
8894
@@ -167,6 +173,12 @@ class ModelConfig:
167173
`logits_processors` extra completion argument. Defaults to None,
168174
which allows no processors.
169175
generation_config: Configuration parameter file for generation.
176+
model_impl: Which implementation of the model to use:
177+
"auto" will try to use the vLLM implementation if it exists and
178+
fall back to the Transformers implementation if no vLLM
179+
implementation is available.
180+
"vllm" will use the vLLM model implementation.
181+
"transformers" will use the Transformers model implementation.
170182
override_generation_config: Override the generation config with the
171183
given config.
172184
"""
@@ -230,6 +242,7 @@ def __init__(
230242
generation_config: Optional[str] = None,
231243
enable_sleep_mode: bool = False,
232244
override_generation_config: Optional[Dict[str, Any]] = None,
245+
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
233246
) -> None:
234247
self.model = model
235248
self.tokenizer = tokenizer
@@ -241,6 +254,7 @@ def __init__(
241254
self.code_revision = code_revision
242255
self.rope_scaling = rope_scaling
243256
self.rope_theta = rope_theta
257+
self.model_impl = model_impl
244258

245259
if hf_overrides is None:
246260
hf_overrides = {}

vllm/engine/arg_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
1414
DecodingConfig, DeviceConfig, HfOverrides,
1515
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
16-
ModelConfig, ObservabilityConfig, ParallelConfig,
17-
PoolerConfig, PromptAdapterConfig, SchedulerConfig,
18-
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
19-
VllmConfig)
16+
ModelConfig, ModelImpl, ObservabilityConfig,
17+
ParallelConfig, PoolerConfig, PromptAdapterConfig,
18+
SchedulerConfig, SpeculativeConfig, TaskOption,
19+
TokenizerPoolConfig, VllmConfig)
2020
from vllm.executor.executor_base import ExecutorBase
2121
from vllm.logger import init_logger
2222
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -204,6 +204,7 @@ class EngineArgs:
204204
generation_config: Optional[str] = None
205205
override_generation_config: Optional[Dict[str, Any]] = None
206206
enable_sleep_mode: bool = False
207+
model_impl: str = "auto"
207208

208209
calculate_kv_scales: Optional[bool] = None
209210

@@ -383,6 +384,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
383384
'qualified names that can be passed with the `logits_processors` '
384385
'extra completion argument. Defaults to None, which allows no '
385386
'processors.')
387+
parser.add_argument(
388+
'--model-impl',
389+
type=str,
390+
default=EngineArgs.model_impl,
391+
choices=[f.value for f in ModelImpl],
392+
help='Which implementation of the model to use.\n\n'
393+
'* "auto" will try to use the vLLM implementation if it exists '
394+
'and fall back to the Transformers implementation if no vLLM '
395+
'implementation is available.\n'
396+
'* "vllm" will use the vLLM model implementation.\n'
397+
'* "transformers" will use the Transformers model '
398+
'implementation.\n')
386399
# Parallel arguments
387400
parser.add_argument(
388401
'--distributed-executor-backend',
@@ -1038,6 +1051,7 @@ def create_model_config(self) -> ModelConfig:
10381051
generation_config=self.generation_config,
10391052
override_generation_config=self.override_generation_config,
10401053
enable_sleep_mode=self.enable_sleep_mode,
1054+
model_impl=self.model_impl,
10411055
)
10421056

10431057
def create_load_config(self) -> LoadConfig:

vllm/model_executor/model_loader/utils.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22
"""Utilities for selecting and loading models."""
33
import contextlib
44
from dataclasses import dataclass, field
5-
from typing import Dict, List, Tuple, Type
5+
from typing import Dict, List, Optional, Tuple, Type
66

77
import torch
8+
import transformers
89
from torch import nn
10+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
911

10-
from vllm.config import ModelConfig
12+
from vllm.config import ModelConfig, ModelImpl
13+
from vllm.logger import init_logger
1114
from vllm.model_executor.models import ModelRegistry
1215
from vllm.model_executor.models.adapters import (as_classification_model,
1316
as_embedding_model,
1417
as_reward_model)
1518

19+
logger = init_logger(__name__)
20+
1621

1722
@contextlib.contextmanager
1823
def set_default_torch_dtype(dtype: torch.dtype):
@@ -23,6 +28,50 @@ def set_default_torch_dtype(dtype: torch.dtype):
2328
torch.set_default_dtype(old_dtype)
2429

2530

31+
def is_transformers_impl_compatible(
32+
arch: str,
33+
module: Optional[transformers.PreTrainedModel] = None) -> bool:
34+
mod = module or getattr(transformers, arch, None)
35+
if mod is None:
36+
return False
37+
if hasattr(mod, "supports_backend"):
38+
return mod.is_backend_compatible()
39+
else:
40+
return mod._supports_flex_attn
41+
42+
43+
def resolve_transformers_fallback(model_config: ModelConfig,
44+
architectures: list[str]):
45+
for i, arch in enumerate(architectures):
46+
if arch == "TransformersModel":
47+
continue
48+
custom_module = None
49+
auto_map = getattr(model_config.hf_config, "auto_map", None)
50+
if auto_map is not None and "AutoModel" in auto_map:
51+
custom_module = get_class_from_dynamic_module(
52+
model_config.hf_config.auto_map["AutoModel"],
53+
model_config.model)
54+
# TODO(Isotr0py): Further clean up these raises.
55+
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
56+
if model_config.model_impl == ModelImpl.TRANSFORMERS:
57+
if not is_transformers_impl_compatible(arch, custom_module):
58+
raise ValueError(
59+
f"The Transformers implementation of {arch} is not "
60+
"compatible with vLLM.")
61+
architectures[i] = "TransformersModel"
62+
if model_config.model_impl == ModelImpl.AUTO:
63+
if not is_transformers_impl_compatible(arch, custom_module):
64+
raise ValueError(
65+
f"{arch} has no vLLM implementation and the Transformers "
66+
"implementation is not compatible with vLLM.")
67+
logger.warning(
68+
"%s has no vLLM implementation, falling back to Transformers "
69+
"implementation. Some features may not be supported and "
70+
"performance may not be optimal.", arch)
71+
architectures[i] = "TransformersModel"
72+
return architectures
73+
74+
2675
def get_model_architecture(
2776
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
2877
architectures = getattr(model_config.hf_config, "architectures", [])
@@ -38,6 +87,14 @@ def get_model_architecture(
3887
and "MixtralForCausalLM" in architectures):
3988
architectures = ["QuantMixtralForCausalLM"]
4089

90+
vllm_supported_archs = ModelRegistry.get_supported_archs()
91+
is_vllm_supported = any(arch in vllm_supported_archs
92+
for arch in architectures)
93+
if (not is_vllm_supported
94+
or model_config.model_impl == ModelImpl.TRANSFORMERS):
95+
architectures = resolve_transformers_fallback(model_config,
96+
architectures)
97+
4198
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
4299
if model_config.task == "embed":
43100
model_cls = as_embedding_model(model_cls)

0 commit comments

Comments
 (0)