Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { env, apis } from '../env.js';
// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web/webgpu';
import * as ONNX_WEB from 'onnxruntime-web/all';

export { Tensor } from 'onnxruntime-common';

Expand Down Expand Up @@ -61,6 +61,9 @@ if (apis.IS_NODE_ENV) {
if (apis.IS_WEBGPU_AVAILABLE) {
supportedExecutionProviders.push('webgpu');
}
if(apis.IS_WEBNN_AVAILABLE) {
supportedExecutionProviders.push('webnn');
}
supportedExecutionProviders.push('wasm');
defaultExecutionProviders = ['wasm'];
}
Expand Down
4 changes: 4 additions & 0 deletions src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const IS_BROWSER_ENV = typeof self !== 'undefined';
const IS_WEBWORKER_ENV = IS_BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope';
const IS_WEB_CACHE_AVAILABLE = IS_BROWSER_ENV && 'caches' in self;
const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;

const IS_PROCESS_AVAILABLE = typeof process !== 'undefined';
const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node';
Expand All @@ -55,6 +56,9 @@ export const apis = Object.freeze({
/** Whether the WebGPU API is available */
IS_WEBGPU_AVAILABLE,

/** Whether the WebNN API is available */
IS_WEBNN_AVAILABLE,

/** Whether the Node.js process API is available */
IS_PROCESS_AVAILABLE,

Expand Down
1 change: 1 addition & 0 deletions src/utils/devices.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export const DEVICE_TYPES = Object.freeze({
gpu: 'gpu', // Auto-detect GPU
wasm: 'wasm', // WebAssembly
webgpu: 'webgpu', // WebGPU
webnn: 'webnn', // WebNN
cuda: 'cuda', // CUDA
dml: 'dml', // DirectML
});
Expand Down
1 change: 1 addition & 0 deletions src/utils/dtypes.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export const DEFAULT_DEVICE_DTYPE_MAPPING = Object.freeze({
[DEVICE_TYPES.gpu]: DATA_TYPES.fp32,
[DEVICE_TYPES.wasm]: DATA_TYPES.q8,
[DEVICE_TYPES.webgpu]: DATA_TYPES.fp32,
[DEVICE_TYPES.webnn]: DATA_TYPES.fp32,
[DEVICE_TYPES.cuda]: DATA_TYPES.fp32,
[DEVICE_TYPES.dml]: DATA_TYPES.fp32,
});
Expand Down