Skip to content

Commit 742594c

Browse files
authored
Clears GPU Cache when there are no more active sessions (#22490)
Fixes #21574
1 parent 08cc261 commit 742594c

File tree

5 files changed

+45
-1
lines changed

5 files changed

+45
-1
lines changed

js/web/lib/wasm/jsep/backend-webgpu.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,10 @@ export class WebGpuBackend {
902902
this.sessionStatus = 'default';
903903
}
904904

905+
onCreateSession(): void {
906+
this.gpuDataManager.onCreateSession();
907+
}
908+
905909
onReleaseSession(sessionId: number): void {
906910
this.unregisterBuffers(sessionId);
907911
if (this.capturedCommandList.has(sessionId)) {

js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ export interface GpuDataManager {
6464
*/
6565
dispose(): void;
6666

67+
/**
68+
* create session related data.
69+
*/
70+
onCreateSession(): void;
71+
6772
/**
6873
* release session related data.
6974
* @param sessionId - specify the session ID.
@@ -200,6 +205,9 @@ class GpuDataManagerImpl implements GpuDataManager {
200205
// a SessionID -> GPUBuffer[] mapping.
201206
private capturedPendingBuffers: Map<number, GPUBuffer[]>;
202207

208+
// The session count.
209+
private sessionCount: number;
210+
203211
constructor(private backend: WebGpuBackend) {
204212
this.storageCache = new Map();
205213
this.freeBuffers = new Map();
@@ -213,6 +221,8 @@ class GpuDataManagerImpl implements GpuDataManager {
213221
this.freeBuffers.set(key, []);
214222
this.freeUniformBuffers.set(key, []);
215223
}
224+
225+
this.sessionCount = 0;
216226
}
217227

218228
upload(id: GpuDataId, data: Uint8Array): void {
@@ -360,7 +370,12 @@ class GpuDataManagerImpl implements GpuDataManager {
360370
release(id: GpuDataId): number {
361371
const cachedData = this.storageCache.get(id);
362372
if (!cachedData) {
363-
throw new Error('releasing data does not exist');
373+
if (this.storageCache.size === 0) {
374+
// cache was previously cleared, no need to release anything.
375+
return 0;
376+
} else {
377+
throw new Error('releasing data does not exist');
378+
}
364379
}
365380

366381
LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.release(id=${id}), gpuDataId=${cachedData.gpuData.id}`);
@@ -460,6 +475,10 @@ class GpuDataManagerImpl implements GpuDataManager {
460475
this.capturedPendingBuffers = new Map();
461476
}
462477

478+
onCreateSession() {
479+
this.sessionCount += 1;
480+
}
481+
463482
onReleaseSession(sessionId: number) {
464483
// release the captured pending buffers.
465484
const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
@@ -469,6 +488,16 @@ class GpuDataManagerImpl implements GpuDataManager {
469488
});
470489
this.capturedPendingBuffers.delete(sessionId);
471490
}
491+
492+
// release the storage cache if no active sessions.
493+
this.sessionCount -= 1;
494+
if (this.sessionCount === 0) {
495+
LOG_DEBUG('warning', () => '[WebGPU] Clearing webgpu buffer cache');
496+
this.storageCache.forEach((storage) => {
497+
storage.gpuData.buffer.destroy();
498+
});
499+
this.storageCache = new Map();
500+
}
472501
}
473502
}
474503

js/web/lib/wasm/wasm-core-impl.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ export const createSession = async (
317317
checkLastError("Can't create a session.");
318318
}
319319

320+
wasm.jsepOnCreateSession?.();
321+
320322
// clear current MLContext after session creation
321323
if (wasm.currentContext) {
322324
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);

js/web/lib/wasm/wasm-types.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ export declare namespace JSEP {
141141
* @param sessionId - specify the session ID.
142142
*/
143143
jsepOnRunStart: (sessionId: number) => void;
144+
/**
145+
* [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is
146+
* called.
147+
* @returns
148+
*/
149+
jsepOnCreateSession: () => void;
144150
/**
145151
* [exported from pre-jsep.js] Release a session. This function will be called before _OrtReleaseSession() is
146152
* called.

onnxruntime/wasm/pre-jsep.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ Module['jsepInit'] = (name, params) => {
192192
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
193193
return backend['createDownloader'](gpuBuffer, size, type);
194194
};
195+
Module['jsepOnCreateSession'] = sessionId => {
196+
backend['onCreateSession'](sessionId);
197+
};
195198
Module['jsepOnReleaseSession'] = sessionId => {
196199
backend['onReleaseSession'](sessionId);
197200
};

0 commit comments

Comments
 (0)