Skip to content

Commit ab67572

Browse files
committed
memory resident limit to enabled peers
1 parent ca92dc1 commit ab67572

File tree

12 files changed

+238
-10
lines changed

12 files changed

+238
-10
lines changed

unified-runtime/source/adapters/level_zero/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ if(UR_BUILD_ADAPTER_L0_V2)
150150
${CMAKE_CURRENT_SOURCE_DIR}/helpers/kernel_helpers.cpp
151151
${CMAKE_CURRENT_SOURCE_DIR}/helpers/memory_helpers.cpp
152152
${CMAKE_CURRENT_SOURCE_DIR}/helpers/mutable_helpers.cpp
153-
${CMAKE_CURRENT_SOURCE_DIR}/usm_p2p.cpp
154153
${CMAKE_CURRENT_SOURCE_DIR}/virtual_mem.cpp
155154
${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.cpp
156155
${CMAKE_CURRENT_SOURCE_DIR}/sampler.hpp
@@ -188,6 +187,7 @@ if(UR_BUILD_ADAPTER_L0_V2)
188187
${CMAKE_CURRENT_SOURCE_DIR}/v2/queue_create.cpp
189188
${CMAKE_CURRENT_SOURCE_DIR}/v2/queue_immediate_in_order.cpp
190189
${CMAKE_CURRENT_SOURCE_DIR}/v2/usm.cpp
190+
${CMAKE_CURRENT_SOURCE_DIR}/v2/usm_p2p.cpp
191191
)
192192
install_ur_library(ur_adapter_level_zero_v2)
193193

unified-runtime/source/adapters/level_zero/context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ ur_result_t urContextCreate(
4242

4343
Context->initialize();
4444
*RetContext = reinterpret_cast<ur_context_handle_t>(Context);
45+
// TODO: delete below 'if' when memory isolation in the context is implemented in the driver
4546
if (IndirectAccessTrackingEnabled) {
4647
std::scoped_lock<ur_shared_mutex> Lock(Platform->ContextsMutex);
4748
Platform->Contexts.push_back(*RetContext);

unified-runtime/source/adapters/level_zero/device.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <stdarg.h>
1717
#include <string>
1818
#include <unordered_map>
19+
#include <unordered_set>
1920
#include <vector>
2021

2122
#include "adapters/level_zero/platform.hpp"
@@ -236,6 +237,9 @@ struct ur_device_handle_t_ : ur_object {
236237
std::unordered_map<ur_exp_image_native_handle_t, ze_image_handle_t>
237238
ZeOffsetToImageHandleMap;
238239

240+
// devices which user enabled p2p access by urUsmP2P(Enable|Disable)PeerAccessExp
241+
std::unordered_set<DeviceId> p2pDeviceIds;
242+
239243
// unique ephemeral identifer of the device in the adapter
240244
std::optional<DeviceId> Id;
241245
};

unified-runtime/source/adapters/level_zero/platform.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,10 @@ struct ur_platform_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter>,
9191
uint32_t VersionMinor,
9292
uint32_t VersionBuild);
9393

94-
// Keep track of all contexts in the platform. This is needed to manage
94+
// Keep track of all contexts in the platform. In v1 L0 this is needed to manage
9595
// a lifetime of memory allocations in each context when there are kernels
96-
// with indirect access.
97-
// TODO: should be deleted when memory isolation in the context is implemented
98-
// in the driver.
96+
// with indirect access. In v2 it is used during ext_oneapi_enable_peer_access
97+
// and ext_oneapi_disable_peer_access calls.
9998
std::list<ur_context_handle_t> Contexts;
10099
ur_shared_mutex ContextsMutex;
101100

unified-runtime/source/adapters/level_zero/v2/context.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,24 @@ static std::vector<ur_device_handle_t>
1818
filterP2PDevices(ur_device_handle_t hSourceDevice,
1919
const std::vector<ur_device_handle_t> &devices) {
2020
std::vector<ur_device_handle_t> p2pDevices;
21+
22+
std::optional<std::unordered_set<DeviceId>> p2pDevicesEnabledByUser;
23+
2124
for (auto &device : devices) {
2225
if (device == hSourceDevice) {
2326
continue;
2427
}
2528

29+
if (!p2pDevicesEnabledByUser.has_value())
30+
{
31+
std::shared_lock<ur_shared_mutex> Lock(hSourceDevice->Mutex);
32+
p2pDevicesEnabledByUser.emplace(hSourceDevice->p2pDeviceIds);
33+
}
34+
35+
if (p2pDevicesEnabledByUser->count(*device->Id) == 0) {
36+
continue;
37+
}
38+
2639
ze_bool_t p2p;
2740
ZE2UR_CALL_THROWS(zeDeviceCanAccessPeer,
2841
(device->ZeDevice, hSourceDevice->ZeDevice, &p2p));
@@ -126,8 +139,39 @@ void ur_context_handle_t_::removeUsmPool(ur_usm_pool_handle_t hPool) {
126139
usmPoolHandles.remove(hPool);
127140
}
128141

129-
const std::vector<ur_device_handle_t> &
130-
ur_context_handle_t_::getP2PDevices(ur_device_handle_t hDevice) const {
142+
void ur_context_handle_t_::addResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t newPeerDevice) {
143+
std::scoped_lock<ur_shared_mutex> lock(Mutex);
144+
auto & pDevices = p2pAccessDevices.at(hDevice->Id);
145+
146+
assert(0 = std::count_if(
147+
std::begin(pDevices), std::end(pDevices),
148+
[&](const auto pDevice) { return newPeerDevice->Id == pDevice->Id; }));
149+
150+
pDevices.push_back(newPeerDevice);
151+
}
152+
void ur_context_handle_t_::removeResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t oldPeerDevice) {
153+
std::scoped_lock<ur_shared_mutex> lock(Mutex);
154+
auto & pDevices = p2pAccessDevices.at(hDevice->Id);
155+
156+
const auto & findOldDevice = [&] {
157+
return std::find_if(
158+
std::begin(pDevices), std::end(pDevices),
159+
[oldPeerDevice](const auto pDevice) { return oldPeerDevice->Id == pDevice->Id; });
160+
};
161+
162+
auto pDeviceIt = findOldDevice();
163+
assert(pDeviceIt != std::end(pDevices));
164+
pDevices.erase(pDeviceIt);
165+
assert(findOldDevice() == std::end(pDevices));
166+
167+
for (auto poolHandle : usmPoolHandles) {
168+
poolHandle->removeResidentDevice();
169+
}
170+
}
171+
172+
std::vector<ur_device_handle_t>
173+
ur_context_handle_t_::getP2PDevices(ur_device_handle_t hDevice) {
174+
std::scoped_lock<ur_shared_mutex> lock(Mutex);
131175
return p2pAccessDevices[hDevice->Id.value()];
132176
}
133177

@@ -145,6 +189,10 @@ ur_result_t urContextCreate(uint32_t deviceCount,
145189

146190
*phContext =
147191
new ur_context_handle_t_(zeContext, deviceCount, phDevices, true);
192+
{
193+
std::scoped_lock<ur_shared_mutex> Lock(hPlatform->ContextsMutex);
194+
hPlatform->Contexts.push_back(*phContext);
195+
}
148196
return UR_RESULT_SUCCESS;
149197
} catch (...) {
150198
return exceptionToResult(std::current_exception());
@@ -182,6 +230,14 @@ ur_result_t urContextRetain(ur_context_handle_t hContext) try {
182230
}
183231

184232
ur_result_t urContextRelease(ur_context_handle_t hContext) try {
233+
auto Platform = hContext->getPlatform();
234+
auto &Contexts = Platform->Contexts;
235+
{
236+
std::scoped_lock<ur_shared_mutex> Lock(Platform->ContextsMutex);
237+
auto It = std::find(Contexts.begin(), Contexts.end(), hContext);
238+
UR_ASSERT(It != Contexts.end(), UR_RESULT_ERROR_INVALID_CONTEXT);
239+
Contexts.erase(It);
240+
}
185241
return hContext->release();
186242
} catch (...) {
187243
return exceptionToResult(std::current_exception());

unified-runtime/source/adapters/level_zero/v2/context.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct ur_context_handle_t_ : ur_object {
3535

3636
void addUsmPool(ur_usm_pool_handle_t hPool);
3737
void removeUsmPool(ur_usm_pool_handle_t hPool);
38+
void addResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t peerDevice);
39+
void removeResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t peerDevice);
3840

3941
template <typename Func> void forEachUsmPool(Func func) {
4042
std::shared_lock<ur_shared_mutex> lock(Mutex);
@@ -44,8 +46,8 @@ struct ur_context_handle_t_ : ur_object {
4446
}
4547
}
4648

47-
const std::vector<ur_device_handle_t> &
48-
getP2PDevices(ur_device_handle_t hDevice) const;
49+
std::vector<ur_device_handle_t>
50+
getP2PDevices(ur_device_handle_t hDevice);
4951

5052
v2::event_pool &getNativeEventsPool() { return nativeEventsPool; }
5153
v2::command_list_cache_t &getCommandListCache() { return commandListCache; }

unified-runtime/source/adapters/level_zero/v2/memory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ void *ur_discrete_buffer_handle_t::getDevicePtr(
273273
return getActiveDeviceAlloc(offset);
274274
}
275275

276-
auto &p2pDevices = hContext->getP2PDevices(hDevice);
276+
auto p2pDevices = hContext->getP2PDevices(hDevice);
277277
auto p2pAccessible = std::find(p2pDevices.begin(), p2pDevices.end(),
278278
activeAllocationDevice) != p2pDevices.end();
279279

unified-runtime/source/adapters/level_zero/v2/usm.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
#include <umf/providers/provider_level_zero.h>
1919

20+
#include "memory_pool_internal.h"
21+
2022
static inline void UMF_CALL_THROWS(umf_result_t res) {
2123
if (res != UMF_RESULT_SUCCESS) {
2224
throw res;
@@ -311,6 +313,14 @@ void ur_usm_pool_handle_t_::cleanupPoolsForQueue(void *hQueue) {
311313
});
312314
}
313315

316+
void ur_usm_pool_handle_t_::addResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t peerDevice) {
317+
poolManager.addResidentDevice(hDevice, peerDevice);
318+
}
319+
320+
void ur_usm_pool_handle_t_::removeResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t peerDevice) {
321+
poolManager.removeResidentDevice(hDevice, peerDevice);
322+
}
323+
314324
namespace ur::level_zero {
315325
ur_result_t urUSMPoolCreate(
316326
/// [in] handle of the context object

unified-runtime/source/adapters/level_zero/v2/usm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ struct ur_usm_pool_handle_t_ : ur_object {
4848

4949
void cleanupPools();
5050
void cleanupPoolsForQueue(void *hQueue);
51+
void addResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t peerDevice);
52+
void removeResidentDevice(ur_device_handle_t hDevice, ur_device_handle_t peerDevice);
5153

5254
private:
5355
ur_context_handle_t hContext;
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//===----------- usm_p2p.cpp - L0 Adapter ---------------------------------===//
2+
//
3+
// Copyright (C) 2023 Intel Corporation
4+
//
5+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See LICENSE.TXT
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "logger/ur_logger.hpp"
12+
#include "ur_level_zero.hpp"
13+
14+
namespace ur::level_zero {
15+
16+
ur_result_t urUsmP2PEnablePeerAccessExp(ur_device_handle_t commandDevice,
17+
ur_device_handle_t peerDevice) {
18+
{
19+
std::shared_lock<ur_shared_mutex> Lock(commandDevice->Mutex);
20+
commandDevice->p2pDeviceIds.insert(*peerDevice->Id);
21+
}
22+
23+
auto Platform = commandDevice->Platform;
24+
{
25+
std::scoped_lock<ur_shared_mutex> Lock(Platform->ContextsMutex);
26+
for (auto Context : Platform->Contexts) {
27+
Context->addResidentDevice(commandDevice, peerDevice);
28+
}
29+
}
30+
31+
return UR_RESULT_SUCCESS;
32+
}
33+
34+
ur_result_t urUsmP2PDisablePeerAccessExp(ur_device_handle_t commandDevice,
35+
ur_device_handle_t peerDevice) {
36+
{
37+
std::shared_lock<ur_shared_mutex> Lock(commandDevice->Mutex);
38+
commandDevice->p2pDeviceIds.erase(*peerDevice->Id);
39+
}
40+
41+
42+
auto Platform = commandDevice->Platform;
43+
{
44+
std::scoped_lock<ur_shared_mutex> Lock(Platform->ContextsMutex);
45+
for (auto Context : Platform->Contexts) {
46+
Context->removeResidentDevice(commandDevice, peerDevice);
47+
}
48+
}
49+
50+
51+
std::shared_lock<ur_shared_mutex> Lock(commandDevice->Mutex);
52+
commandDevice->p2pDeviceIds.erase(*peerDevice->Id);
53+
return UR_RESULT_SUCCESS;
54+
}
55+
56+
ur_result_t urUsmP2PPeerAccessGetInfoExp(ur_device_handle_t commandDevice,
57+
ur_device_handle_t peerDevice,
58+
ur_exp_peer_info_t propName,
59+
size_t propSize, void *pPropValue,
60+
size_t *pPropSizeRet) {
61+
62+
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);
63+
64+
int propertyValue = 0;
65+
switch (propName) {
66+
case UR_EXP_PEER_INFO_UR_PEER_ACCESS_SUPPORT: {
67+
bool p2pAccessSupported = false;
68+
ZeStruct<ze_device_p2p_properties_t> p2pProperties;
69+
ZE2UR_CALL(zeDeviceGetP2PProperties,
70+
(commandDevice->ZeDevice, peerDevice->ZeDevice, &p2pProperties));
71+
if (p2pProperties.flags & ZE_DEVICE_P2P_PROPERTY_FLAG_ACCESS) {
72+
p2pAccessSupported = true;
73+
}
74+
ze_bool_t p2pDeviceSupported = false;
75+
ZE2UR_CALL(
76+
zeDeviceCanAccessPeer,
77+
(commandDevice->ZeDevice, peerDevice->ZeDevice, &p2pDeviceSupported));
78+
propertyValue = p2pAccessSupported && p2pDeviceSupported;
79+
break;
80+
}
81+
case UR_EXP_PEER_INFO_UR_PEER_ATOMICS_SUPPORT: {
82+
ZeStruct<ze_device_p2p_properties_t> p2pProperties;
83+
ZE2UR_CALL(zeDeviceGetP2PProperties,
84+
(commandDevice->ZeDevice, peerDevice->ZeDevice, &p2pProperties));
85+
propertyValue = p2pProperties.flags & ZE_DEVICE_P2P_PROPERTY_FLAG_ATOMICS;
86+
break;
87+
}
88+
default: {
89+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
90+
}
91+
}
92+
93+
return ReturnValue(propertyValue);
94+
}
95+
} // namespace ur::level_zero

0 commit comments

Comments
 (0)