11
11
#include " tensor_helper.h"
12
12
#include < string>
13
13
14
- Napi::FunctionReference InferenceSessionWrap::constructor;
14
+ Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor;
15
+ Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor;
16
+
17
+ Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor () {
18
+ return InferenceSessionWrap::ortTensorConstructor;
19
+ }
15
20
16
21
Napi::Object InferenceSessionWrap::Init (Napi::Env env, Napi::Object exports) {
17
22
#if defined(USE_DML) && defined(_WIN32)
@@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
23
28
Ort::Global<void >::api_ == nullptr , env,
24
29
" Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version "
25
30
" ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)." );
26
- auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, " onnxruntime-node" };
27
- env.SetInstanceData (ortEnv);
31
+
28
32
// initialize binding
29
33
Napi::HandleScope scope (env);
30
34
31
35
Napi::Function func = DefineClass (
32
36
env, " InferenceSession" ,
33
- {InstanceMethod (" loadModel" , &InferenceSessionWrap::LoadModel), InstanceMethod (" run" , &InferenceSessionWrap::Run),
37
+ {InstanceMethod (" loadModel" , &InferenceSessionWrap::LoadModel),
38
+ InstanceMethod (" run" , &InferenceSessionWrap::Run),
34
39
InstanceMethod (" dispose" , &InferenceSessionWrap::Dispose),
40
+ InstanceMethod (" endProfiling" , &InferenceSessionWrap::EndProfiling),
35
41
InstanceAccessor (" inputNames" , &InferenceSessionWrap::GetInputNames, nullptr , napi_default, nullptr ),
36
42
InstanceAccessor (" outputNames" , &InferenceSessionWrap::GetOutputNames, nullptr , napi_default, nullptr )});
37
43
38
- constructor = Napi::Persistent (func);
39
- constructor .SuppressDestruct ();
44
+ wrappedSessionConstructor = Napi::Persistent (func);
45
+ wrappedSessionConstructor .SuppressDestruct ();
40
46
exports.Set (" InferenceSession" , func);
41
47
42
48
Napi::Function listSupportedBackends = Napi::Function::New (env, InferenceSessionWrap::ListSupportedBackends);
43
49
exports.Set (" listSupportedBackends" , listSupportedBackends);
44
50
51
+ Napi::Function initOrtOnce = Napi::Function::New (env, InferenceSessionWrap::InitOrtOnce);
52
+ exports.Set (" initOrtOnce" , initOrtOnce);
53
+
45
54
return exports;
46
55
}
47
56
57
+ Napi::Value InferenceSessionWrap::InitOrtOnce (const Napi::CallbackInfo& info) {
58
+ Napi::Env env = info.Env ();
59
+ Napi::HandleScope scope (env);
60
+
61
+ int log_level = info[0 ].As <Napi::Number>().Int32Value ();
62
+
63
+ Ort::Env* ortEnv = env.GetInstanceData <Ort::Env>();
64
+ if (ortEnv == nullptr ) {
65
+ ortEnv = new Ort::Env{OrtLoggingLevel (log_level), " onnxruntime-node" };
66
+ env.SetInstanceData (ortEnv);
67
+ }
68
+
69
+ Napi::Function tensorConstructor = info[1 ].As <Napi::Function>();
70
+ ortTensorConstructor = Napi::Persistent (tensorConstructor);
71
+ ortTensorConstructor.SuppressDestruct ();
72
+
73
+ return env.Undefined ();
74
+ }
75
+
48
76
InferenceSessionWrap::InferenceSessionWrap (const Napi::CallbackInfo& info)
49
77
: Napi::ObjectWrap<InferenceSessionWrap>(info), initialized_(false ), disposed_(false ), session_(nullptr ), defaultRunOptions_(nullptr ) {}
50
78
@@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
118
146
? typeInfo.GetTensorTypeAndShapeInfo ().GetElementType ()
119
147
: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
120
148
}
149
+
150
+ // cache preferred output locations
151
+ ParsePreferredOutputLocations (info[argsLength - 1 ].As <Napi::Object>(), outputNames_, preferredOutputLocations_);
152
+ if (preferredOutputLocations_.size () > 0 ) {
153
+ ioBinding_ = std::make_unique<Ort::IoBinding>(*session_);
154
+ }
121
155
} catch (Napi::Error const & e) {
122
156
throw e;
123
157
} catch (std::exception const & e) {
@@ -167,15 +201,16 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
167
201
std::vector<bool > reuseOutput;
168
202
size_t inputIndex = 0 ;
169
203
size_t outputIndex = 0 ;
170
- OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu (OrtArenaAllocator, OrtMemTypeDefault).release ();
204
+ Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
205
+ Ort::MemoryInfo gpuBufferMemoryInfo{" WebGPU_Buffer" , OrtDeviceAllocator, 0 , OrtMemTypeDefault};
171
206
172
207
try {
173
208
for (auto & name : inputNames_) {
174
209
if (feed.Has (name)) {
175
210
inputIndex++;
176
211
inputNames_cstr.push_back (name.c_str ());
177
212
auto value = feed.Get (name);
178
- inputValues.push_back (NapiValueToOrtValue (env, value, memory_info ));
213
+ inputValues.push_back (NapiValueToOrtValue (env, value, cpuMemoryInfo, gpuBufferMemoryInfo ));
179
214
}
180
215
}
181
216
for (auto & name : outputNames_) {
@@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
184
219
outputNames_cstr.push_back (name.c_str ());
185
220
auto value = fetch.Get (name);
186
221
reuseOutput.push_back (!value.IsNull ());
187
- outputValues.emplace_back (value.IsNull () ? Ort::Value{nullptr } : NapiValueToOrtValue (env, value, memory_info ));
222
+ outputValues.emplace_back (value.IsNull () ? Ort::Value{nullptr } : NapiValueToOrtValue (env, value, cpuMemoryInfo, gpuBufferMemoryInfo ));
188
223
}
189
224
}
190
225
@@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
193
228
runOptions = Ort::RunOptions{};
194
229
ParseRunOptions (info[2 ].As <Napi::Object>(), runOptions);
195
230
}
231
+ if (preferredOutputLocations_.size () == 0 ) {
232
+ session_->Run (runOptions == nullptr ? *defaultRunOptions_.get () : runOptions,
233
+ inputIndex == 0 ? nullptr : &inputNames_cstr[0 ], inputIndex == 0 ? nullptr : &inputValues[0 ],
234
+ inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0 ],
235
+ outputIndex == 0 ? nullptr : &outputValues[0 ], outputIndex);
196
236
197
- session_->Run (runOptions == nullptr ? *defaultRunOptions_.get () : runOptions,
198
- inputIndex == 0 ? nullptr : &inputNames_cstr[0 ], inputIndex == 0 ? nullptr : &inputValues[0 ],
199
- inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0 ],
200
- outputIndex == 0 ? nullptr : &outputValues[0 ], outputIndex);
237
+ Napi::Object result = Napi::Object::New (env);
201
238
202
- Napi::Object result = Napi::Object::New (env);
239
+ for (size_t i = 0 ; i < outputIndex; i++) {
240
+ result.Set (outputNames_[i], OrtValueToNapiValue (env, std::move (outputValues[i])));
241
+ }
242
+ return scope.Escape (result);
243
+ } else {
244
+ // IO binding
245
+ ORT_NAPI_THROW_ERROR_IF (preferredOutputLocations_.size () != outputNames_.size (), env,
246
+ " Preferred output locations must have the same size as output names." );
203
247
204
- for (size_t i = 0 ; i < outputIndex; i++) {
205
- result.Set (outputNames_[i], OrtValueToNapiValue (env, outputValues[i]));
206
- }
248
+ for (size_t i = 0 ; i < inputIndex; i++) {
249
+ ioBinding_->BindInput (inputNames_cstr[i], inputValues[i]);
250
+ }
251
+ for (size_t i = 0 ; i < outputIndex; i++) {
252
+ // TODO: support preallocated output tensor (outputValues[i])
253
+
254
+ if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) {
255
+ ioBinding_->BindOutput (outputNames_cstr[i], gpuBufferMemoryInfo);
256
+ } else {
257
+ ioBinding_->BindOutput (outputNames_cstr[i], cpuMemoryInfo);
258
+ }
259
+ }
260
+
261
+ session_->Run (runOptions == nullptr ? *defaultRunOptions_.get () : runOptions, *ioBinding_);
262
+
263
+ auto outputs = ioBinding_->GetOutputValues ();
264
+ ORT_NAPI_THROW_ERROR_IF (outputs.size () != outputIndex, env, " Output count mismatch." );
207
265
208
- return scope.Escape (result);
266
+ Napi::Object result = Napi::Object::New (env);
267
+ for (size_t i = 0 ; i < outputIndex; i++) {
268
+ result.Set (outputNames_[i], OrtValueToNapiValue (env, std::move (outputs[i])));
269
+ }
270
+ return scope.Escape (result);
271
+ }
209
272
} catch (Napi::Error const & e) {
210
273
throw e;
211
274
} catch (std::exception const & e) {
@@ -218,13 +281,29 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
218
281
ORT_NAPI_THROW_ERROR_IF (!this ->initialized_ , env, " Session is not initialized." );
219
282
ORT_NAPI_THROW_ERROR_IF (this ->disposed_ , env, " Session already disposed." );
220
283
284
+ this ->ioBinding_ .reset (nullptr );
285
+
221
286
this ->defaultRunOptions_ .reset (nullptr );
222
287
this ->session_ .reset (nullptr );
223
288
224
289
this ->disposed_ = true ;
225
290
return env.Undefined ();
226
291
}
227
292
293
+ Napi::Value InferenceSessionWrap::EndProfiling (const Napi::CallbackInfo& info) {
294
+ Napi::Env env = info.Env ();
295
+ ORT_NAPI_THROW_ERROR_IF (!this ->initialized_ , env, " Session is not initialized." );
296
+ ORT_NAPI_THROW_ERROR_IF (this ->disposed_ , env, " Session already disposed." );
297
+
298
+ Napi::EscapableHandleScope scope (env);
299
+
300
+ Ort::AllocatorWithDefaultOptions allocator;
301
+
302
+ auto filename = session_->EndProfilingAllocated (allocator);
303
+ Napi::String filenameValue = Napi::String::From (env, filename.get ());
304
+ return scope.Escape (filenameValue);
305
+ }
306
+
228
307
Napi::Value InferenceSessionWrap::ListSupportedBackends (const Napi::CallbackInfo& info) {
229
308
Napi::Env env = info.Env ();
230
309
Napi::EscapableHandleScope scope (env);
@@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo
242
321
#ifdef USE_DML
243
322
result.Set (result.Length (), createObject (" dml" , true ));
244
323
#endif
324
+ #ifdef USE_WEBGPU
325
+ result.Set (result.Length (), createObject (" webgpu" , true ));
326
+ #endif
245
327
#ifdef USE_CUDA
246
328
result.Set (result.Length (), createObject (" cuda" , false ));
247
329
#endif
0 commit comments