Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,16 @@ MlasConv(
MLAS_THREADPOOL* ThreadPool
);
}


inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) {
#if defined(__has_builtin)
# if __has_builtin(__builtin_mul_overflow)
return __builtin_mul_overflow(a, b, out);
# endif
#endif
// Fallback to manual check if builtin not available
if (b != 0 && a > SIZE_MAX / b) return true;
if (out) *out = a * b;
return false;
}
115 changes: 77 additions & 38 deletions onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
#include "mlasi_kleidiai.h"


// Thread-local reusable buffers to reduce allocation overhead across tiles.
struct KaiTlsBuffers {
std::vector<float> output_tile;
std::vector<float> bias_zero;
std::vector<std::byte> rhs_packed;
std::vector<std::byte> lhs_packed;
};
static thread_local KaiTlsBuffers g_kai_tls;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that ORT as a library could not control when a thread_local variable is deallocated. As the destructor of this var does not depend on anything except the C++ runtime, I feel it's fine.


size_t
MLASCALL
ArmKleidiAI::MlasGemmPackBSize(
Expand Down Expand Up @@ -51,7 +61,6 @@ Return Value:
// Compute the number of bytes required to hold the packed buffer.
//
size_t bytes = 0;

if (TransA == CblasNoTrans) {
switch (TransB) {
case CblasNoTrans:
Expand Down Expand Up @@ -125,15 +134,18 @@ Return Value:
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

// pass zeroed bias values
const std::vector<float> bias(N);
// Ensure size and zero the used span (only newly added elements are value-initialized by resize).
if (g_kai_tls.bias_zero.size() != N) {
g_kai_tls.bias_zero.resize(N);
}
std::fill_n(g_kai_tls.bias_zero.data(), N, 0.0f);

switch (TransB) {
case CblasNoTrans:
kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr);
kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr);
break;
case CblasTrans:
kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr);
kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr);
break;
default:
return false;
Expand Down Expand Up @@ -225,59 +237,69 @@ Return Value:
size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

if (M < m_step && N < n_step && !Data->BIsPacked) {
if ((M < m_step || N < n_step) && !Data->BIsPacked) {
// Fallback to MLAS
return false;
}

std::vector<MLAS_SGEMM_DATA_PARAMS> KaiPackedData;
KaiPackedData.resize(BatchSize);

size_t LhsPackedStride = 0;
std::byte* LhsPackedData = nullptr;

LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr);
auto LhsPacked = std::make_unique<std::byte[]>(LhsPackedStride * BatchSize);
LhsPackedData = LhsPacked.get();

std::unique_ptr<std::byte[]> RhsPacked{nullptr};
size_t lhs_resize = 0;
if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize))
{
// size_t wraparound detected, exit
return false;
}

if (g_kai_tls.lhs_packed.capacity() < lhs_resize) {

g_kai_tls.lhs_packed.reserve(lhs_resize);
}
g_kai_tls.lhs_packed.resize(lhs_resize);
LhsPackedData = g_kai_tls.lhs_packed.data();

// RHS packed buffer: use TLS reusable vector to minimize allocations
size_t RhsPackedStride = 0;
std::byte* RhsPackedData = nullptr;

// It is assumed all B batches require packing or not
if (Data[0].BIsPacked) {
// We have already decided the matmul variant we are using, before having values for M,N,K
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);
kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);
KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
KaiPackedData[batch_idx].B = Data[batch_idx].B;
});
} else {
// Multithread pack lhs and rhs
size_t RhsPackedStride = 0;
std::byte* RhsPackedData = nullptr;

RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K);
RhsPacked = std::make_unique<std::byte[]>(RhsPackedStride * BatchSize);
RhsPackedData = RhsPacked.get();
size_t rhs_resize = 0;
if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize))
{
// size_t wraparound detected, exit
return false;
}

if (g_kai_tls.rhs_packed.capacity() < rhs_resize) {
g_kai_tls.rhs_packed.reserve(rhs_resize);
}

g_kai_tls.rhs_packed.resize(rhs_resize);
RhsPackedData = g_kai_tls.rhs_packed.data();

MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) {
// lhs odd, rhs even
if (batch_idx & 0x1) {
batch_idx >>= 1;

std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);

kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);

KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
} else {
batch_idx >>= 1;

std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]);

ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast<const float*>(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr);

KaiPackedData[batch_idx].B = reinterpret_cast<const float*>(RhsPackedPtr);
ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K,
reinterpret_cast<const float*>(Data[batch_idx].B),
Data[batch_idx].ldb, RhsPackedPtr);
}
});
}
Expand All @@ -303,6 +325,14 @@ Return Value:
dim[1] = MlasDivRoundup(M, m_step);
dim[2] = MlasDivRoundup(N, n_step);

// Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop.
// Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively.
size_t max_tile_elems = 0;
if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) {
// size_t wraparound detected for tile size, fallback to MLAS
return false;
}

MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) {
// compute B,M,N index from iteration index
ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
Expand All @@ -314,18 +344,18 @@ Return Value:
UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K)
: kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K);

auto BTile = reinterpret_cast<const void*>(
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].B) + rhs_packed_offset
);
const std::byte* B_base = Data[0].BIsPacked
? reinterpret_cast<const std::byte*>(Data[BIdx].B)
: (RhsPackedData + RhsPackedStride * BIdx);
auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset);

// Get lhs tile, A
const size_t lhs_packed_offset =
UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K)
: kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K);

auto ATile = reinterpret_cast<const float*>(
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].A) + lhs_packed_offset
);
const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx;
auto ATile = reinterpret_cast<const float*>(A_base + lhs_packed_offset);

auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step;
auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step;
Expand All @@ -336,9 +366,18 @@ Return Value:
MIdx * m_step * Data[BIdx].ldc * sizeof(float) +
NIdx * n_step * sizeof(float)
);
// Allocate temporary buffer for raw A*B result
std::vector<float> OutputTile(TileSizeM * TileSizeN, 0.0f);
float* temp_tile = OutputTile.data();
// Allocate temporary buffer for raw A*B result (TLS reusable buffer)
size_t tile_elems = TileSizeM * TileSizeN;

if (g_kai_tls.output_tile.capacity() < tile_elems) {

g_kai_tls.output_tile.reserve(tile_elems);
}
// resize the tile to the required size
g_kai_tls.output_tile.resize(tile_elems);

float* temp_tile = g_kai_tls.output_tile.data();
std::fill_n(temp_tile, tile_elems, 0.0f);

if (UseSME2) {
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(
Expand Down
Loading