Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions unified-runtime/source/adapters/offload/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,58 @@ ur_result_t doMemcpy(ur_command_t Command, ur_queue_handle_t hQueue,

return UR_RESULT_SUCCESS;
}

ur_result_t doMemcpyRect(
ur_command_t Command, ur_queue_handle_t hQueue, void *DestPtr,
ur_rect_offset_t DestOffset, size_t DestPitch, size_t DestSlice,
ol_device_handle_t DestDevice, void *SrcPtr, ur_rect_offset_t SrcOffset,
size_t SrcPitch, size_t SrcSlice, ol_device_handle_t SrcDevice,
ur_rect_region_t Region, bool blocking, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
ol_queue_handle_t Queue;
OL_RETURN_ON_ERR(hQueue->nextQueue(Queue));
OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList));

ol_memcpy_rect_t DestRect{DestPtr,
{static_cast<uint32_t>(DestOffset.x),
static_cast<uint32_t>(DestOffset.y),
static_cast<uint32_t>(DestOffset.z)},
DestPitch,
DestSlice};
ol_memcpy_rect_t SrcRect{SrcPtr,
{static_cast<uint32_t>(SrcOffset.x),
static_cast<uint32_t>(SrcOffset.y),
static_cast<uint32_t>(SrcOffset.z)},
SrcPitch,
SrcSlice};
ol_dimensions_t OlRegion{static_cast<uint32_t>(Region.width),
static_cast<uint32_t>(Region.height),
static_cast<uint32_t>(Region.depth)};

if (blocking) {
// Ensure all work in the queue is complete
OL_RETURN_ON_ERR(olSyncQueue(Queue));
OL_RETURN_ON_ERR(olMemcpyRect(nullptr, DestRect, DestDevice, SrcRect,
SrcDevice, OlRegion));
if (phEvent) {
*phEvent = ur_event_handle_t_::createEmptyEvent(Command, hQueue);
}
return UR_RESULT_SUCCESS;
}

OL_RETURN_ON_ERR(
olMemcpyRect(Queue, DestRect, DestDevice, SrcRect, SrcDevice, OlRegion));
if (phEvent) {
auto *Event = new ur_event_handle_t_(Command, hQueue);
if (auto Res = olCreateEvent(Queue, &Event->OffloadEvent)) {
delete Event;
return offloadResultToUR(Res);
};
*phEvent = Event;
}

return UR_RESULT_SUCCESS;
}
} // namespace

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
Expand All @@ -282,6 +334,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead,
ur_rect_offset_t bufferOrigin, ur_rect_offset_t hostOrigin,
ur_rect_region_t region, size_t bufferRowPitch, size_t bufferSlicePitch,
size_t hostRowPitch, size_t hostSlicePitch, void *pDst,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
char *DevPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);

return doMemcpyRect(UR_COMMAND_MEM_BUFFER_READ_RECT, hQueue, pDst, hostOrigin,
hostRowPitch, hostSlicePitch, Adapter->HostDevice, DevPtr,
bufferOrigin, bufferRowPitch, bufferSlicePitch,
hQueue->OffloadDevice, region, blockingRead,
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite,
size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
Expand All @@ -294,6 +363,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
blockingWrite, numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite,
ur_rect_offset_t bufferOrigin, ur_rect_offset_t hostOrigin,
ur_rect_region_t region, size_t bufferRowPitch, size_t bufferSlicePitch,
size_t hostRowPitch, size_t hostSlicePitch, void *pSrc,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
char *DevPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);

return doMemcpyRect(
UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, DevPtr, bufferOrigin,
bufferRowPitch, bufferSlicePitch, hQueue->OffloadDevice, pSrc, hostOrigin,
hostRowPitch, hostSlicePitch, Adapter->HostDevice, region, blockingWrite,
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc,
ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size,
Expand All @@ -310,6 +396,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc,
ur_mem_handle_t hBufferDst, ur_rect_offset_t srcOrigin,
ur_rect_offset_t dstOrigin, ur_rect_region_t region, size_t srcRowPitch,
size_t srcSlicePitch, size_t dstRowPitch, size_t dstSlicePitch,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
char *SrcPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBufferSrc->Mem).Ptr);
char *DstPtr =
reinterpret_cast<char *>(std::get<BufferMem>(hBufferDst->Mem).Ptr);

return doMemcpyRect(UR_COMMAND_MEM_BUFFER_COPY_RECT, hQueue, DstPtr,
dstOrigin, dstRowPitch, dstSlicePitch,
hQueue->OffloadDevice, SrcPtr, srcOrigin, srcRowPitch,
srcSlicePitch, hQueue->OffloadDevice, region, false,
numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, const void *pPattern,
size_t patternSize, size_t offset, size_t size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
pDdiTable->pfnEventsWaitWithBarrierExt = urEnqueueEventsWaitWithBarrierExt;
pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch;
pDdiTable->pfnMemBufferCopy = urEnqueueMemBufferCopy;
pDdiTable->pfnMemBufferCopyRect = nullptr;
pDdiTable->pfnMemBufferCopyRect = urEnqueueMemBufferCopyRect;
pDdiTable->pfnMemBufferFill = urEnqueueMemBufferFill;
pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap;
pDdiTable->pfnMemBufferRead = urEnqueueMemBufferRead;
pDdiTable->pfnMemBufferReadRect = nullptr;
pDdiTable->pfnMemBufferReadRect = urEnqueueMemBufferReadRect;
pDdiTable->pfnMemBufferWrite = urEnqueueMemBufferWrite;
pDdiTable->pfnMemBufferWriteRect = nullptr;
pDdiTable->pfnMemBufferWriteRect = urEnqueueMemBufferWriteRect;
pDdiTable->pfnMemImageCopy = nullptr;
pDdiTable->pfnMemImageRead = nullptr;
pDdiTable->pfnMemImageWrite = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,23 @@ TEST_P(urEnqueueMemBufferCopyRectTestWithParam, Success) {
// Zero destination buffer to begin with since the write may not cover the
// whole buffer.
const uint8_t zero = 0x0;
ur_event_handle_t FillEvent;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, dst_buffer, &zero, sizeof(zero),
0, dst_buffer_size, 0, nullptr,
nullptr));
&FillEvent));

// Enqueue the rectangular copy between the buffers.
ur_event_handle_t CopyEvent;
ASSERT_SUCCESS(urEnqueueMemBufferCopyRect(
queue, src_buffer, dst_buffer, src_buffer_origin, dst_buffer_origin,
region, src_buffer_row_pitch, src_buffer_slice_pitch,
dst_buffer_row_pitch, dst_buffer_slice_pitch, 0, nullptr, nullptr));
dst_buffer_row_pitch, dst_buffer_slice_pitch, 1, &FillEvent, &CopyEvent));

std::vector<uint8_t> output(dst_buffer_size, 0x0);
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, dst_buffer,
/* is_blocking */ true, 0,
dst_buffer_size, output.data(), 0,
nullptr, nullptr));
dst_buffer_size, output.data(), 1,
&CopyEvent, nullptr));

// Do host side equivalent.
std::vector<uint8_t> expected(dst_buffer_size, 0x0);
Expand Down
Loading