Skip to content

Commit 87cde7f

Browse files
authored
Use preferredInputLayout to select model for different backends (#328)
1 parent e1a8542 commit 87cde7f

File tree

6 files changed

+19
-27
lines changed

6 files changed

+19
-27
lines changed

common/utils.js

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,10 @@ export function permuteData(array, dims, axes) {
432432
return [permutedData, shape];
433433
}
434434

435-
export function getDefaultLayout(deviceType) {
436-
if (deviceType.indexOf('cpu') != -1) {
437-
return 'nhwc';
438-
} else if (deviceType.indexOf('gpu') != -1 ||
439-
deviceType.indexOf('npu') != -1) {
440-
return 'nchw';
441-
}
435+
export async function getDefaultLayout(deviceType) {
436+
const context = await navigator.ml.createContext({deviceType});
437+
const limits = context.opSupportLimits();
438+
return limits.preferredInputLayout ?? 'nchw';
442439
}
443440

444441
/**

face_recognition/main.js

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ $('#backendBtns .btn').on('change', async (e) => {
5454
if (inputType === 'camera') {
5555
await stopCamRender();
5656
}
57-
layout = utils.getDefaultLayout($(e.target).attr('id'));
57+
[backend, deviceType] = $(e.target).attr('id').split('_');
58+
layout = await utils.getDefaultLayout(deviceType);
5859
await main();
5960
});
6061

@@ -296,8 +297,6 @@ function constructNetObject(type) {
296297
async function main() {
297298
try {
298299
if (fdModelName === '') return;
299-
[backend, deviceType] =
300-
$('input[name="backend"]:checked').attr('id').split('_');
301300
ui.handleClick(disabledSelectors, true);
302301
if (isFirstTimeLoad) $('#hint').hide();
303302
const [numRuns, powerPreference] = utils.getUrlParams();
@@ -359,7 +358,7 @@ async function main() {
359358
utils.sizeOfShape(frInputOptions.inputShape)));
360359
for (let i = 0; i < numRuns; i++) {
361360
if (numRuns > 1) {
362-
// clear all predicted embeddings for benckmarking
361+
// clear all predicted embeddings for benchmarking
363362
targetEmbeddings = null;
364363
searchEmbeddings = null;
365364
}

facial_landmark_detection/main.js

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ $('#backendBtns .btn').on('change', async (e) => {
5050
if (inputType === 'camera') {
5151
await stopCamRender();
5252
}
53-
layout = utils.getDefaultLayout($(e.target).attr('id'));
53+
[backend, deviceType] = $(e.target).attr('id').split('_');
54+
layout = await utils.getDefaultLayout(deviceType);
5455
await main();
5556
});
5657

@@ -232,8 +233,6 @@ function constructNetObject(type) {
232233
async function main() {
233234
try {
234235
if (fdModelName === '') return;
235-
[backend, deviceType] =
236-
$('input[name="backend"]:checked').attr('id').split('_');
237236
ui.handleClick(disabledSelectors, true);
238237
if (isFirstTimeLoad) $('#hint').hide();
239238
const [numRuns, powerPreference] = utils.getUrlParams();

image_classification/main.js

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,17 @@ $('#backendBtns .btn').on('change', async (e) => {
9898
if (inputType === 'camera') {
9999
await stopCamRender();
100100
}
101-
const backendId = $(e.target).attr('id');
102-
layout = utils.getDefaultLayout(backendId);
103-
[backend, deviceType] = backendId.split('_');
101+
[backend, deviceType] = $(e.target).attr('id').split('_');
102+
layout = await utils.getDefaultLayout(deviceType);
104103
// Only show the supported models for each deviceType. Now fp16 nchw models
105104
// are only supported on gpu/npu.
106-
if (backendId == 'webnn_gpu') {
105+
if (deviceType == 'gpu') {
107106
ui.handleBtnUI('#float16Label', false);
108107
ui.handleBtnUI('#float32Label', false);
109108
ui.handleBtnUI('#uint8Label', true);
110109
$('#float32').click();
111110
utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
112-
} else if (backendId == 'webnn_npu') {
111+
} else if (deviceType == 'npu') {
113112
ui.handleBtnUI('#float16Label', false);
114113
ui.handleBtnUI('#float32Label', true);
115114
ui.handleBtnUI('#uint8Label', true);

object_detection/main.js

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,15 @@ $('#backendBtns .btn').on('change', async (e) => {
6868
if (inputType === 'camera') {
6969
await stopCamRender();
7070
}
71-
const backendId = $(e.target).attr('id');
72-
layout = utils.getDefaultLayout(backendId);
73-
[backend, deviceType] = backendId.split('_');
71+
[backend, deviceType] = $(e.target).attr('id').split('_');
72+
layout = await utils.getDefaultLayout(deviceType);
7473
// Only show the supported models for each deviceType. Now fp16 nchw models
7574
// are only supported on gpu/npu.
76-
if (backendId == 'webnn_gpu') {
75+
if (deviceType == 'gpu') {
7776
ui.handleBtnUI('#float16Label', false);
7877
ui.handleBtnUI('#float32Label', false);
7978
utils.displayAvailableModels(modelList, modelIds, deviceType, dataType);
80-
} else if (backendId == 'webnn_npu') {
79+
} else if (deviceType == 'npu') {
8180
ui.handleBtnUI('#float16Label', false);
8281
ui.handleBtnUI('#float32Label', true);
8382
$('#float16').click();

semantic_segmentation/main.js

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ $('#backendBtns .btn').on('change', async (e) => {
5353
if (inputType === 'camera') {
5454
await stopCamRender();
5555
}
56-
layout = utils.getDefaultLayout($(e.target).attr('id'));
56+
[backend, deviceType] = $(e.target).attr('id').split('_');
57+
layout = await utils.getDefaultLayout(deviceType);
5758
await main();
5859
});
5960

@@ -326,8 +327,6 @@ function constructNetObject(type) {
326327
export async function main() {
327328
try {
328329
if (modelName === '') return;
329-
[backend, deviceType] =
330-
$('input[name="backend"]:checked').attr('id').split('_');
331330
ui.handleClick(disabledSelectors, true);
332331
if (isFirstTimeLoad) $('#hint').hide();
333332
let start;

0 commit comments

Comments
 (0)