Skip to content

Commit dfada06

Browse files
fs-eireankitm3k
authored andcommitted
support WebGPU EP in Node.js binding (microsoft#22660)
### Description This change enhances the Node.js binding with the following features: - support WebGPU EP - lazy initialization of `OrtEnv` - being able to initialize ORT with default log level setting from `ort.env.logLevel`. - session options: - `enableProfiling` and `profileFilePrefix`: support profiling. - `externalData`: explicit external data (optional in Node.js binding) - `optimizedModelFilePath`: allow dumping optimized model for diagnosis purpose - `preferredOutputLocation`: support IO binding. ====================================================== `Tensor.download()` is not implemented in this PR. Build pipeline update is not included in this PR.
1 parent 703e951 commit dfada06

10 files changed

+479
-63
lines changed

js/node/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.11)
22

33
project (onnxruntime-node)
44

5-
set(CMAKE_CXX_STANDARD 14)
5+
set(CMAKE_CXX_STANDARD 17)
66

77
add_compile_definitions(NAPI_VERSION=${napi_build_version})
88
add_compile_definitions(ORT_API_MANUAL_INIT)
@@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api)
3434

3535
# optional providers
3636
option(USE_DML "Build with DirectML support" OFF)
37+
option(USE_WEBGPU "Build with WebGPU support" OFF)
3738
option(USE_CUDA "Build with CUDA support" OFF)
3839
option(USE_TENSORRT "Build with TensorRT support" OFF)
3940
option(USE_COREML "Build with CoreML support" OFF)
@@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF)
4243
if(USE_DML)
4344
add_compile_definitions(USE_DML=1)
4445
endif()
46+
if(USE_WEBGPU)
47+
add_compile_definitions(USE_WEBGPU=1)
48+
endif()
4549
if(USE_CUDA)
4650
add_compile_definitions(USE_CUDA=1)
4751
endif()

js/node/lib/backend.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common';
55

6-
import { Binding, binding } from './binding';
6+
import { Binding, binding, initOrt } from './binding';
77

88
class OnnxruntimeSessionHandler implements InferenceSessionHandler {
99
#inferenceSession: Binding.InferenceSession;
1010

1111
constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) {
12+
initOrt();
13+
1214
this.#inferenceSession = new binding.InferenceSession();
1315
if (typeof pathOrBuffer === 'string') {
1416
this.#inferenceSession.loadModel(pathOrBuffer, options);
@@ -27,10 +29,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
2729
readonly outputNames: string[];
2830

2931
startProfiling(): void {
30-
// TODO: implement profiling
32+
// startProfiling is a no-op.
33+
//
34+
// if sessionOptions.enableProfiling is true, profiling will be enabled when the model is loaded.
3135
}
3236
endProfiling(): void {
33-
// TODO: implement profiling
37+
this.#inferenceSession.endProfiling();
3438
}
3539

3640
async run(

js/node/lib/binding.ts

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
import { InferenceSession, OnnxValue } from 'onnxruntime-common';
4+
import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common';
55

66
type SessionOptions = InferenceSession.SessionOptions;
77
type FeedsType = {
@@ -28,6 +28,8 @@ export declare namespace Binding {
2828

2929
run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType;
3030

31+
endProfiling(): void;
32+
3133
dispose(): void;
3234
}
3335

@@ -48,4 +50,35 @@ export const binding =
4850
// eslint-disable-next-line @typescript-eslint/naming-convention
4951
InferenceSession: Binding.InferenceSessionConstructor;
5052
listSupportedBackends: () => Binding.SupportedBackend[];
53+
initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void;
5154
};
55+
56+
let ortInitialized = false;
57+
export const initOrt = (): void => {
58+
if (!ortInitialized) {
59+
ortInitialized = true;
60+
let logLevel = 2;
61+
if (env.logLevel) {
62+
switch (env.logLevel) {
63+
case 'verbose':
64+
logLevel = 0;
65+
break;
66+
case 'info':
67+
logLevel = 1;
68+
break;
69+
case 'warning':
70+
logLevel = 2;
71+
break;
72+
case 'error':
73+
logLevel = 3;
74+
break;
75+
case 'fatal':
76+
logLevel = 4;
77+
break;
78+
default:
79+
throw new Error(`Unsupported log level: ${env.logLevel}`);
80+
}
81+
}
82+
binding.initOrtOnce(logLevel, Tensor);
83+
}
84+
};

js/node/script/build.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator'];
2929
const REBUILD = !!buildArgs.rebuild;
3030
// --use_dml
3131
const USE_DML = !!buildArgs.use_dml;
32+
// --use_webgpu
33+
const USE_WEBGPU = !!buildArgs.use_webgpu;
3234
// --use_cuda
3335
const USE_CUDA = !!buildArgs.use_cuda;
3436
// --use_tensorrt
@@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') {
6567
if (USE_DML) {
6668
args.push('--CDUSE_DML=ON');
6769
}
70+
if (USE_WEBGPU) {
71+
args.push('--CDUSE_WEBGPU=ON');
72+
}
6873
if (USE_CUDA) {
6974
args.push('--CDUSE_CUDA=ON');
7075
}

js/node/src/inference_session_wrap.cc

Lines changed: 100 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
#include "tensor_helper.h"
1212
#include <string>
1313

14-
Napi::FunctionReference InferenceSessionWrap::constructor;
14+
Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor;
15+
Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor;
16+
17+
Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() {
18+
return InferenceSessionWrap::ortTensorConstructor;
19+
}
1520

1621
Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
1722
#if defined(USE_DML) && defined(_WIN32)
@@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
2328
Ort::Global<void>::api_ == nullptr, env,
2429
"Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version "
2530
"ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library).");
26-
auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"};
27-
env.SetInstanceData(ortEnv);
31+
2832
// initialize binding
2933
Napi::HandleScope scope(env);
3034

3135
Napi::Function func = DefineClass(
3236
env, "InferenceSession",
33-
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run),
37+
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel),
38+
InstanceMethod("run", &InferenceSessionWrap::Run),
3439
InstanceMethod("dispose", &InferenceSessionWrap::Dispose),
40+
InstanceMethod("endProfiling", &InferenceSessionWrap::EndProfiling),
3541
InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr),
3642
InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)});
3743

38-
constructor = Napi::Persistent(func);
39-
constructor.SuppressDestruct();
44+
wrappedSessionConstructor = Napi::Persistent(func);
45+
wrappedSessionConstructor.SuppressDestruct();
4046
exports.Set("InferenceSession", func);
4147

4248
Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends);
4349
exports.Set("listSupportedBackends", listSupportedBackends);
4450

51+
Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce);
52+
exports.Set("initOrtOnce", initOrtOnce);
53+
4554
return exports;
4655
}
4756

57+
Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) {
58+
Napi::Env env = info.Env();
59+
Napi::HandleScope scope(env);
60+
61+
int log_level = info[0].As<Napi::Number>().Int32Value();
62+
63+
Ort::Env* ortEnv = env.GetInstanceData<Ort::Env>();
64+
if (ortEnv == nullptr) {
65+
ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"};
66+
env.SetInstanceData(ortEnv);
67+
}
68+
69+
Napi::Function tensorConstructor = info[1].As<Napi::Function>();
70+
ortTensorConstructor = Napi::Persistent(tensorConstructor);
71+
ortTensorConstructor.SuppressDestruct();
72+
73+
return env.Undefined();
74+
}
75+
4876
InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info)
4977
: Napi::ObjectWrap<InferenceSessionWrap>(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {}
5078

@@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
118146
? typeInfo.GetTensorTypeAndShapeInfo().GetElementType()
119147
: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
120148
}
149+
150+
// cache preferred output locations
151+
ParsePreferredOutputLocations(info[argsLength - 1].As<Napi::Object>(), outputNames_, preferredOutputLocations_);
152+
if (preferredOutputLocations_.size() > 0) {
153+
ioBinding_ = std::make_unique<Ort::IoBinding>(*session_);
154+
}
121155
} catch (Napi::Error const& e) {
122156
throw e;
123157
} catch (std::exception const& e) {
@@ -167,15 +201,16 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
167201
std::vector<bool> reuseOutput;
168202
size_t inputIndex = 0;
169203
size_t outputIndex = 0;
170-
OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release();
204+
Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
205+
Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault};
171206

172207
try {
173208
for (auto& name : inputNames_) {
174209
if (feed.Has(name)) {
175210
inputIndex++;
176211
inputNames_cstr.push_back(name.c_str());
177212
auto value = feed.Get(name);
178-
inputValues.push_back(NapiValueToOrtValue(env, value, memory_info));
213+
inputValues.push_back(NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
179214
}
180215
}
181216
for (auto& name : outputNames_) {
@@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
184219
outputNames_cstr.push_back(name.c_str());
185220
auto value = fetch.Get(name);
186221
reuseOutput.push_back(!value.IsNull());
187-
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info));
222+
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
188223
}
189224
}
190225

@@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
193228
runOptions = Ort::RunOptions{};
194229
ParseRunOptions(info[2].As<Napi::Object>(), runOptions);
195230
}
231+
if (preferredOutputLocations_.size() == 0) {
232+
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
233+
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
234+
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
235+
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
196236

197-
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
198-
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
199-
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
200-
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
237+
Napi::Object result = Napi::Object::New(env);
201238

202-
Napi::Object result = Napi::Object::New(env);
239+
for (size_t i = 0; i < outputIndex; i++) {
240+
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputValues[i])));
241+
}
242+
return scope.Escape(result);
243+
} else {
244+
// IO binding
245+
ORT_NAPI_THROW_ERROR_IF(preferredOutputLocations_.size() != outputNames_.size(), env,
246+
"Preferred output locations must have the same size as output names.");
203247

204-
for (size_t i = 0; i < outputIndex; i++) {
205-
result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i]));
206-
}
248+
for (size_t i = 0; i < inputIndex; i++) {
249+
ioBinding_->BindInput(inputNames_cstr[i], inputValues[i]);
250+
}
251+
for (size_t i = 0; i < outputIndex; i++) {
252+
// TODO: support preallocated output tensor (outputValues[i])
253+
254+
if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) {
255+
ioBinding_->BindOutput(outputNames_cstr[i], gpuBufferMemoryInfo);
256+
} else {
257+
ioBinding_->BindOutput(outputNames_cstr[i], cpuMemoryInfo);
258+
}
259+
}
260+
261+
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, *ioBinding_);
262+
263+
auto outputs = ioBinding_->GetOutputValues();
264+
ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch.");
207265

208-
return scope.Escape(result);
266+
Napi::Object result = Napi::Object::New(env);
267+
for (size_t i = 0; i < outputIndex; i++) {
268+
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputs[i])));
269+
}
270+
return scope.Escape(result);
271+
}
209272
} catch (Napi::Error const& e) {
210273
throw e;
211274
} catch (std::exception const& e) {
@@ -218,13 +281,29 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
218281
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
219282
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
220283

284+
this->ioBinding_.reset(nullptr);
285+
221286
this->defaultRunOptions_.reset(nullptr);
222287
this->session_.reset(nullptr);
223288

224289
this->disposed_ = true;
225290
return env.Undefined();
226291
}
227292

293+
Napi::Value InferenceSessionWrap::EndProfiling(const Napi::CallbackInfo& info) {
294+
Napi::Env env = info.Env();
295+
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
296+
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
297+
298+
Napi::EscapableHandleScope scope(env);
299+
300+
Ort::AllocatorWithDefaultOptions allocator;
301+
302+
auto filename = session_->EndProfilingAllocated(allocator);
303+
Napi::String filenameValue = Napi::String::From(env, filename.get());
304+
return scope.Escape(filenameValue);
305+
}
306+
228307
Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) {
229308
Napi::Env env = info.Env();
230309
Napi::EscapableHandleScope scope(env);
@@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo
242321
#ifdef USE_DML
243322
result.Set(result.Length(), createObject("dml", true));
244323
#endif
324+
#ifdef USE_WEBGPU
325+
result.Set(result.Length(), createObject("webgpu", true));
326+
#endif
245327
#ifdef USE_CUDA
246328
result.Set(result.Length(), createObject("cuda", false));
247329
#endif

js/node/src/inference_session_wrap.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,22 @@
1212
class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
1313
public:
1414
static Napi::Object Init(Napi::Env env, Napi::Object exports);
15+
static Napi::FunctionReference& GetTensorConstructor();
16+
1517
InferenceSessionWrap(const Napi::CallbackInfo& info);
1618

1719
private:
20+
/**
21+
* [sync] initialize ONNX Runtime once.
22+
*
23+
* This function must be called before any other functions.
24+
*
25+
* @param arg0 a number specifying the log level.
26+
*
27+
* @returns undefined
28+
*/
29+
static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info);
30+
1831
/**
1932
* [sync] list supported backend list
2033
* @returns array with objects { "name": "cpu", requirementsInstalled: true }
@@ -63,10 +76,19 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
6376
*/
6477
Napi::Value Dispose(const Napi::CallbackInfo& info);
6578

79+
/**
80+
* [sync] end the profiling.
81+
* @param nothing
82+
* @returns nothing
83+
* @throw nothing
84+
*/
85+
Napi::Value EndProfiling(const Napi::CallbackInfo& info);
86+
6687
// private members
6788

6889
// persistent constructor
69-
static Napi::FunctionReference constructor;
90+
static Napi::FunctionReference wrappedSessionConstructor;
91+
static Napi::FunctionReference ortTensorConstructor;
7092

7193
// session objects
7294
bool initialized_;
@@ -81,4 +103,8 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
81103
std::vector<std::string> outputNames_;
82104
std::vector<ONNXType> outputTypes_;
83105
std::vector<ONNXTensorElementDataType> outputTensorElementDataTypes_;
106+
107+
// preferred output locations
108+
std::vector<int> preferredOutputLocations_;
109+
std::unique_ptr<Ort::IoBinding> ioBinding_;
84110
};

0 commit comments

Comments
 (0)