diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index bf9adbaefabcc..a9a78668b4810 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -387,6 +387,9 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/*" ) + file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*" + ) file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py" ) @@ -741,6 +744,7 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/graph_optimizers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton/kernel COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils @@ -794,6 +798,9 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_torch_cpp_ext_fused_ops_srcs} $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_graph_optimizers_srcs} + $/onnxruntime/training/ortmodule/graph_optimizers/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ort_triton_srcs} $/onnxruntime/training/ort_triton/ diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc index 945c3aebce579..d0abf58922f88 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc @@ -32,8 +32,10 @@ Status ATen::Compute(OpKernelContext* p_ctx) const { aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size, dlpack_outputs.get()); for (size_t i = 0; i < output_size; ++i) { - ORT_RETURN_IF_ERROR( - p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + if (dlpack_outputs[i]) { + ORT_RETURN_IF_ERROR( + p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + } } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index be9650d96b004..d72868cd8fa9f 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index); +typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); - p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); + void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); + p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) { - ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index); + bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; + IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index 3d971e6aa29a2..ef68b88187e08 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -9,6 +9,7 @@ #include "onnx/defs/data_type_utils.h" #include "core/framework/op_kernel.h" +#include "core/framework/utils.h" using namespace ONNX_NAMESPACE::Utils; @@ -77,7 +78,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe ORT_THROW_IF_ERROR(node->ForEachWithIndex( node->OutputDefs(), [&](const NodeArg& node_arg, size_t out_index) { - if (kernel_info->kernel_def->IsOutputOnCpu(out_index)) { + if (utils::IsOutputOnCpu(*node, kernel_info, out_index)) { cpu_output_args.insert(&node_arg); auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index d63881ab4ff04..23fe5e1cd3d96 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1025,7 +1025,32 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index); + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + } +#else + ORT_UNUSED_PARAMETER(node); +#endif + + return false; +} + +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) { + if (p_kci && p_kci->kernel_def->IsOutputOnCpu(index)) { + return true; + } + +#ifdef ENABLE_ATEN + if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && + node.Domain() == kPytorchAtenDomain) { + const auto& attrs = node.GetAttributes(); + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); + std::string op_name = attrs.at("operator").s(); + std::string overload_name = ""; + if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) { + overload_name = attrs.at("overload_name").s(); + } + + return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index ea6a629f87cb8..f0b1b9109d405 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -121,6 +121,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet bool sync_subgraph_fetches = false); bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); +bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index); template constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index ed3e35706b688..0d7ab70eba613 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -249,7 +249,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg if (!arg->Exists()) continue; - if (kci && kci->kernel_def->IsOutputOnCpu(i)) + if (utils::IsOutputOnCpu(node, kci, i)) non_provider_output_defs_.insert(arg); else provider_output_defs_.insert(arg); @@ -308,7 +308,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } if (arg_output_index != -1) { - if (!kci || !kci->kernel_def->IsOutputOnCpu(arg_output_index)) provider_output_nodes_[arg].insert(&it); + if (!kci || !utils::IsOutputOnCpu(it, kci, arg_output_index)) provider_output_nodes_[arg].insert(&it); } } } @@ -404,8 +404,8 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker // normally initializers are only inputs, but things may change with ops like assign ORT_THROW_IF_ERROR(Node::ForEachWithIndex( p_node->OutputDefs(), - [kci, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { - if (kci->kernel_def->IsOutputOnCpu(index)) { + [kci, &p_node, &dup_replacements](const onnxruntime::NodeArg& arg, size_t index) { + if (utils::IsOutputOnCpu(*p_node, kci, index)) { ORT_ENFORCE(dup_replacements.find(&arg) == dup_replacements.end()); } return Status::OK(); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index a72f563601512..90271b5458399 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1214,14 +1214,14 @@ void addGlobalMethods(py::module& m) { #ifdef ENABLE_ATEN m.def("register_aten_op_executor", - [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { - size_t is_tensor_argument_address_int, aten_op_executor_address_int; + [](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void { + size_t is_cpu_argument_address_int, aten_op_executor_address_int; ORT_THROW_IF_ERROR( - ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); + ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int)); ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); - void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); + void* p_is_cpu_argument = reinterpret_cast(is_cpu_argument_address_int); void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor); }); #endif } diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py index 9dee6564509d5..8bf7cbf80eb37 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor _C.register_aten_op_executor( - str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) + str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address()) ) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 182f2368f5b47..903a394a06ef3 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -154,11 +154,32 @@ class ATenOperatorCache { std::unordered_map, ATenOperator, PairHash> ops_; }; -// Backend uses this function to check if an argument is CPU input (non-tensor argument) or not. -bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index) { - const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); - TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); - return aten_op.elem_kinds[index] == c10::TypeKind::TensorType; +const std::unordered_map> kCpuTensorInputsMap = { + {"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}}; + +const std::unordered_map> kCpuTensorOutputsMap = { + {"_efficient_attention_forward", {2, 3}}}; + +// Backend uses this function to check if an argument is CPU input or not. +bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { + if (is_input) { + // If the argument is non-tensor type, it's CPU argument. + const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); + TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); + if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) { + return true; + } + } + + std::string full_name = std::string(op_name); + std::string overload_name_str = std::string(overload_name); + if (overload_name_str != "") { + full_name += ("." + overload_name_str); + } + + const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap; + return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() && + cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end(); } void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size, @@ -196,14 +217,15 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t size_t output_index = 0; for (const auto& ret : torch::jit::pop(stack, output_size)) { const auto& tensor = ret.toTensor(); - dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()); + dlpack_outputs[output_index++] = + tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; } } -size_t is_tensor_argument_address() { return reinterpret_cast(&IsTensorArgument); } +size_t is_cpu_argument_address() { return reinterpret_cast(&IsCpuArgument); } size_t execute_aten_operator_address() { return reinterpret_cast(&ExecuteATenOperator); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check."); + m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check."); m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor"); } diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py index 7d5716b85db30..329fba5aa670a 100644 --- a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py @@ -5,7 +5,7 @@ from onnxruntime.capi import _pybind_state as _C -from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address +from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address def run_once_aten_op_executor(f): @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs): @run_once_aten_op_executor def load_aten_op_executor_cpp_extension(): - _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address())) + _C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address())) def init_aten_op_executor(): diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index c90acfdb7bb78..80d937fa163e6 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4180,6 +4180,7 @@ Return true if all elements are true and false otherwise. .Attr("func_name", "Function name of the Python Triton kernel.", AttributeProto::STRING, std::string("")) .Attr("onnx_key", "The hash key for the ONNX graph.", AttributeProto::INT, static_cast(0)) .Attr("onnx_string", "The onnx string of the triton kernel.", AttributeProto::STRING, std::string("")) + .AllowUncheckedAttributes() .Input(0, "inputs", "Input tensors. If to call an existing Python Triton kernel, " "the input count and order should match the arguments of the function. If to compute an ONNX graph, " diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py index 97318ea2e53ae..c1b99e4859dbd 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py @@ -3,15 +3,28 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out -from ._slice_scel import slice_scel, slice_scel_backward, transform_slice_scel +import os -__all__ = [ +from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401 +from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401 + +_all_kernels = [ "triton_gemm", "triton_gemm_out", "triton_matmul", "triton_matmul_out", "slice_scel", "slice_scel_backward", - "transform_slice_scel", ] + +_all_optimizers = [ + "optimize_graph_for_slice_scel", +] + +if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1: + from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401 + + _all_kernels.extend(["flash_attn_forward", "flash_attn_backward"]) + _all_optimizers.append("optimize_graph_for_flash_attention") + +__all__ = _all_kernels + _all_optimizers # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py new file mode 100644 index 0000000000000..40398b33d8f04 --- /dev/null +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py @@ -0,0 +1,1244 @@ +""" +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math +from typing import List, Tuple + +import torch +import triton +import triton.language as tl +from onnx import GraphProto, NodeProto, TensorProto, helper + +from onnxruntime.training.ortmodule import register_graph_optimizer +from onnxruntime.training.ortmodule.graph_optimizers.utils import GraphMatcher, check_attribute_value, update_graph + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Bias, + Out, + Lse, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != "none": + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) + v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != "none": + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == "none": + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == "matrix": + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != "none": + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def flash_attn_forward(q, k, v, bias=None, **kwargs): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + + causal = kwargs.get("causal", 0) == 1 + softmax_scale = kwargs.get("softmax_scale", 1.0 / math.sqrt(d)) + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)") + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + bias, + o, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse + + +def flash_attn_backward(do, q, k, v, o, lse, bias=None, **kwargs): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + + causal = kwargs.get("causal", 0) == 1 + softmax_scale = kwargs.get("softmax_scale", 1.0 / math.sqrt(d)) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=128, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)") + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + bias, + do, + dq_accum, + dk, + dv, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + return dq, dk, dv + + +def _make_flash_attention_nodes( + idx: int, + q: str, + k: str, + v: str, + y: str, + dy: str, + dq: str, + dk: str, + dv: str, + bias: str, + scale: float, +): + logsumexp = helper.make_tensor_value_info("logsumexp_" + str(idx), TensorProto.FLOAT, []) + fwd_node = helper.make_node( + "TritonOp", + [q, k, v, bias], + [y, logsumexp.name], + "TritonOp_Flash_Attn_Fwd_" + str(idx), + None, + "com.microsoft", + func_name="flash_attn_forward", + causal=0, + softmax_scale=scale, + ) + bwd_node = helper.make_node( + "TritonOp", + [dy, q, k, v, y, logsumexp.name, bias], + [dq, dk, dv], + "TritonOp_Flash_Attn_Bwd_" + str(idx), + None, + "com.microsoft", + func_name="flash_attn_backward", + causal=0, + softmax_scale=scale, + ) + return [fwd_node, bwd_node], [logsumexp] + + +# Without causal mask, without Dropout. For example, BERT model in HuggingFace. +_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("Identity", False, [(10, 0, 0)]), # 11 + ("Div", False, [(11, 0, 0)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15 + ("FusedMatMul", False, [(5, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(14, 0, 0)]), # 18 + ("Transpose", False, [(15, 0, 0)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 +] + + +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[19].output[0], + nodes[20].output[0], + nodes[4].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + return nodes, nodes_to_add, new_value_infos + + +# llama2+peft, k doesn't require grad. +_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 1)]), # 1 + ("Div", False, [(0, 0, 0)]), # 2 + ("Add", False, [(2, 0, 0)]), # 3 + ("Softmax", False, [(3, 0, 0)]), # 4 + ("MatMul", False, [(4, 0, 0)]), # 5 + ("Transpose", True, [(5, 0, 1)]), # 6 + ("Identity", False, [(6, 0, 0)]), # 7 + ("YieldOp", False, [(7, 0, -1)]), # 8 + ("Transpose", False, [(5, 0, 0)]), # 9 + ("FusedMatMul", False, [(6, 0, 1)]), # 10 + ("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11 + ("Identity", False, [(11, 0, 0)]), # 12 + ("Div", False, [(12, 0, 0)]), # 13 + ("Identity", False, [(13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(4, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Sum", False, [(16, 0, 0)]), # 18 + ("Transpose", False, [(18, 0, 0)]), # 19 +] + + +def _optimize_for_pattern_1(matcher: GraphProto, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[2].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 1, 3, 2]) + and scale_value is not None + and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + and matcher.get_consumer_count(nodes[14].output[0]) == 1 + ): + return [], [], [] + + dtype, _ = matcher.get_type_and_shape(nodes[0].input[0]) + assert dtype is not None + trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None) + trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None) + trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None) + trans_q = helper.make_node( + "Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_q_grad = helper.make_node( + "Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k = helper.make_node( + "Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3] + ) + nodes[19].input[0] = nodes[18].input[1] + v_grad = nodes[19].output[0] + nodes[19].output[0] = nodes[18].output[0] + nodes[18].input[1] = nodes[18].output[0] + nodes[18].output[0] = v_grad + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + trans_q_tensor.name, + trans_k_tensor.name, + nodes[6].input[0], + nodes[9].output[0], + nodes[17].input[0], + trans_q_grad_tensor.name, + "", + nodes[16].output[0], + nodes[3].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + nodes_to_remove = nodes[:6] + nodes[9:18] + nodes_to_add.extend([trans_q, trans_q_grad, trans_k]) + new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor]) + return nodes_to_remove, nodes_to_add, new_value_infos + + +# llama2+peft, k requires grad. +_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 1)]), # 1 + ("Div", False, [(0, 0, 0)]), # 2 + ("Add", False, [(2, 0, 0)]), # 3 + ("Softmax", False, [(3, 0, 0)]), # 4 + ("MatMul", False, [(4, 0, 0)]), # 5 + ("Transpose", True, [(5, 0, 1)]), # 6 + ("Identity", False, [(6, 0, 0)]), # 7 + ("YieldOp", False, [(7, 0, -1)]), # 8 + ("Transpose", False, [(5, 0, 0)]), # 9 + ("FusedMatMul", False, [(6, 0, 1)]), # 10 + ("SoftmaxGrad_13", False, [(10, 0, 0), (4, 0, 1)]), # 11 + ("Identity", False, [(11, 0, 0)]), # 12 + ("Div", False, [(12, 0, 0)]), # 13 + ("Identity", False, [(13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(14, 0, 1)]), # 16 + ("Transpose", False, [(16, 0, 0)]), # 17 + ("FusedMatMul", False, [(4, 0, 0)]), # 18 + ("Transpose", True, [(18, 0, 1)]), # 19 + ("Sum", False, [(18, 0, 0)]), # 20 + ("Transpose", False, [(20, 0, 0)]), # 21 +] + + +def _aptimize_for_pattern_2(matcher: GraphProto, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[2].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 1, 3, 2]) + and scale_value is not None + and check_attribute_value(nodes[6], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + and matcher.get_consumer_count(nodes[14].output[0]) == 2 + ): + return [], [], [] + + dtype, _ = matcher.get_type_and_shape(nodes[0].input[0]) + assert dtype is not None + trans_q_tensor = helper.make_tensor_value_info("trans_q_" + str(idx), dtype, None) + trans_q_grad_tensor = helper.make_tensor_value_info("trans_q_grad_" + str(idx), dtype, None) + trans_k_tensor = helper.make_tensor_value_info("trans_k_" + str(idx), dtype, None) + trans_k_grad_tensor = helper.make_tensor_value_info("trans_k_grad_" + str(idx), dtype, None) + trans_q = helper.make_node( + "Transpose", [nodes[0].input[0]], [trans_q_tensor.name], "Trans_Q_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_q_grad = helper.make_node( + "Transpose", [trans_q_grad_tensor.name], [nodes[15].output[0]], "Trans_Q_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k = helper.make_node( + "Transpose", [nodes[1].input[0]], [trans_k_tensor.name], "Trans_K_" + str(idx), perm=[0, 2, 1, 3] + ) + trans_k_grad = helper.make_node( + "Transpose", [trans_k_grad_tensor.name], [nodes[17].output[0]], "Trans_K_Grad_" + str(idx), perm=[0, 2, 1, 3] + ) + nodes[21].input[0] = nodes[20].input[1] + v_grad = nodes[21].output[0] + nodes[21].output[0] = nodes[20].output[0] + nodes[20].input[1] = nodes[20].output[0] + nodes[20].output[0] = v_grad + nodes_to_add, new_value_infos = _make_flash_attention_nodes( + idx, + trans_q_tensor.name, + trans_k_tensor.name, + nodes[6].input[0], + nodes[9].output[0], + nodes[19].input[0], + trans_q_grad_tensor.name, + trans_k_grad_tensor.name, + nodes[18].output[0], + nodes[3].input[1], + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ) + nodes_to_remove = nodes[:6] + nodes[9:20] + nodes_to_add.extend([trans_q, trans_q_grad, trans_k, trans_k_grad]) + new_value_infos.extend([trans_q_tensor, trans_q_grad_tensor, trans_k_tensor, trans_k_grad_tensor]) + return nodes_to_remove, nodes_to_add, new_value_infos + + +# TODO: add pattern to support attention with causal mask, such as GPT2 in HuggingFace. +_PATTERNS = [ + (_PATTERN_0, _optimize_for_pattern_0), + (_PATTERN_1, _optimize_for_pattern_1), + (_PATTERN_2, _aptimize_for_pattern_2), +] + + +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_flash_attention(graph: GraphProto): + nodes_to_remove = [] + nodes_to_add = [] + new_value_infos = [] + matcher = GraphMatcher(graph) + idx = 0 + for pattern_tuple in _PATTERNS: + for nodes in matcher.match_pattern(pattern_tuple[0]): + remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes) + if len(add_nodes) > 0: + nodes_to_remove.extend(remove_nodes) + nodes_to_add.extend(add_nodes) + new_value_infos.extend(add_value_infos) + idx += 1 + update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py index 8edcc9b63ef4f..fb7ddc68900c9 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_slice_scel.py @@ -11,7 +11,7 @@ import triton.language as tl from onnx import TensorProto, helper -from onnxruntime.training.ortmodule import register_graph_transformer +from onnxruntime.training.ortmodule import register_graph_optimizer from .._utils import get_attribute, to_numpy_array @@ -246,8 +246,8 @@ def _get_shape_related_nodes(graph, start_arg, sub_graph_nodes): args.append(output) -@register_graph_transformer(devices="cuda") -def transform_slice_scel(graph): +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_slice_scel(graph): remove_nodes = [] triton_nodes = [] value_infos = [] diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 59cf05bb082fc..fbf1b7c2bac42 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -124,7 +124,8 @@ def _are_deterministic_algorithms_enabled(): return ORTMODULE_IS_DETERMINISTIC -from .graph_transformer_registry import register_graph_transformer # noqa: E402, F401 +from .graph_optimizer_registry import register_graph_optimizer # noqa: E402, F401 +from .graph_optimizers import * # noqa: E402, F403 from .options import DebugOptions, LogLevel # noqa: E402, F401 # ORTModule must be loaded only after all validation passes diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3953d342f1897..e0f11e5aa407e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -21,7 +21,7 @@ from ._logger import ORTModuleInitPhase, SuppressLogs, TrackTime from ._runtime_inspector import Phase from ._utils import save_tuning_results, set_tuning_results -from .graph_transformer_registry import GraphTransformerRegistry +from .graph_optimizer_registry import GraphOptimizerRegistry from .options import DebugOptions, _SkipCheck @@ -369,7 +369,7 @@ def _build_graph(self, graph_transformer_config): device_type = self._device.type if device_type == "cuda" and self.is_rocm_pytorch: device_type = "rocm" - GraphTransformerRegistry.transform_all( + GraphOptimizerRegistry.optimize_all( type(self._flattened_module._original_module).__name__, device_type, self._onnx_models.optimized_model.graph ) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py new file mode 100644 index 0000000000000..897ecac148bfb --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizer_registry.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Callable + +from onnx.onnx_ml_pb2 import GraphProto + + +class GraphOptimizerRegistry: + _OPTIMIZER_FUNCS = {} # noqa: RUF012 + + @classmethod + def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]): + modules = [] + if target_modules == "all": + modules.append("all") + else: + modules = target_modules.split("|") + for module in modules: + if module in cls._OPTIMIZER_FUNCS: + cls._OPTIMIZER_FUNCS[module].append((fn, devices, priority)) + else: + cls._OPTIMIZER_FUNCS[module] = [(fn, devices, priority)] + + @classmethod + def optimize_all(cls, module_name: str, device: str, graph: GraphProto): + optimizers_to_apply = [] + if "all" in cls._OPTIMIZER_FUNCS: + optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS["all"]) + if module_name in cls._OPTIMIZER_FUNCS: + optimizers_to_apply.extend(cls._OPTIMIZER_FUNCS[module_name]) + optimizers_to_apply = [x for x in optimizers_to_apply if x[1] == "all" or device in x[1]] + optimizers_to_apply.sort(key=lambda x: x[2], reverse=True) + for fn, _, _ in optimizers_to_apply: + fn(graph) + + +# target_modules can be multiple module names separated by "|", or "all" means apply to all modules. +# devices can be multiple device types separated by "|" or "all" means apply to all devices. +def register_graph_optimizer(target_modules: str = "all", devices: str = "all", priority: int = 0): + def graph_optimizer_wrapper(fn): + GraphOptimizerRegistry.register(target_modules, devices, priority, fn) + return fn + + return graph_optimizer_wrapper diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py new file mode 100644 index 0000000000000..d215e12f8137a --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -0,0 +1,15 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os + +_all_optimizers = [] + +if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1: + from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 + + _all_optimizers.append("optimize_graph_for_aten_efficient_attention") + +__all__ = _all_optimizers # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py new file mode 100644 index 0000000000000..94bd41293b427 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -0,0 +1,414 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation +is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to +run is you are using PyTorch with older versions. + +PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add +support if we want to try in the future. +""" + +from typing import List, Tuple + +from onnx import GraphProto, NodeProto, TensorProto, helper + +from ..graph_optimizer_registry import register_graph_optimizer +from .utils import GraphMatcher, check_attribute_value, make_constant_node, update_graph + + +def _make_efficient_attention_nodes( + idx: int, + q: str, + k: str, + v: str, + y: str, + dy: str, + dq: str, + dk: str, + dv: str, + bias: str, + expand_bias: bool, + scale: float, + dropout_ratio: float, + causal: bool, +): + nodes_to_add = [] + scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale]) + dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio]) + causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0]) + int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0]) + true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True]) + false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False]) + logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, []) + seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, []) + offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, []) + new_value_infos = [logsumexp, seed, offset] + if expand_bias: + shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1) + shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3) + shape_2 = helper.make_node("Shape", [q], ["shape_2_" + str(idx)], start=1, end=2) + shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2) + concat = helper.make_node( + "Concat", + ["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)], + ["concated_shape_" + str(idx)], + axis=0, + ) + expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)]) + nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand]) + bias = "expanded_bias_" + str(idx) + fwd_node = helper.make_node( + "ATen", + [ + q, + k, + v, + bias, + "", + "", + "", + dropout_ratio_node.output[0], + causal_node.output[0], + true_node.output[0], + scale_node.output[0], + "", + "", + ], + [y, logsumexp.name, seed.name, offset.name], + "efficient_attention_forward_" + str(idx), + None, + "org.pytorch.aten", + operator="_efficient_attention_forward", + ) + bwd_node = helper.make_node( + "ATen", + [ + dy, + q, + k, + v, + bias, + y, + "", + "", + int_zero_node.output[0], + int_zero_node.output[0], + logsumexp.name, + dropout_ratio_node.output[0], + seed.name, + offset.name, + causal_node.output[0], + false_node.output[0], + scale_node.output[0], + "", + ], + [dq, dk, dv, ""], + "efficient_attention_backward_" + str(idx), + None, + "org.pytorch.aten", + operator="_efficient_attention_backward", + ) + nodes_to_add.extend( + [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node] + ) + return nodes_to_add, new_value_infos + + +# Without causal mask, with Dropout. For example, BERT model in HuggingFace. +_PATTERN_0: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("Dropout", False, [(5, 0, 0)]), # 6 + ("MatMul", False, [(6, 0, 0)]), # 7 + ("Transpose", True, [(7, 0, 1)]), # 8 + ("Transpose", False, [(7, 0, 0)]), # 9 + ("FusedMatMul", False, [(8, 0, 1)]), # 10 + ("DropoutGrad", False, [(10, 0, 0), (6, 1, 1)]), # 11 + ("SoftmaxGrad_13", False, [(11, 0, 0), (5, 0, 1)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("Div", False, [(13, 0, 0)]), # 14 + ("Identity", False, [(14, 0, 0)]), # 15 + ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]), # 16 + ("FusedMatMul", False, [(1, 0, 0), (15, 0, 1)]), # 17 + ("FusedMatMul", False, [(6, 0, 0)]), # 18 + ("Transpose", True, [(18, 0, 1)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 + ("Transpose", False, [(17, 0, 0)]), # 21 + ("Transpose", False, [(18, 0, 0)]), # 22 +] + + +def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + ratio_value = matcher.get_constant_value(nodes[6].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and ratio_value is not None + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[9], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + _, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0]) + _, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1]) + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[8].input[0], + nodes[9].output[0], + nodes[19].input[0], + nodes[20].output[0], + nodes[21].output[0], + nodes[22].output[0], + nodes[4].input[1], + add_input_shape_0 != add_input_shape_1, + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + ratio_value, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# Without causal mask, without Dropout. For example, BERT model and disabling attention dropout in HuggingFace. +_PATTERN_1: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Transpose", True, [(0, 0, 0)]), # 1 + ("Transpose", True, [(0, 0, 1)]), # 2 + ("Div", False, [(0, 0, 0)]), # 3 + ("Add", False, [(3, 0, 0)]), # 4 + ("Softmax", False, [(4, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("Identity", False, [(10, 0, 0)]), # 11 + ("Div", False, [(11, 0, 0)]), # 12 + ("Identity", False, [(12, 0, 0)]), # 13 + ("FusedMatMul", False, [(2, 0, 1), (13, 0, 0)]), # 14 + ("FusedMatMul", False, [(1, 0, 0), (13, 0, 1)]), # 15 + ("FusedMatMul", False, [(5, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(14, 0, 0)]), # 18 + ("Transpose", False, [(15, 0, 0)]), # 19 + ("Transpose", False, [(16, 0, 0)]), # 20 +] + + +def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value = matcher.get_constant_value(nodes[3].input[1]) + if not ( + check_attribute_value(nodes[1], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[2], "perm", [0, 2, 3, 1]) + and scale_value is not None + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) + ): + return [], [], [] + + _, add_input_shape_0 = matcher.get_type_and_shape(nodes[4].input[0]) + _, add_input_shape_1 = matcher.get_type_and_shape(nodes[4].input[1]) + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[1].input[0], + nodes[2].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[19].output[0], + nodes[20].output[0], + nodes[4].input[1], + add_input_shape_0 != add_input_shape_1, + 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), + 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# No causal mask, no attention mask, without Dropout. +_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Softmax", False, [(0, 0, 0)]), # 7 + ("Cast", False, [(7, 0, 0)]), # 8 + ("MatMul", False, [(8, 0, 0)]), # 9 + ("Transpose", True, [(9, 0, 1)]), # 10 + ("Transpose", False, [(9, 0, 0)]), # 11 + ("FusedMatMul", False, [(10, 0, 1)]), # 12 + ("Cast", False, [(12, 0, 0)]), # 13 + ("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14 + ("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16 + ("Mul", False, [(15, 0, 0)]), # 17 + ("Mul", False, [(16, 0, 0)]), # 18 + ("Identity", False, [(17, 0, 0)]), # 19 + ("Identity", False, [(18, 0, 0)]), # 20 + ("Cast", False, [(19, 0, 0)]), # 21 + ("Cast", False, [(20, 0, 0)]), # 22 + ("Transpose", False, [(21, 0, 0)]), # 23 + ("Transpose", False, [(22, 0, 0)]), # 24 + ("FusedMatMul", False, [(8, 0, 0)]), # 25 + ("Transpose", True, [(25, 0, 1)]), # 26 + ("Transpose", False, [(25, 0, 0)]), # 27 +] + + +def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[8], "to", 10) + and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[10].input[0], + nodes[11].output[0], + nodes[26].input[0], + nodes[23].output[0], + nodes[24].output[0], + nodes[27].output[0], + "", + False, + scale_value_1, + 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# Has causal mask, no attention mask, without Dropout. +_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Add", False, [(0, 0, 0)]), # 7 + ("Cast", True, [(7, 0, 1)]), # 8 + ("Slice", True, [(8, 0, 0)]), # 9 + ("Slice", True, [(9, 0, 0)]), # 10 + ("Unsqueeze", True, [(9, 0, 2)]), # 11 + ("Gather", True, [(11, 0, 0)]), # 12 + ("Shape", True, [(12, 0, 0)]), # 13 + ("Softmax", False, [(7, 0, 0)]), # 14 + ("Cast", False, [(14, 0, 0)]), # 15 + ("MatMul", False, [(15, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(16, 0, 0)]), # 18 + ("FusedMatMul", False, [(17, 0, 1)]), # 19 + ("Cast", False, [(19, 0, 0)]), # 20 + ("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21 + ("Identity", False, [(21, 0, 0)]), # 22 + ("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23 + ("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24 + ("Mul", False, [(23, 0, 0)]), # 25 + ("Mul", False, [(24, 0, 0)]), # 26 + ("Identity", False, [(25, 0, 0)]), # 27 + ("Identity", False, [(26, 0, 0)]), # 28 + ("Cast", False, [(27, 0, 0)]), # 29 + ("Cast", False, [(28, 0, 0)]), # 30 + ("Transpose", False, [(29, 0, 0)]), # 31 + ("Transpose", False, [(30, 0, 0)]), # 32 + ("FusedMatMul", False, [(15, 0, 0)]), # 33 + ("Transpose", True, [(33, 0, 1)]), # 34 + ("Transpose", False, [(33, 0, 0)]), # 35 +] + + +def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[15], "to", 10) + and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[34].input[0], + nodes[31].output[0], + nodes[32].output[0], + nodes[35].output[0], + "", + False, + scale_value_1, + 0.0, + True, + ) + return nodes, nodes_to_add, new_value_infos + + +_PATTERNS = [ + (_PATTERN_0, _optimize_for_pattern_0), + (_PATTERN_1, _optimize_for_pattern_1), + (_PATTERN_2, _optimize_for_pattern_2), + (_PATTERN_3, _optimize_for_pattern_3), +] + + +@register_graph_optimizer(devices="cuda") +def optimize_graph_for_aten_efficient_attention(graph: GraphProto): + nodes_to_remove = [] + nodes_to_add = [] + new_value_infos = [] + matcher = GraphMatcher(graph) + idx = 0 + for pattern_tuple in _PATTERNS: + for nodes in matcher.match_pattern(pattern_tuple[0]): + remove_nodes, add_nodes, add_value_infos = pattern_tuple[1](matcher, idx, nodes) + if len(add_nodes) > 0: + nodes_to_remove.extend(remove_nodes) + nodes_to_add.extend(add_nodes) + new_value_infos.extend(add_value_infos) + idx += 1 + update_graph(graph, nodes_to_remove, nodes_to_add, new_value_infos) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py new file mode 100644 index 0000000000000..e6e5ce56773e1 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/utils.py @@ -0,0 +1,178 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import itertools +from typing import Any, Dict, List, Sequence, Tuple + +import numpy as np +from onnx import GraphProto, NodeProto, TensorProto, helper, numpy_helper + + +def _get_attribute(node: NodeProto, attr_name: str, default_value: Any = None) -> Any: + """Get attribute value from node by attribute key.""" + found = [attr for attr in node.attribute if attr.name == attr_name] + if found: + return helper.get_attribute_value(found[0]) + return default_value + + +def _to_numpy_array(node: Any) -> np.ndarray: + """Convert Constant node or TensorProto to Python value.""" + tensor = node + if isinstance(node, NodeProto): + tensor = _get_attribute(node, "value") + assert isinstance(tensor, TensorProto) + return numpy_helper.to_array(tensor).tolist() + + +class GraphMatcher: + """Sub-graph matcher with given pattern. + + GraphMatcher takes an ONNX graph to initialize. It tries to match sub-graphs to a given pattern and yield + matched sub-graphs (a list of matched nodes for each sub-graph) one by one. + + Pattern is described by a list. Each entry of the list is a Tuple: + + Tuple[str, bool, List[Tuple[int, int, int]]], e.g., ("FusedMatMul", False, [(2, 0, 1), (15, 0, 0)]) + + * First string is the Op type, e.g., "FusedMatMul". + * Second bool indicates it's producer node or consumer node for source node. + * There is a list to describe the edge infos of this node to other nodes, each edge is a tuple with 3 integers, + first integer is the index of the target node in the list, second integer is the output index of the edge, + and thrid integer is the input index of the edge. + + For each entry, GraphMatcher used the first edge to lookup target node, and try to use make sure the sug-graph also + matches rest edge infos. + + Note that when lookup target node, it will only take the first matched node as target node. For example, if a source + node has multiple "MatMul" consumers nodes comsuming same output, only the first "MatMul" node will be returned. + You need to avoid using such confusing edge info as the first edge info for node lookup. Try to use other edge to + avoid such confusion if possible. + """ + + def __init__(self, graph: GraphProto): + self._graph: GraphProto = graph + self._op_type_to_nodes: Dict[str, List[NodeProto]] = {} + self._consumer_count: Dict[str, int] = {} + for node in graph.node: + if node.op_type not in self._op_type_to_nodes: + self._op_type_to_nodes[node.op_type] = [] + self._op_type_to_nodes[node.op_type].append(node) + for input in node.input: + self._consumer_count[input] = self._consumer_count.get(input, 0) + 1 + + def _get_producer(self, arg: str, op_type: str, output_idx: int): + for node in self._op_type_to_nodes.get(op_type, []): + if (output_idx >= 0 and len(node.output) > output_idx and node.output[output_idx] == arg) or ( + output_idx == -1 and arg in node.output + ): + return node + return None + + def _get_consumer(self, arg: str, op_type: str, input_idx: int): + for node in self._op_type_to_nodes.get(op_type, []): + if (input_idx >= 0 and len(node.input) > input_idx and node.input[input_idx] == arg) or ( + input_idx == -1 and arg in node.input + ): + return node + return None + + def get_consumer_count(self, arg: str): + return self._consumer_count.get(arg, 0) + + def get_constant_value(self, arg: str): + node_or_initializer = None + if "Constant" in self._op_type_to_nodes: + for node in self._op_type_to_nodes["Constant"]: + if arg in node.output: + node_or_initializer = node + break + if node_or_initializer is None: + for initializer in self._graph.initializer: + if arg == initializer.name: + node_or_initializer = initializer + break + if node_or_initializer is None: + return None + return _to_numpy_array(node_or_initializer) + + def get_type_and_shape(self, arg: str): + value_infos = [ + value_info + for value_info in itertools.chain(self._graph.input, self._graph.value_info) + if value_info.name == arg + ] + if len(value_infos) > 0 and value_infos[0].type.tensor_type.HasField("shape"): + shape = [] + for dim in value_infos[0].type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + return value_infos[0].type.tensor_type.elem_type, shape + initializers = [initializer for initializer in self._graph.initializer if initializer.name == arg] + if len(initializers) > 0: + return initializers[0].data_type, initializers[0].dims + return None, None + + def _match_pattern(self, node: NodeProto, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + nodes = [node] + for i in range(1, len(pattern)): + next_op_type = pattern[i][0] + is_producer = pattern[i][1] + node_idx, output_idx, input_idx = pattern[i][2][0] + next_node = ( + self._get_producer(nodes[node_idx].input[input_idx], next_op_type, output_idx) + if is_producer + else self._get_consumer(nodes[node_idx].output[output_idx], next_op_type, input_idx) + ) + if next_node is None: + return [] + for j in range(1, len(pattern[i][2])): + node_idx, output_idx, input_idx = pattern[i][2][j] + assert output_idx >= 0 and input_idx >= 0 + if (not is_producer and nodes[node_idx].output[output_idx] != next_node.input[input_idx]) or ( + is_producer and next_node.output[output_idx] != nodes[node_idx].input[input_idx] + ): + return [] + nodes.append(next_node) + return nodes + + def match_pattern(self, pattern: List[Tuple[str, bool, List[Tuple[int, int, int]]]]): + for node in self._op_type_to_nodes.get(pattern[0][0], []): + result = self._match_pattern(node, pattern) + if len(result) == len(pattern): + yield result + + +def check_attribute_value(node: NodeProto, attr_name: str, expected_value: Any): + """Check if the attribute of given node has expected value.""" + value = _get_attribute(node, attr_name) + return value == expected_value + + +def make_constant_node(name: str, dtype: TensorProto.DataType, dims: Sequence[int], vals: Any): + """Create a constant node with given constant tensor (data type, shape, and data).""" + return helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, data_type=dtype, dims=dims, vals=vals), + ) + + +def update_graph( + graph: GraphProto, + nodes_to_remove: List[NodeProto], + nodes_to_add: List[NodeProto], + new_value_infos: List[TensorProto] = [], # noqa: B006 +): + """Update an ONNX graph by removing some nodes, and adding some new nodes and value infos.""" + nodes = [node for node in graph.node if node not in nodes_to_remove] + nodes.extend(nodes_to_add) + graph.ClearField("node") + graph.node.extend(nodes) + if len(new_value_infos) > 0: + graph.value_info.extend(new_value_infos) diff --git a/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py b/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py deleted file mode 100644 index 70056179c140e..0000000000000 --- a/orttraining/orttraining/python/training/ortmodule/graph_transformer_registry.py +++ /dev/null @@ -1,47 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from typing import Callable - -from onnx.onnx_ml_pb2 import GraphProto - - -class GraphTransformerRegistry: - _TRANSFORMER_FUNCS = {} # noqa: RUF012 - - @classmethod - def register(cls, target_modules: str, devices: str, priority: int, fn: Callable[[GraphProto], None]): - modules = [] - if target_modules == "all": - modules.append("all") - else: - modules = target_modules.split("|") - for module in modules: - if module in cls._TRANSFORMER_FUNCS: - cls._TRANSFORMER_FUNCS[module].append((fn, devices, priority)) - else: - cls._TRANSFORMER_FUNCS[module] = [(fn, devices, priority)] - - @classmethod - def transform_all(cls, module_name: str, device: str, graph: GraphProto): - transformers_to_apply = [] - if "all" in cls._TRANSFORMER_FUNCS: - transformers_to_apply.extend(cls._TRANSFORMER_FUNCS["all"]) - if module_name in cls._TRANSFORMER_FUNCS: - transformers_to_apply.extend(cls._TRANSFORMER_FUNCS[module_name]) - transformers_to_apply = [x for x in transformers_to_apply if x[1] == "all" or device in x[1]] - transformers_to_apply.sort(key=lambda x: x[2], reverse=True) - for fn, _, _ in transformers_to_apply: - fn(graph) - - -# target_modules can be multiple module names separated by "|", or "all" means apply to all modules. -# devices can be multiple device types separated by "|" or "all" means apply to all devices. -def register_graph_transformer(target_modules: str = "all", devices: str = "all", priority: int = 0): - def graph_transformer_wrapper(fn): - GraphTransformerRegistry.register(target_modules, devices, priority, fn) - return fn - - return graph_transformer_wrapper diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc b/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc index 28f4ff665f797..c230a0c9a3b1d 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.cc @@ -17,8 +17,8 @@ InlinedHashSet TritonOp::GetBoolOutputs(size_t output_size) const { InlinedHashSet bool_outputs; for (size_t i = 0; i < output_size; ++i) { ORT_ENFORCE(i < Node().OutputDefs().size(), "Output index out of range."); - if (Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == - ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) { + if (Node().OutputDefs()[i]->Exists() && Node().OutputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) { bool_outputs.insert(i); } } @@ -37,13 +37,15 @@ Status TritonOp::Compute(OpKernelContext* context) const { InlinedHashSet bool_outputs = GetBoolOutputs(output_size); auto& executor = training::framework::triton::TritonOpExecutor::Instance(); if (func_name_ != "") { - executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs); + executor.ExecuteByFuncName(func_name_, inputs, outputs, bool_outputs, kwargs_); } else { executor.ExecuteByOnnx(onnx_key_, onnx_string_, inputs, outputs, bool_outputs); } ORT_ENFORCE(output_size == outputs.size()); for (size_t i = 0; i < output_size; ++i) { - ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast(i), outputs[i])); + if (Node().OutputDefs()[i]->Exists()) { + ORT_THROW_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast(i), outputs[i])); + } } return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h index 25e7b1f15ff6b..f226db76f7ed7 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h @@ -5,6 +5,8 @@ #pragma once +#include "core/common/inlined_containers.h" + #ifndef SHARED_PROVIDER #include "core/framework/op_kernel.h" #endif @@ -18,6 +20,19 @@ class TritonOp final : public OpKernel { ORT_THROW_IF_ERROR(info.GetAttr("func_name", &func_name_)); ORT_THROW_IF_ERROR(info.GetAttr("onnx_key", &onnx_key_)); ORT_THROW_IF_ERROR(info.GetAttr("onnx_string", &onnx_string_)); + for (const auto& attr : info.node().GetAttributes()) { + if (attr.first.rfind("_", 0) == 0 || attr.first == "func_name" || attr.first == "onnx_key" || + attr.first == "onnx_string") { + continue; + } + // Support int64 and float only for now, skip other types. + if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) { + kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}}); + } else if (attr.second.type() == + ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) { + kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}}); + } + } } Status Compute(OpKernelContext* context) const override; @@ -28,6 +43,7 @@ class TritonOp final : public OpKernel { std::string func_name_; int64_t onnx_key_; std::string onnx_string_; + InlinedHashMap> kwargs_; }; bool IsTritonOpExecutorInitialized(); diff --git a/pyproject.toml b/pyproject.toml index 89011a7944ab6..97515cb9fa62b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,3 +92,4 @@ unfixable = [ "tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix "onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix "onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix +"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo. diff --git a/setup.py b/setup.py index b71836e0ee6e4..f6308c56d0590 100644 --- a/setup.py +++ b/setup.py @@ -466,6 +466,7 @@ def finalize_options(self): "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils", "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator", "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops", + "onnxruntime.training.ortmodule.graph_optimizers", "onnxruntime.training.ort_triton", "onnxruntime.training.ort_triton.kernel", "onnxruntime.training.utils",