You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)
This PR is to support efficient attention and flash attention in
ORTModule, including:
- Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev
or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable.
- Integrate Triton Flash attention, which requires
triton==2.0.0.dev20221202. Need A100 or H100.
ORTMODULE_USE_FLASH_ATTENTION=1 to enable.
- A python transformer tool to match sub-graph by config and write
transformer quickly.
Current transformers supports attention mask for both efficient attn and
flash attn, and dropout for efficient attn only. To support more
training scenarios (such as causal mask in GPT2), more transformers need
to be added.
The feature is guarded by system environment variables, it won't effect
any current behavior if not enabled. Since it requires specific
PyTorch/Triton versions, related tests is not added for now.
0 commit comments