From 362d5e072384181c79d2cbf177514345d0e1f965 Mon Sep 17 00:00:00 2001 From: Ross Brunton Date: Thu, 25 Sep 2025 15:31:12 +0100 Subject: [PATCH] [UR][Offload] Implement 2D/3D memcpies on buffers --- .../source/adapters/offload/enqueue.cpp | 105 ++++++++++++++++++ .../adapters/offload/ur_interface_loader.cpp | 6 +- .../enqueue/urEnqueueMemBufferCopyRect.cpp | 10 +- 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index 3d419a30dc500..6e31bf9d52c01 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -268,6 +268,58 @@ ur_result_t doMemcpy(ur_command_t Command, ur_queue_handle_t hQueue, return UR_RESULT_SUCCESS; } + +ur_result_t doMemcpyRect( + ur_command_t Command, ur_queue_handle_t hQueue, void *DestPtr, + ur_rect_offset_t DestOffset, size_t DestPitch, size_t DestSlice, + ol_device_handle_t DestDevice, void *SrcPtr, ur_rect_offset_t SrcOffset, + size_t SrcPitch, size_t SrcSlice, ol_device_handle_t SrcDevice, + ur_rect_region_t Region, bool blocking, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); + OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList)); + + ol_memcpy_rect_t DestRect{DestPtr, + {static_cast(DestOffset.x), + static_cast(DestOffset.y), + static_cast(DestOffset.z)}, + DestPitch, + DestSlice}; + ol_memcpy_rect_t SrcRect{SrcPtr, + {static_cast(SrcOffset.x), + static_cast(SrcOffset.y), + static_cast(SrcOffset.z)}, + SrcPitch, + SrcSlice}; + ol_dimensions_t OlRegion{static_cast(Region.width), + static_cast(Region.height), + static_cast(Region.depth)}; + + if (blocking) { + // Ensure all work in the queue is complete + OL_RETURN_ON_ERR(olSyncQueue(Queue)); + OL_RETURN_ON_ERR(olMemcpyRect(nullptr, DestRect, DestDevice, SrcRect, + SrcDevice, OlRegion)); + if (phEvent) { + *phEvent = ur_event_handle_t_::createEmptyEvent(Command, hQueue); + } + return UR_RESULT_SUCCESS; + } + + OL_RETURN_ON_ERR( + olMemcpyRect(Queue, DestRect, DestDevice, SrcRect, SrcDevice, OlRegion)); + if (phEvent) { + auto *Event = new ur_event_handle_t_(Command, hQueue); + if (auto Res = olCreateEvent(Queue, &Event->OffloadEvent)) { + delete Event; + return offloadResultToUR(Res); + }; + *phEvent = Event; + } + + return UR_RESULT_SUCCESS; +} } // namespace UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( @@ -282,6 +334,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( numEventsInWaitList, phEventWaitList, phEvent); } +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect( + ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead, + ur_rect_offset_t bufferOrigin, ur_rect_offset_t hostOrigin, + ur_rect_region_t region, size_t bufferRowPitch, size_t bufferSlicePitch, + size_t hostRowPitch, size_t hostSlicePitch, void *pDst, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + char *DevPtr = + reinterpret_cast(std::get(hBuffer->Mem).Ptr); + + return doMemcpyRect(UR_COMMAND_MEM_BUFFER_READ_RECT, hQueue, pDst, hostOrigin, + hostRowPitch, hostSlicePitch, Adapter->HostDevice, DevPtr, + bufferOrigin, bufferRowPitch, bufferSlicePitch, + hQueue->OffloadDevice, region, blockingRead, + numEventsInWaitList, phEventWaitList, phEvent); +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite, size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList, @@ -294,6 +363,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( blockingWrite, numEventsInWaitList, phEventWaitList, phEvent); } +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( + ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite, + ur_rect_offset_t bufferOrigin, ur_rect_offset_t hostOrigin, + ur_rect_region_t region, size_t bufferRowPitch, size_t bufferSlicePitch, + size_t hostRowPitch, size_t hostSlicePitch, void *pSrc, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + char *DevPtr = + reinterpret_cast(std::get(hBuffer->Mem).Ptr); + + return doMemcpyRect( + UR_COMMAND_MEM_BUFFER_WRITE_RECT, hQueue, DevPtr, bufferOrigin, + bufferRowPitch, bufferSlicePitch, hQueue->OffloadDevice, pSrc, hostOrigin, + hostRowPitch, hostSlicePitch, Adapter->HostDevice, region, blockingWrite, + numEventsInWaitList, phEventWaitList, phEvent); +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc, ur_mem_handle_t hBufferDst, size_t srcOffset, size_t dstOffset, size_t size, @@ -310,6 +396,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( phEventWaitList, phEvent); } +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( + ur_queue_handle_t hQueue, ur_mem_handle_t hBufferSrc, + ur_mem_handle_t hBufferDst, ur_rect_offset_t srcOrigin, + ur_rect_offset_t dstOrigin, ur_rect_region_t region, size_t srcRowPitch, + size_t srcSlicePitch, size_t dstRowPitch, size_t dstSlicePitch, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + char *SrcPtr = + reinterpret_cast(std::get(hBufferSrc->Mem).Ptr); + char *DstPtr = + reinterpret_cast(std::get(hBufferDst->Mem).Ptr); + + return doMemcpyRect(UR_COMMAND_MEM_BUFFER_COPY_RECT, hQueue, DstPtr, + dstOrigin, dstRowPitch, dstSlicePitch, + hQueue->OffloadDevice, SrcPtr, srcOrigin, srcRowPitch, + srcSlicePitch, hQueue->OffloadDevice, region, false, + numEventsInWaitList, phEventWaitList, phEvent); +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill( ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, const void *pPattern, size_t patternSize, size_t offset, size_t size, diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index fa3adfdaf92bd..0054f6cf8e8dc 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -175,13 +175,13 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnEventsWaitWithBarrierExt = urEnqueueEventsWaitWithBarrierExt; pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch; pDdiTable->pfnMemBufferCopy = urEnqueueMemBufferCopy; - pDdiTable->pfnMemBufferCopyRect = nullptr; + pDdiTable->pfnMemBufferCopyRect = urEnqueueMemBufferCopyRect; pDdiTable->pfnMemBufferFill = urEnqueueMemBufferFill; pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap; pDdiTable->pfnMemBufferRead = urEnqueueMemBufferRead; - pDdiTable->pfnMemBufferReadRect = nullptr; + pDdiTable->pfnMemBufferReadRect = urEnqueueMemBufferReadRect; pDdiTable->pfnMemBufferWrite = urEnqueueMemBufferWrite; - pDdiTable->pfnMemBufferWriteRect = nullptr; + pDdiTable->pfnMemBufferWriteRect = urEnqueueMemBufferWriteRect; pDdiTable->pfnMemImageCopy = nullptr; pDdiTable->pfnMemImageRead = nullptr; pDdiTable->pfnMemImageWrite = nullptr; diff --git a/unified-runtime/test/conformance/enqueue/urEnqueueMemBufferCopyRect.cpp b/unified-runtime/test/conformance/enqueue/urEnqueueMemBufferCopyRect.cpp index 6cde8bf1eae85..d5685604a9ee4 100644 --- a/unified-runtime/test/conformance/enqueue/urEnqueueMemBufferCopyRect.cpp +++ b/unified-runtime/test/conformance/enqueue/urEnqueueMemBufferCopyRect.cpp @@ -113,21 +113,23 @@ TEST_P(urEnqueueMemBufferCopyRectTestWithParam, Success) { // Zero destination buffer to begin with since the write may not cover the // whole buffer. const uint8_t zero = 0x0; + ur_event_handle_t FillEvent; ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, dst_buffer, &zero, sizeof(zero), 0, dst_buffer_size, 0, nullptr, - nullptr)); + &FillEvent)); // Enqueue the rectangular copy between the buffers. + ur_event_handle_t CopyEvent; ASSERT_SUCCESS(urEnqueueMemBufferCopyRect( queue, src_buffer, dst_buffer, src_buffer_origin, dst_buffer_origin, region, src_buffer_row_pitch, src_buffer_slice_pitch, - dst_buffer_row_pitch, dst_buffer_slice_pitch, 0, nullptr, nullptr)); + dst_buffer_row_pitch, dst_buffer_slice_pitch, 1, &FillEvent, &CopyEvent)); std::vector output(dst_buffer_size, 0x0); ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, dst_buffer, /* is_blocking */ true, 0, - dst_buffer_size, output.data(), 0, - nullptr, nullptr)); + dst_buffer_size, output.data(), 1, + &CopyEvent, nullptr)); // Do host side equivalent. std::vector expected(dst_buffer_size, 0x0);