Skip to content

Commit cb2257c

Browse files
committed
CUDA pinned memory fixes
1 parent 1414bc4 commit cb2257c

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "core/common/inlined_containers.h"
88
#include "core/framework/allocator.h"
9+
#include "core/providers/shared_library/provider_api.h"
910
#include <mutex>
1011

1112
namespace onnxruntime {
@@ -56,9 +57,13 @@ class CUDAPinnedAllocator : public IAllocator {
5657
CUDAPinnedAllocator(const char* name, OrtDevice::DeviceId device_id)
5758
: IAllocator(
5859
OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator,
59-
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
60-
device_id),
61-
OrtMemTypeCPUOutput)) {}
60+
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
61+
0 /*CPU device always with id 0*/),
62+
OrtMemTypeCPUOutput)) {
63+
if (device_id != 0) {
64+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Device ID was non 0 for a CudaPinned allocator. The ID will be ignored.";
65+
}
66+
}
6267

6368
void* Alloc(size_t size) override;
6469
void Free(void* p) override;

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ std::vector<AllocatorPtr> NvExecutionProvider::CreatePreferredAllocators() {
13041304

13051305
AllocatorCreationInfo pinned_allocator_info(
13061306
[](OrtDevice::DeviceId device_id) {
1307-
return std::make_unique<CUDAPinnedAllocator>(device_id, CUDA_PINNED);
1307+
return std::make_unique<CUDAPinnedAllocator>(CUDA_PINNED, device_id);
13081308
},
13091309
narrow<OrtDevice::DeviceId>(device_id_));
13101310

@@ -3258,8 +3258,8 @@ OrtDevice NvExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const
32583258
if (mem_type == OrtMemTypeCPUInput)
32593259
return OrtDevice();
32603260
if (mem_type == OrtMemTypeCPUOutput)
3261-
return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
3262-
default_device_.Id());
3261+
return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
3262+
0 /*CPU device id always be 0*/);
32633263
return default_device_;
32643264
}
32653265

0 commit comments

Comments
 (0)