Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions unified-runtime/source/adapters/offload/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@
struct ur_kernel_handle_t_ : RefCounted {

// Simplified version of the CUDA adapter's argument implementation
struct OffloadKernelArguments {
struct alignas(32) OffloadKernelArguments {
Copy link
Contributor

@pbalcer pbalcer Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what this is for. Is this what CUDA/HIP drivers expect? Shouldn't we move this requirement/functionality into the plugins?

static constexpr size_t MaxParamBytes = 4096u;
using args_t = std::array<char, MaxParamBytes>;
using final_buffer_t = std::array<char, MaxParamBytes>;
using args_t = std::vector<char>;
using args_size_t = std::vector<size_t>;
using args_ptr_t = std::vector<void *>;
args_t Storage;
size_t StorageUsed = 0;
using args_offset_t = std::vector<size_t>;
final_buffer_t RealisedBuffer;
args_t ParamStorage;
args_size_t ParamSizes;
args_ptr_t Pointers;
args_offset_t Pointers;
bool Dirty = true;
size_t RealisedSpace;

struct MemObjArg {
ur_mem_handle_t_ *Mem;
Expand All @@ -47,12 +50,12 @@ struct ur_kernel_handle_t_ : RefCounted {
ParamSizes.resize(Index + 1);
}
ParamSizes[Index] = Size;
// Calculate the insertion point in the array.
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
std::begin(ParamSizes) + Index, 0);
// Update the stored value for the argument.
std::memcpy(&Storage[InsertPos], Arg, Size);
Pointers[Index] = &Storage[InsertPos];

auto Base = ParamStorage.size();
ParamStorage.resize(Base + Size);
std::memcpy(&ParamStorage[Base], Arg, Size);
Pointers[Index] = Base;
Dirty = true;
}

void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
Expand All @@ -66,14 +69,44 @@ struct ur_kernel_handle_t_ : RefCounted {
}
}
MemObjArgs.push_back(MemObjArg{hMem, Index, Flags});
Dirty = true;
}

const args_ptr_t &getPointers() const noexcept { return Pointers; }
void realise() noexcept {
if (!Dirty) {
return;
}

size_t Space = sizeof(RealisedBuffer);
void *Offset = &RealisedBuffer[0];
for (size_t I = 0; I < Pointers.size(); I++) {
void *ValueBase = &ParamStorage[Pointers[I]];
size_t Size = ParamSizes[I];

// Align the value to a multiple of the size
// TODO: This is probably not correct, but UR doesn't allow specifying
// the alignment of arguments
if (!std::align(Size, Size, Offset, Space)) {
// Ran out of space. TODO: Handle properly
abort();
}

std::memcpy(Offset, ValueBase, Size);
Offset = &reinterpret_cast<char *>(Offset)[Size];
}

const char *getStorage() const noexcept { return Storage.data(); }
Dirty = false;
RealisedSpace = reinterpret_cast<char *>(Offset) - &RealisedBuffer[0];
}

const char *getStorage() noexcept {
realise();
return &RealisedBuffer[0];
}

size_t getStorageSize() const noexcept {
return std::accumulate(std::begin(ParamSizes), std::end(ParamSizes), 0);
size_t getStorageSize() noexcept {
realise();
return RealisedSpace;
}
};

Expand Down
Loading