@@ -18,11 +18,24 @@ static std::vector<ur_device_handle_t>
18
18
filterP2PDevices (ur_device_handle_t hSourceDevice,
19
19
const std::vector<ur_device_handle_t > &devices) {
20
20
std::vector<ur_device_handle_t > p2pDevices;
21
+
22
+ std::optional<std::unordered_set<DeviceId>> p2pDevicesEnabledByUser;
23
+
21
24
for (auto &device : devices) {
22
25
if (device == hSourceDevice) {
23
26
continue ;
24
27
}
25
28
29
+ if (!p2pDevicesEnabledByUser.has_value ())
30
+ {
31
+ std::shared_lock<ur_shared_mutex> Lock (hSourceDevice->Mutex );
32
+ p2pDevicesEnabledByUser.emplace (hSourceDevice->p2pDeviceIds );
33
+ }
34
+
35
+ if (p2pDevicesEnabledByUser->count (*device->Id ) == 0 ) {
36
+ continue ;
37
+ }
38
+
26
39
ze_bool_t p2p;
27
40
ZE2UR_CALL_THROWS (zeDeviceCanAccessPeer,
28
41
(device->ZeDevice , hSourceDevice->ZeDevice , &p2p));
@@ -126,8 +139,39 @@ void ur_context_handle_t_::removeUsmPool(ur_usm_pool_handle_t hPool) {
126
139
usmPoolHandles.remove (hPool);
127
140
}
128
141
129
- const std::vector<ur_device_handle_t > &
130
- ur_context_handle_t_::getP2PDevices (ur_device_handle_t hDevice) const {
142
+ void ur_context_handle_t_::addResidentDevice (ur_device_handle_t hDevice, ur_device_handle_t newPeerDevice) {
143
+ std::scoped_lock<ur_shared_mutex> lock (Mutex);
144
+ auto & pDevices = p2pAccessDevices.at (hDevice->Id );
145
+
146
+ assert (0 = std::count_if (
147
+ std::begin (pDevices), std::end (pDevices),
148
+ [&](const auto pDevice) { return newPeerDevice->Id == pDevice->Id ; }));
149
+
150
+ pDevices.push_back (newPeerDevice);
151
+ }
152
+ void ur_context_handle_t_::removeResidentDevice (ur_device_handle_t hDevice, ur_device_handle_t oldPeerDevice) {
153
+ std::scoped_lock<ur_shared_mutex> lock (Mutex);
154
+ auto & pDevices = p2pAccessDevices.at (hDevice->Id );
155
+
156
+ const auto & findOldDevice = [&] {
157
+ return std::find_if (
158
+ std::begin (pDevices), std::end (pDevices),
159
+ [oldPeerDevice](const auto pDevice) { return oldPeerDevice->Id == pDevice->Id ; });
160
+ };
161
+
162
+ auto pDeviceIt = findOldDevice ();
163
+ assert (pDeviceIt != std::end (pDevices));
164
+ pDevices.erase (pDeviceIt);
165
+ assert (findOldDevice () == std::end (pDevices));
166
+
167
+ for (auto poolHandle : usmPoolHandles) {
168
+ poolHandle->removeResidentDevice ();
169
+ }
170
+ }
171
+
172
+ std::vector<ur_device_handle_t >
173
+ ur_context_handle_t_::getP2PDevices (ur_device_handle_t hDevice) {
174
+ std::scoped_lock<ur_shared_mutex> lock (Mutex);
131
175
return p2pAccessDevices[hDevice->Id .value ()];
132
176
}
133
177
@@ -145,6 +189,10 @@ ur_result_t urContextCreate(uint32_t deviceCount,
145
189
146
190
*phContext =
147
191
new ur_context_handle_t_ (zeContext, deviceCount, phDevices, true );
192
+ {
193
+ std::scoped_lock<ur_shared_mutex> Lock (hPlatform->ContextsMutex );
194
+ hPlatform->Contexts .push_back (*phContext);
195
+ }
148
196
return UR_RESULT_SUCCESS;
149
197
} catch (...) {
150
198
return exceptionToResult (std::current_exception ());
@@ -182,6 +230,14 @@ ur_result_t urContextRetain(ur_context_handle_t hContext) try {
182
230
}
183
231
184
232
ur_result_t urContextRelease (ur_context_handle_t hContext) try {
233
+ auto Platform = hContext->getPlatform ();
234
+ auto &Contexts = Platform->Contexts ;
235
+ {
236
+ std::scoped_lock<ur_shared_mutex> Lock (Platform->ContextsMutex );
237
+ auto It = std::find (Contexts.begin (), Contexts.end (), hContext);
238
+ UR_ASSERT (It != Contexts.end (), UR_RESULT_ERROR_INVALID_CONTEXT);
239
+ Contexts.erase (It);
240
+ }
185
241
return hContext->release ();
186
242
} catch (...) {
187
243
return exceptionToResult (std::current_exception ());
0 commit comments