Skip to content

Commit b7408f7

Browse files
authored
[ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959)
This PR is to support efficient attention and flash attention in ORTModule, including: - Use ATen to call efficient attention, which requires PyTorch 2.2.0 dev or newer. ORTMODULE_USE_EFFICIENT_ATTENTION=1 to enable. - Integrate Triton Flash attention, which requires triton==2.0.0.dev20221202. Need A100 or H100. ORTMODULE_USE_FLASH_ATTENTION=1 to enable. - A python transformer tool to match sub-graph by config and write transformer quickly. Current transformers supports attention mask for both efficient attn and flash attn, and dropout for efficient attn only. To support more training scenarios (such as causal mask in GPT2), more transformers need to be added. The feature is guarded by system environment variables, it won't effect any current behavior if not enabled. Since it requires specific PyTorch/Triton versions, related tests is not added for now.
1 parent 37873be commit b7408f7

File tree

26 files changed

+2037
-93
lines changed

26 files changed

+2037
-93
lines changed

cmake/onnxruntime_python.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING)
387387
file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS
388388
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*"
389389
)
390+
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
391+
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
392+
)
390393
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
391394
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
392395
)
@@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING)
741744
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils
742745
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
743746
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
747+
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
744748
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
745749
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
746750
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
@@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING)
794798
COMMAND ${CMAKE_COMMAND} -E copy
795799
${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs}
796800
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/
801+
COMMAND ${CMAKE_COMMAND} -E copy
802+
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
803+
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
797804
COMMAND ${CMAKE_COMMAND} -E copy
798805
${onnxruntime_python_ort_triton_srcs}
799806
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/

onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const {
3232
aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size,
3333
dlpack_outputs.get());
3434
for (size_t i = 0; i < output_size; ++i) {
35-
ORT_RETURN_IF_ERROR(
36-
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
35+
if (dlpack_outputs[i]) {
36+
ORT_RETURN_IF_ERROR(
37+
p_ctx_internal->SetOutputMLValue(static_cast<int>(i), dlpack::DlpackToOrtValue(dlpack_outputs[i])));
38+
}
3739
}
3840

3941
return Status::OK();

onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace onnxruntime {
1010
namespace contrib {
1111
namespace aten_ops {
1212

13-
typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index);
13+
typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
1414
typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size,
1515
DLManagedTensor** dlpack_inputs, size_t output_size,
1616
DLManagedTensor** dlpack_outputs);
@@ -22,17 +22,17 @@ class ATenOperatorExecutor {
2222
return instance;
2323
}
2424

25-
void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
26-
ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
27-
p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
25+
void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) {
26+
ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw);
27+
p_is_cpu_argument_func_ = reinterpret_cast<IsCpuArgumentFunc>(p_is_cpu_argument_func_raw);
2828
p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
2929
}
3030

3131
bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; }
3232

33-
bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) {
34-
ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized.");
35-
return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index);
33+
bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
34+
ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized.");
35+
return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
3636
}
3737

3838
void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size,
@@ -43,7 +43,7 @@ class ATenOperatorExecutor {
4343
}
4444

4545
private:
46-
IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr;
46+
IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr;
4747
ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr;
4848
};
4949

onnxruntime/core/framework/fallback_cpu_capability.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "onnx/defs/data_type_utils.h"
1010

1111
#include "core/framework/op_kernel.h"
12+
#include "core/framework/utils.h"
1213

1314
using namespace ONNX_NAMESPACE::Utils;
1415

@@ -77,7 +78,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
7778
ORT_THROW_IF_ERROR(node->ForEachWithIndex(
7879
node->OutputDefs(),
7980
[&](const NodeArg& node_arg, size_t out_index) {
80-
if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) {
81+
if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) {
8182
cpu_output_args.insert(&node_arg);
8283
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
8384
for (auto& consumer_node : consumer_nodes) {

onnxruntime/core/framework/utils.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index)
10251025
overload_name = attrs.at("overload_name").s();
10261026
}
10271027

1028-
return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index);
1028+
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true);
1029+
}
1030+
#else
1031+
ORT_UNUSED_PARAMETER(node);
1032+
#endif
1033+
1034+
return false;
1035+
}
1036+
1037+
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) {
1038+
if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) {
1039+
return true;
1040+
}
1041+
1042+
#ifdef ENABLE_ATEN
1043+
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
1044+
node.Domain() == kPytorchAtenDomain) {
1045+
const auto& attrs = node.GetAttributes();
1046+
ORT_ENFORCE(utils::HasString(attrs.at("operator")));
1047+
std::string op_name = attrs.at("operator").s();
1048+
std::string overload_name = "";
1049+
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
1050+
overload_name = attrs.at("overload_name").s();
1051+
}
1052+
1053+
return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false);
10291054
}
10301055
#else
10311056
ORT_UNUSED_PARAMETER(node);

onnxruntime/core/framework/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet
121121
bool sync_subgraph_fetches = false);
122122

123123
bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
124+
bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index);
124125

125126
template <typename T>
126127
constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {

onnxruntime/core/optimizer/transformer_memcpy.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
249249
if (!arg->Exists())
250250
continue;
251251

252-
if (kci && kci->kernel_def->IsOutputOnCpu(i))
252+
if (utils::IsOutputOnCpu(node, kci, i))
253253
non_provider_output_defs_.insert(arg);
254254
else
255255
provider_output_defs_.insert(arg);
@@ -308,7 +308,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
308308
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
309309
}
310310
if (arg_output_index != -1) {
311-
if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it);
311+
if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it);
312312
}
313313
}
314314
}
@@ -404,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
404404
// normally initializers are only inputs, but things may change with ops like assign
405405
ORT_THROW_IF_ERROR(Node::ForEachWithIndex(
406406
p_node->OutputDefs(),
407-
[kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
408-
if (kci->kernel_def->IsOutputOnCpu(index)) {
407+
[kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) {
408+
if (utils::IsOutputOnCpu(*p_node, kci, index)) {
409409
ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end());
410410
}
411411
return Status::OK();

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,14 +1214,14 @@ void addGlobalMethods(py::module& m) {
12141214

12151215
#ifdef ENABLE_ATEN
12161216
m.def("register_aten_op_executor",
1217-
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
1218-
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
1217+
[](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
1218+
size_t is_cpu_argument_address_int, aten_op_executor_address_int;
12191219
ORT_THROW_IF_ERROR(
1220-
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
1220+
ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int));
12211221
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
1222-
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
1222+
void* p_is_cpu_argument = reinterpret_cast<void*>(is_cpu_argument_address_int);
12231223
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
1224-
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
1224+
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor);
12251225
});
12261226
#endif
12271227
}

onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension():
2929
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor
3030

3131
_C.register_aten_op_executor(
32-
str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address())
32+
str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address())
3333
)

onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,32 @@ class ATenOperatorCache {
154154
std::unordered_map<std::pair<std::string, std::string>, ATenOperator, PairHash> ops_;
155155
};
156156

157-
// Backend uses this function to check if an argument is CPU input (non-tensor argument) or not.
158-
bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index) {
159-
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
160-
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
161-
return aten_op.elem_kinds[index] == c10::TypeKind::TensorType;
157+
const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorInputsMap = {
158+
{"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}};
159+
160+
const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorOutputsMap = {
161+
{"_efficient_attention_forward", {2, 3}}};
162+
163+
// Backend uses this function to check if an argument is CPU input or not.
164+
bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
165+
if (is_input) {
166+
// If the argument is non-tensor type, it's CPU argument.
167+
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
168+
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
169+
if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) {
170+
return true;
171+
}
172+
}
173+
174+
std::string full_name = std::string(op_name);
175+
std::string overload_name_str = std::string(overload_name);
176+
if (overload_name_str != "") {
177+
full_name += ("." + overload_name_str);
178+
}
179+
180+
const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap;
181+
return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() &&
182+
cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end();
162183
}
163184

164185
void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size,
@@ -196,14 +217,15 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t
196217
size_t output_index = 0;
197218
for (const auto& ret : torch::jit::pop(stack, output_size)) {
198219
const auto& tensor = ret.toTensor();
199-
dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous());
220+
dlpack_outputs[output_index++] =
221+
tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
200222
}
201223
}
202224

203-
size_t is_tensor_argument_address() { return reinterpret_cast<size_t>(&IsTensorArgument); }
225+
size_t is_cpu_argument_address() { return reinterpret_cast<size_t>(&IsCpuArgument); }
204226
size_t execute_aten_operator_address() { return reinterpret_cast<size_t>(&ExecuteATenOperator); }
205227

206228
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
207-
m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check.");
229+
m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check.");
208230
m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor");
209231
}

0 commit comments

Comments
 (0)