Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)

# Usually, SIMD instructions in kernel compile fine, and we detect at run-time if they are supported
# But on some platforms, even compiling these SIMD specific instructions could fail.
# So we provide options to disable compiling these SIMD instructions.
option(onnxruntime_DISABLE_SSE4 "Disable compiling kernel with SSE4 instructions" OFF)
option(onnxruntime_DISABLE_AVX "Disable compiling kernel with AVX instructions" OFF)
option(onnxruntime_DISABLE_AVX2 "Disable compiling kernel with AVX2 instructions" OFF)
option(onnxruntime_DISABLE_AVX512 "Disable compiling kernel with AVX512 instructions" OFF)
option(onnxruntime_DISABLE_AMX "Disable compiling kernel with AMX instructions" OFF)

cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
Expand Down
107 changes: 90 additions & 17 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/saturation_check.cpp
)

if (onnxruntime_DISABLE_SSE4)
target_compile_definitions(onnxruntime_mlas PRIVATE ORT_DISABLE_SSE4)
endif()

if (onnxruntime_DISABLE_AVX)
target_compile_definitions(onnxruntime_mlas PRIVATE ORT_DISABLE_AVX)
endif()

if (onnxruntime_DISABLE_AVX2)
target_compile_definitions(onnxruntime_mlas PRIVATE ORT_DISABLE_AVX2)
endif()

if (onnxruntime_DISABLE_AVX512)
target_compile_definitions(onnxruntime_mlas PRIVATE ORT_DISABLE_AVX512)
endif()

if (onnxruntime_DISABLE_AMX)
target_compile_definitions(onnxruntime_mlas PRIVATE ORT_DISABLE_AMX)
endif()

target_sources(onnxruntime_mlas PRIVATE
${MLAS_INC_DIR}/mlas_float16.h
${MLAS_INC_DIR}/mlas_gemm_postprocessor.h
Expand Down Expand Up @@ -191,17 +211,18 @@ function(setup_mlas_source_for_windows)
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2")

set(mlas_platform_srcs_sse41
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
)

target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/dgemm.cpp
${mlas_platform_srcs_avx}
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
Expand Down Expand Up @@ -244,6 +265,24 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
)

if(NOT onnxruntime_DISABLE_SSE4)
target_sources(onnxruntime_mlas PRIVATE
${mlas_platform_srcs_sse41}
)
endif()

if(NOT onnxruntime_DISABLE_AVX)
target_sources(onnxruntime_mlas PRIVATE
${mlas_platform_srcs_avx}
)
endif()

if(NOT onnxruntime_DISABLE_AVX2)
target_sources(onnxruntime_mlas PRIVATE
${mlas_platform_srcs_avx2}
)
endif()

if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER)
set_source_files_properties(${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm PROPERTIES COMPILE_FLAGS "-DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER")
endif()
Expand All @@ -262,10 +301,14 @@ function(setup_mlas_source_for_windows)
else()
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm
${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm
)
if (NOT onnxruntime_DISABLE_SSE4)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
)
endif()
endif()
endfunction()

Expand Down Expand Up @@ -436,7 +479,7 @@ else()
${MLAS_SRC_DIR}/sconv_kernel_neon.cpp
${MLAS_SRC_DIR}/spool_kernel_neon.cpp
)

# Conditionally add the SVE implementation if compiler supports it
if (onnxruntime_USE_SVE)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
Expand All @@ -450,8 +493,8 @@ else()
endif()
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")

if (NOT APPLE)
set(mlas_platform_srcs
Expand Down Expand Up @@ -588,9 +631,14 @@ else()

set(mlas_platform_srcs
${mlas_platform_srcs_sse2}
${mlas_platform_srcs_avx}
)

if (NOT onnxruntime_DISABLE_AVX)
set(mlas_platform_srcs
${mlas_platform_srcs_avx}
)
endif()

# In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own
# implementation to avoid external dependency.
if(ANDROID)
Expand Down Expand Up @@ -715,26 +763,51 @@ endif()
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/dgemm.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${mlas_platform_srcs_sse2}
${mlas_platform_srcs_avx}
${mlas_platform_srcs_avx2}
${mlas_platform_srcs_avx512f}
${mlas_platform_srcs_avx512core}
${mlas_platform_srcs_avx512vnni}
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
if (NOT onnxruntime_DISABLE_SSE4)
set(mlas_platform_srcs
${mlas_platform_srcs}
${mlas_platform_srcs_sse41}
)
endif()

if (NOT onnxruntime_DISABLE_AVX)
set(mlas_platform_srcs
${mlas_platform_srcs}
${mlas_platform_srcs_avx}
)
endif()

if (NOT onnxruntime_DISABLE_AVX2)
set(mlas_platform_srcs
${mlas_platform_srcs}
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
)
endif()

if (NOT onnxruntime_DISABLE_AVX512)
set(mlas_platform_srcs
${mlas_platform_srcs}
${mlas_platform_srcs_avx512f}
${mlas_platform_srcs_avx512core}
${mlas_platform_srcs_avx512vnni}
)
endif()

if (NOT onnxruntime_ORT_MINIMAL_BUILD AND NOT onnxruntime_DISABLE_AVX512)
set(mlas_platform_srcs
${mlas_platform_srcs}
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f")
endif()
if(NOT APPLE)
if(NOT APPLE AND NOT onnxruntime_DISABLE_AMX)
set(mlas_platform_srcs
${mlas_platform_srcs}
${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S
${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S
)
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/mlas/lib/convsym.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ struct MLAS_CONV_SYM_DISPATCH {

#if defined(MLAS_TARGET_AMD64)

#if !defined(ORT_DISABLE_AVX2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these macros come from?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to follow the way enabling/disabling certain things was done in the code (for instance: ORT_ENABLE_STREAM). Then the macros should be defined with a '-D CMAKE_CXX_FLAGS=ORT_DISABLE_AVX2'.

Would there be a preferred way of doing this? Should it be done through cmake option() ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be preferable to make them CMake options. They should be documented too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? I don't want to overcrowd the cmake options. That would be 5 more options.
If you're positive, I will add the options and documentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be documented. it's fine to add additional CMake options. these are additional build configuration options so it seems like a reasonable place for them.

const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx2 = {
MlasConvSymKernelAvx2,
MlasConvSymDepthwiseKernelAvx2,
Expand Down Expand Up @@ -194,8 +195,9 @@ const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvxVnni = {
4, // KernelDepthwiseOutputCount
false, // FixupInputZeroPoint
};
#endif // !defined(ORT_DISABLE_AVX2)

#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) && !defined(ORT_DISABLE_AVX512)

const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Core = {
MlasConvSymKernelAvx512Core,
Expand Down Expand Up @@ -229,7 +231,7 @@ const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Vnni = {
false, // FixupInputZeroPoint
};

#endif // ORT_MINIMAL_BUILD
#endif // !defined(ORT_MINIMAL_BUILD) && !defined(ORT_DISABLE_AVX512)

#elif defined(MLAS_TARGET_ARM64)
const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchNeon = {
Expand Down
51 changes: 33 additions & 18 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ Return Value:
__cpuid(1, Cpuid1[0], Cpuid1[1], Cpuid1[2], Cpuid1[3]);
#endif

#if defined(_MSC_VER)
#if defined(_MSC_VER) && !defined(ORT_DISABLE_SSE4)

//
// Check if the processor supports SSE 4.1 instructions.
Expand All @@ -328,7 +328,7 @@ Return Value:
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchSse41;
}

#endif
#endif // defined(_MSC_VER) && !defined(ORT_DISABLE_SSE4)

//
// Check if the processor supports the AVX and OSXSAVE features.
Expand All @@ -348,10 +348,13 @@ Return Value:

if ((xcr0 & 0x6) == 0x6) {

#if !defined(ORT_DISABLE_AVX)
this->GemmFloatKernel = MlasGemmFloatKernelAvx;
#endif // !defined(ORT_DISABLE_AVX)

#if defined(MLAS_TARGET_AMD64)

#if !defined(ORT_DISABLE_AVX)
this->KernelM1Routine = MlasSgemmKernelM1Avx;
this->KernelM1TransposeBRoutine = MlasSgemmKernelM1TransposeBAvx;
this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Avx;
Expand All @@ -368,7 +371,7 @@ Return Value:
this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelAvx;
this->ReduceMinimumMaximumF32Kernel = MlasReduceMinimumMaximumF32KernelAvx;
this->GemmU8U8Kernel = nullptr;

#endif // !defined(ORT_DISABLE_AVX)
//
// Check if the processor supports AVX2/FMA3 features.
//
Expand All @@ -381,7 +384,7 @@ Return Value:
#endif

if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) {

#if !defined(ORT_DISABLE_AVX2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this ifdef block also includes FMA3 kernels. should those also be controlled by the macro to disable AVX2?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I've re-grouped the FMA3 kernels together and moved them outside the block.

this->Avx2Supported_ = true;

this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2;
Expand All @@ -390,6 +393,17 @@ Return Value:
this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchAvx2;
this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx2;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx2;
this->QLinearAddS8Kernel = MlasQLinearAddS8KernelAvx2;
this->QLinearAddU8Kernel = MlasQLinearAddU8KernelAvx2;
this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernelAvx2<uint8_t, int8_t>;
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernelAvx2<uint8_t, uint8_t>;
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;
this->RopeDispatch = &MlasRopeDispatchAvx2;
#endif // !defined(ORT_DISABLE_AVX2)

this->GemmFloatKernel = MlasGemmFloatKernelFma3;
this->GemmDoubleKernel = MlasGemmDoubleKernelFma3;
Expand All @@ -401,18 +415,7 @@ Return Value:
this->LogisticKernelRoutine = MlasComputeLogisticF32KernelFma3;
this->TanhKernelRoutine = MlasComputeTanhF32KernelFma3;
this->ErfKernelRoutine = MlasErfKernelFma3;
this->QLinearAddS8Kernel = MlasQLinearAddS8KernelAvx2;
this->QLinearAddU8Kernel = MlasQLinearAddU8KernelAvx2;
this->ConvDepthwiseU8S8Kernel = MlasConvDepthwiseKernelAvx2<uint8_t, int8_t>;
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernelAvx2<uint8_t, uint8_t>;
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;
this->RopeDispatch = &MlasRopeDispatchAvx2;


//
// Check if the processor supports Hybrid core architecture.
Expand All @@ -433,6 +436,7 @@ Return Value:
__cpuid_count(7, 1, Cpuid7_1[0], Cpuid7_1[1], Cpuid7_1[2], Cpuid7_1[3]);
#endif

#if !defined(ORT_DISABLE_AVX2)
if ((Cpuid7_1[0] & 0x10) != 0) {

this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2;
Expand All @@ -441,9 +445,11 @@ Return Value:
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni;
}
#endif // !defined(ORT_DISABLE_AVX2)

#if !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_DISABLE_AVX512)
//
// Check if the processor supports AVX512F features and the
// operating system supports saving AVX512F state.
Expand Down Expand Up @@ -499,7 +505,10 @@ Return Value:
}
}
}
#endif // !defined(ORT_DISABLE_AVX512)


#if !defined(ORT_DISABLE_AVX2)
//
// Check if the processor supports AVX-VNNI-INT8
//
Expand All @@ -510,18 +519,22 @@ Return Value:
this->GemmS8U8Dispatch = &MlasGemmS8U8DispatchAvx2Vnni;
this->GemmS8U8Kernel = MlasGemmS8U8KernelAvx2Vnni;
}
#endif // !defined(ORT_DISABLE_AVX2)

#ifndef __APPLE__
#if !defined(__APPLE__)
#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))
#if !defined(ORT_DISABLE_AVX)
//
// Check if the processor supports AVX NE CONVERT.
//
if ((Cpuid7_1[3] & (0b1 << 5)) != 0) {
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx;
}
#endif // !defined(ORT_DISABLE_AVX)
#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))


#if !defined(ORT_DISABLE_AMX)
//
// Check if the processor supports AMX-TILE and AMX-INT8
// features.
Expand All @@ -534,14 +547,16 @@ Return Value:
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx;
}
}
#endif // __APPLE__
#endif // !defined(ORT_DISABLE_AMX)
#endif // !defined(__APPLE__)

#endif // ORT_MINIMAL_BUILD

}

#endif // MLAS_TARGET_AMD64


}
}

Expand Down Expand Up @@ -797,4 +812,4 @@ thread_local size_t ThreadedBufSize = 0;
thread_local std::unique_ptr<uint8_t, decltype(&_aligned_free)> ThreadedBufHolder(nullptr, &_aligned_free);
#else
thread_local std::unique_ptr<uint8_t, decltype(&free)> ThreadedBufHolder(nullptr, &free);
#endif
#endif