@@ -188,6 +188,68 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
188
188
}
189
189
}
190
190
191
+ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t hContext,
192
+ ur_device_handle_t hDevice,
193
+ ur_usm_pool_desc_t *pPoolDesc)
194
+ : hContext(hContext) {
195
+ // TODO: handle UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK from pPoolDesc
196
+ auto disjointPoolConfigs = initializeDisjointPoolConfig ();
197
+
198
+ if (disjointPoolConfigs.has_value ()) {
199
+ if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
200
+ for (auto &config : disjointPoolConfigs.value ().Configs ) {
201
+ config.MaxPoolableSize = limits->maxPoolableSize ;
202
+ config.SlabMinSize = limits->minDriverAllocSize ;
203
+ }
204
+ }
205
+ } else {
206
+ // If pooling is disabled, do nothing.
207
+ UR_LOG (INFO, " USM pooling is disabled. Skiping pool limits adjustment." );
208
+ }
209
+
210
+ // Create pool descriptor for single device provided
211
+ std::vector<usm::pool_descriptor> descriptors;
212
+ {
213
+ auto &desc = descriptors.emplace_back ();
214
+ desc.poolHandle = this ;
215
+ desc.hContext = hContext;
216
+ desc.hDevice = hDevice;
217
+ desc.type = UR_USM_TYPE_DEVICE;
218
+ }
219
+ {
220
+ auto &desc = descriptors.emplace_back ();
221
+ desc.poolHandle = this ;
222
+ desc.hContext = hContext;
223
+ desc.hDevice = hDevice;
224
+ desc.type = UR_USM_TYPE_SHARED;
225
+ desc.deviceReadOnly = false ;
226
+ }
227
+ {
228
+ auto &desc = descriptors.emplace_back ();
229
+ desc.poolHandle = this ;
230
+ desc.hContext = hContext;
231
+ desc.hDevice = hDevice;
232
+ desc.type = UR_USM_TYPE_SHARED;
233
+ desc.deviceReadOnly = true ;
234
+ }
235
+
236
+ for (auto &desc : descriptors) {
237
+ std::unique_ptr<UsmPool> usmPool;
238
+ if (disjointPoolConfigs.has_value ()) {
239
+ auto &poolConfig =
240
+ disjointPoolConfigs.value ().Configs [descToDisjoinPoolMemType (desc)];
241
+ auto pool = usm::makeDisjointPool (makeProvider (desc), poolConfig);
242
+ usmPool = std::make_unique<UsmPool>(this , std::move (pool));
243
+ } else {
244
+ auto pool = usm::makeProxyPool (makeProvider (desc));
245
+ usmPool = std::make_unique<UsmPool>(this , std::move (pool));
246
+ }
247
+ UMF_CALL_THROWS (
248
+ umfPoolSetTag (usmPool->umfPool .get (), usmPool.get (), nullptr ));
249
+ poolManager.addPool (desc, std::move (usmPool));
250
+ }
251
+ }
252
+
191
253
ur_context_handle_t ur_usm_pool_handle_t_::getContextHandle () const {
192
254
return hContext;
193
255
}
@@ -358,27 +420,27 @@ size_t ur_usm_pool_handle_t_::getTotalReservedSize() {
358
420
}
359
421
360
422
size_t ur_usm_pool_handle_t_::getPeakReservedSize () {
361
- size_t totalAllocatedSize = 0 ;
423
+ size_t maxPeakSize = 0 ;
362
424
umf_result_t umfRet = UMF_RESULT_SUCCESS;
363
425
poolManager.forEachPool ([&](UsmPool *p) {
364
426
umf_memory_provider_handle_t hProvider = nullptr ;
365
- size_t allocatedSize = 0 ;
427
+ size_t peakSize = 0 ;
366
428
umfRet = umfPoolGetMemoryProvider (p->umfPool .get (), &hProvider);
367
429
if (umfRet != UMF_RESULT_SUCCESS) {
368
430
return false ;
369
431
}
370
432
371
- umfRet = umfCtlGet (" umf.provider.by_handle.{}.stats.peak_memory" ,
372
- &allocatedSize, sizeof (allocatedSize ), hProvider);
433
+ umfRet = umfCtlGet (" umf.provider.by_handle.{}.stats.peak_memory" , &peakSize,
434
+ sizeof (peakSize ), hProvider);
373
435
if (umfRet != UMF_RESULT_SUCCESS) {
374
436
return false ;
375
437
}
376
438
377
- totalAllocatedSize += allocatedSize ;
439
+ maxPeakSize = std::max (maxPeakSize, peakSize) ;
378
440
return true ;
379
441
});
380
442
381
- return umfRet == UMF_RESULT_SUCCESS ? totalAllocatedSize : 0 ;
443
+ return umfRet == UMF_RESULT_SUCCESS ? maxPeakSize : 0 ;
382
444
}
383
445
384
446
size_t ur_usm_pool_handle_t_::getTotalUsedSize () {
@@ -460,6 +522,32 @@ ur_result_t urUSMPoolGetInfo(
460
522
return exceptionToResult (std::current_exception ());
461
523
}
462
524
525
+ ur_result_t urUSMPoolCreateExp (ur_context_handle_t hContext,
526
+ ur_device_handle_t hDevice,
527
+ ur_usm_pool_desc_t *pPoolDesc,
528
+ ur_usm_pool_handle_t *pPool) try {
529
+ *pPool = new ur_usm_pool_handle_t_ (hContext, hDevice, pPoolDesc);
530
+ hContext->addUsmPool (*pPool);
531
+ return UR_RESULT_SUCCESS;
532
+ } catch (umf_result_t e) {
533
+ return umf::umf2urResult (e);
534
+ } catch (...) {
535
+ return exceptionToResult (std::current_exception ());
536
+ }
537
+
538
+ ur_result_t urUSMPoolDestroyExp (ur_context_handle_t , ur_device_handle_t ,
539
+ ur_usm_pool_handle_t hPool) try {
540
+ if (hPool->RefCount .release ()) {
541
+ hPool->getContextHandle ()->removeUsmPool (hPool);
542
+ delete hPool;
543
+ }
544
+ return UR_RESULT_SUCCESS;
545
+ } catch (umf_result_t e) {
546
+ return umf::umf2urResult (e);
547
+ } catch (...) {
548
+ return exceptionToResult (std::current_exception ());
549
+ }
550
+
463
551
ur_result_t urUSMPoolGetInfoExp (ur_usm_pool_handle_t hPool,
464
552
ur_usm_pool_info_t propName, void *pPropValue,
465
553
size_t *pPropSizeRet) {
@@ -497,6 +585,28 @@ ur_result_t urUSMPoolGetInfoExp(ur_usm_pool_handle_t hPool,
497
585
return UR_RESULT_SUCCESS;
498
586
}
499
587
588
+ ur_result_t urUSMPoolSetInfoExp (ur_usm_pool_handle_t /* hPool*/ ,
589
+ ur_usm_pool_info_t propName,
590
+ void * /* pPropValue*/ , size_t propSize) {
591
+ if (propSize < sizeof (size_t )) {
592
+ return UR_RESULT_ERROR_INVALID_SIZE;
593
+ }
594
+
595
+ switch (propName) {
596
+ // TODO: Support for pool release threshold and maximum size hints.
597
+ case UR_USM_POOL_INFO_RELEASE_THRESHOLD_EXP:
598
+ case UR_USM_POOL_INFO_MAXIMUM_SIZE_EXP:
599
+ // TODO: Allow user to overwrite pool peak statistics.
600
+ case UR_USM_POOL_INFO_RESERVED_HIGH_EXP:
601
+ case UR_USM_POOL_INFO_USED_HIGH_EXP:
602
+ break ;
603
+ default :
604
+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
605
+ }
606
+
607
+ return UR_RESULT_SUCCESS;
608
+ }
609
+
500
610
ur_result_t urUSMPoolGetDefaultDevicePoolExp (ur_context_handle_t hContext,
501
611
ur_device_handle_t ,
502
612
ur_usm_pool_handle_t *pPool) {
0 commit comments