22
22
struct ur_kernel_handle_t_ : RefCounted {
23
23
24
24
// Simplified version of the CUDA adapter's argument implementation
25
- struct OffloadKernelArguments {
25
+ struct alignas ( 32 ) OffloadKernelArguments {
26
26
static constexpr size_t MaxParamBytes = 4096u ;
27
- using args_t = std::array<char , MaxParamBytes>;
27
+ using final_buffer_t = std::array<char , MaxParamBytes>;
28
+ using args_t = std::vector<char >;
28
29
using args_size_t = std::vector<size_t >;
29
- using args_ptr_t = std::vector<void * >;
30
- args_t Storage ;
31
- size_t StorageUsed = 0 ;
30
+ using args_offset_t = std::vector<size_t >;
31
+ final_buffer_t RealisedBuffer ;
32
+ args_t ParamStorage ;
32
33
args_size_t ParamSizes;
33
- args_ptr_t Pointers;
34
+ args_offset_t Pointers;
35
+ bool Dirty = true ;
36
+ size_t RealisedSpace;
34
37
35
38
struct MemObjArg {
36
39
ur_mem_handle_t_ *Mem;
@@ -47,12 +50,12 @@ struct ur_kernel_handle_t_ : RefCounted {
47
50
ParamSizes.resize (Index + 1 );
48
51
}
49
52
ParamSizes[Index] = Size;
50
- // Calculate the insertion point in the array.
51
- size_t InsertPos = std::accumulate ( std::begin (ParamSizes),
52
- std::begin (ParamSizes) + Index, 0 );
53
- // Update the stored value for the argument.
54
- std::memcpy (&Storage[InsertPos], Arg, Size) ;
55
- Pointers[Index] = &Storage[InsertPos] ;
53
+
54
+ auto Base = ParamStorage. size ();
55
+ ParamStorage. resize (Base + Size );
56
+ std::memcpy (&ParamStorage[Base], Arg, Size);
57
+ Pointers[Index] = Base ;
58
+ Dirty = true ;
56
59
}
57
60
58
61
void addMemObjArg (int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -66,14 +69,44 @@ struct ur_kernel_handle_t_ : RefCounted {
66
69
}
67
70
}
68
71
MemObjArgs.push_back (MemObjArg{hMem, Index, Flags});
72
+ Dirty = true ;
69
73
}
70
74
71
- const args_ptr_t &getPointers () const noexcept { return Pointers; }
75
+ void realise () noexcept {
76
+ if (!Dirty) {
77
+ return ;
78
+ }
79
+
80
+ size_t Space = sizeof (RealisedBuffer);
81
+ void *Offset = &RealisedBuffer[0 ];
82
+ for (size_t I = 0 ; I < Pointers.size (); I++) {
83
+ void *ValueBase = &ParamStorage[Pointers[I]];
84
+ size_t Size = ParamSizes[I];
85
+
86
+ // Align the value to a multiple of the size
87
+ // TODO: This is probably not correct, but UR doesn't allow specifying
88
+ // the alignment of arguments
89
+ if (!std::align (Size, Size, Offset, Space)) {
90
+ // Ran out of space. TODO: Handle properly
91
+ abort ();
92
+ }
93
+
94
+ std::memcpy (Offset, ValueBase, Size);
95
+ Offset = &reinterpret_cast <char *>(Offset)[Size];
96
+ }
72
97
73
- const char *getStorage () const noexcept { return Storage.data (); }
98
+ Dirty = false ;
99
+ RealisedSpace = reinterpret_cast <char *>(Offset) - &RealisedBuffer[0 ];
100
+ }
101
+
102
+ const char *getStorage () noexcept {
103
+ realise ();
104
+ return &RealisedBuffer[0 ];
105
+ }
74
106
75
- size_t getStorageSize () const noexcept {
76
- return std::accumulate (std::begin (ParamSizes), std::end (ParamSizes), 0 );
107
+ size_t getStorageSize () noexcept {
108
+ realise ();
109
+ return RealisedSpace;
77
110
}
78
111
};
79
112
0 commit comments