Skip to content

Commit 73e0a4b

Browse files
[WebGPU] allow WGSL template generation (#25130)
### Description This PR introduces the [WGSL-Template](https://github.com/fs-eire/wgsl-template) into WebGPU EP to enable the capability of template based WGSL source generation. The purpose of this change is to improve the readability of the WGSL source code and reduce the future maintenance cost due to the difficulty of understanding the current WGSL code generation logic. ### Explanation - build - for static code generator: `--use_webgpu` - for dynamic code generator: `--use_webgpu --use_wgsl_template dynamic` - the template files use extension `.wgsl.template` and are located in `core/providers/webgpu/` folder. - the corresponding C++ file should use `shader_helper.ApplyTemplate<...>()` to apply the template. (check pad.cc for example) - for documentation - https://github.com/fs-eire/wgsl-template --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b6ce761 commit 73e0a4b

File tree

16 files changed

+927
-67
lines changed

16 files changed

+927
-67
lines changed

.github/workflows/linux-wasm-ci-build-and-test-workflow.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,28 @@ jobs:
4545
with:
4646
submodules: recursive
4747

48+
- name: Set up Node.js
49+
uses: actions/setup-node@v4
50+
with:
51+
node-version: "22"
52+
53+
# The image used in this workflow has an old version of Node.js installed at /bin/node and npm at /bin/npm.
54+
# We need to remove the system Node.js and npm, otherwise CMake will not be able to find the correct Node.js and npm executables.
55+
- name: Remove the system Node.js
56+
run: |
57+
if [ -f /bin/node ]; then
58+
rm /bin/node
59+
fi
60+
if [ -f /bin/npm ]; then
61+
rm /bin/npm
62+
fi
63+
4864
- name: Set up Python
4965
uses: actions/setup-python@v5
5066
with:
5167
python-version: "3.12"
5268
architecture: ${{ env.buildArch }}
69+
5370
- uses: microsoft/onnxruntime-github-actions/[email protected]
5471
with:
5572
vcpkg-version: '2025.06.13'

cmake/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ option(onnxruntime_USE_CANN "Build with CANN support" OFF)
134134
option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF)
135135
option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF)
136136
option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF)
137+
option(onnxruntime_WGSL_TEMPLATE "Specify the code generator for WGSL template. Default is static." "static")
137138
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)
138139
option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.")
139140
option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF)

cmake/deps.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@ directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v
5757
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
5858
dawn;https://github.com/google/dawn/archive/9733be39e18186961d503e064874afe3e9ceb8d1.zip;2a4017c32892b90d072a9102eba90ae691fae36d
5959
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.4.0.tar.gz;22d3b57b54a61c194ab256ff11b0353a3b220244
60+
duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794

cmake/external/onnxruntime_external_deps.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,16 @@ if (onnxruntime_USE_WEBGPU)
772772
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
773773
list(APPEND onnxruntime_EXTERNAL_LIBRARIES webgpu_glfw glfw)
774774
endif()
775+
776+
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic")
777+
onnxruntime_fetchcontent_declare(
778+
duktape
779+
URL ${DEP_URL_duktape}
780+
URL_HASH SHA1=${DEP_SHA1_duktape}
781+
EXCLUDE_FROM_ALL
782+
)
783+
onnxruntime_fetchcontent_makeavailable(duktape)
784+
endif()
775785
endif()
776786

777787
if(onnxruntime_USE_COREML)

cmake/onnxruntime_providers_webgpu.cmake

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
1010
add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1)
1111
endif()
12+
if (onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic")
13+
if (onnxruntime_DISABLE_EXCEPTIONS)
14+
message(FATAL_ERROR "Dynamic WGSL template generation requires exception handling to be enabled.")
15+
endif()
16+
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
17+
message(FATAL_ERROR "Dynamic WGSL template generation is not supported when targeting WebAssembly.")
18+
endif()
19+
add_definitions(-DORT_WGSL_TEMPLATE_DYNAMIC=1)
20+
elseif (NOT onnxruntime_WGSL_TEMPLATE STREQUAL "static")
21+
message(FATAL_ERROR "Unsupported value for onnxruntime_WGSL_TEMPLATE: ${onnxruntime_WGSL_TEMPLATE}. Supported values are 'static' or 'dynamic'.")
22+
endif()
1223
file(GLOB_RECURSE onnxruntime_providers_webgpu_cc_srcs CONFIGURE_DEPENDS
1324
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.h"
1425
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.cc"
@@ -20,6 +31,7 @@
2031

2132
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs})
2233
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
34+
target_compile_features(onnxruntime_providers_webgpu PRIVATE cxx_std_20)
2335
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu
2436
onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
2537

@@ -117,4 +129,95 @@
117129
endif()
118130

119131
add_dependencies(onnxruntime_providers_webgpu ${onnxruntime_EXTERNAL_DEPENDENCIES})
132+
133+
if (onnxruntime_WGSL_TEMPLATE)
134+
# Define the WGSL templates directory and output directory
135+
set(WGSL_TEMPLATES_DIR "${ONNXRUNTIME_ROOT}/core/providers/webgpu/wgsl_templates")
136+
set(WGSL_GENERATED_ROOT "${CMAKE_CURRENT_BINARY_DIR}/wgsl_generated")
137+
138+
# Find npm and node executables
139+
find_program(NPM_EXECUTABLE "npm.cmd" "npm" REQUIRED)
140+
if(NOT NPM_EXECUTABLE)
141+
message(FATAL_ERROR "npm is required for WGSL template generation but was not found")
142+
endif()
143+
find_program(NODE_EXECUTABLE "node" REQUIRED)
144+
if (NOT NODE_EXECUTABLE)
145+
message(FATAL_ERROR "Node is required for WGSL template generation but was not found")
146+
endif()
147+
148+
# Install npm dependencies
149+
add_custom_command(
150+
OUTPUT "${WGSL_TEMPLATES_DIR}/node_modules/.install_complete"
151+
COMMAND ${NPM_EXECUTABLE} ci
152+
COMMAND ${CMAKE_COMMAND} -E touch "${WGSL_TEMPLATES_DIR}/node_modules/.install_complete"
153+
DEPENDS "${WGSL_TEMPLATES_DIR}/package.json" "${WGSL_TEMPLATES_DIR}/package-lock.json"
154+
WORKING_DIRECTORY ${WGSL_TEMPLATES_DIR}
155+
COMMENT "Installing npm dependencies for WGSL template generation"
156+
VERBATIM
157+
)
158+
159+
if (onnxruntime_WGSL_TEMPLATE STREQUAL "static")
160+
set(WGSL_GENERATED_DIR "${WGSL_GENERATED_ROOT}/wgsl_template_gen")
161+
# set(WGSL_GEN_OUTPUTS "${WGSL_GENERATED_DIR}/index.h" "${WGSL_GENERATED_DIR}/index_impl.h")
162+
# Define the output files that will be generated
163+
set(WGSL_GENERATED_INDEX_H "${WGSL_GENERATED_DIR}/index.h")
164+
set(WGSL_GENERATED_INDEX_IMPL_H "${WGSL_GENERATED_DIR}/index_impl.h")
165+
elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic")
166+
set(WGSL_GENERATED_DIR "${WGSL_GENERATED_ROOT}/dynamic")
167+
# set(WGSL_GEN_OUTPUTS "${WGSL_GENERATED_DIR}/templates.js")
168+
set(WGSL_GENERATED_TEMPLATES_JS "${WGSL_GENERATED_DIR}/templates.js")
169+
endif()
170+
171+
# Ensure the output directory exists
172+
file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR})
173+
174+
# Find all WGSL template input files
175+
file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template")
176+
177+
# Set wgsl-gen command line options as a list
178+
set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose")
179+
if (onnxruntime_WGSL_TEMPLATE STREQUAL "static")
180+
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
181+
list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal")
182+
else()
183+
list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp")
184+
endif()
185+
elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic")
186+
list(APPEND WGSL_GEN_OPTIONS "--generator" "dynamic")
187+
endif()
188+
189+
# Generate WGSL templates
190+
add_custom_command(
191+
OUTPUT ${WGSL_GENERATED_INDEX_H} ${WGSL_GENERATED_INDEX_IMPL_H} ${WGSL_GENERATED_TEMPLATES_JS}
192+
COMMAND ${NPM_EXECUTABLE} run gen -- ${WGSL_GEN_OPTIONS}
193+
DEPENDS "${WGSL_TEMPLATES_DIR}/node_modules/.install_complete" ${WGSL_TEMPLATE_FILES}
194+
WORKING_DIRECTORY ${WGSL_TEMPLATES_DIR}
195+
COMMENT "Generating WGSL templates from *.wgsl.template files"
196+
COMMAND_EXPAND_LISTS
197+
VERBATIM
198+
)
199+
200+
# Create a target to represent the generation step
201+
add_custom_target(onnxruntime_webgpu_wgsl_generation
202+
DEPENDS ${WGSL_GENERATED_INDEX_H} ${WGSL_GENERATED_INDEX_IMPL_H} ${WGSL_GENERATED_TEMPLATES_JS}
203+
SOURCES ${WGSL_TEMPLATE_FILES}
204+
)
205+
206+
if (onnxruntime_WGSL_TEMPLATE STREQUAL "static")
207+
# Add the generated directory to include paths
208+
target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT})
209+
elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic")
210+
add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c")
211+
target_compile_features(duktape_static PRIVATE c_std_99)
212+
target_link_libraries(onnxruntime_providers_webgpu duktape_static)
213+
target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src)
214+
# Define the path to the generated templates.js file
215+
target_compile_definitions(onnxruntime_providers_webgpu PRIVATE
216+
"ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"")
217+
endif()
218+
219+
# Make sure generation happens before building the provider
220+
add_dependencies(onnxruntime_providers_webgpu onnxruntime_webgpu_wgsl_generation)
221+
endif()
222+
120223
set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")

onnxruntime/core/providers/webgpu/shader_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "core/providers/webgpu/shader_variable.h"
1414
#include "core/providers/webgpu/string_utils.h"
1515

16+
#include "core/providers/webgpu/wgsl_templates/wgsl_gen.h"
17+
1618
namespace onnxruntime {
1719
namespace webgpu {
1820

onnxruntime/core/providers/webgpu/tensor/pad.cc

Lines changed: 5 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,73 +19,11 @@ Status PadProgram::GenerateShaderCode(ShaderHelper& shader) const {
1919
}
2020
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);
2121

22-
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
23-
std::string constant_value_str = std::string("let constant_value = ") +
24-
(is_float16_ ? "bitcast<vec2<f16>>(uniforms.constant_value)[0];\n" : "bitcast<output_value_t>(uniforms.constant_value);\n");
25-
if (dim_value_zero_) {
26-
// Only Constant mode needs fill output if the one dim value or mores dims' values of input are zero.
27-
shader.MainFunctionBody() << constant_value_str
28-
<< "output[global_idx] = constant_value;\n";
29-
return Status::OK();
30-
}
31-
32-
shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
33-
<< " var input_index = u32(0);\n"
34-
<< " var use_pad_value = false;\n"
35-
<< " var in_coord = i32(0);\n";
36-
37-
const int rank = output.Rank();
38-
std::string output_indices_str = "i32(" + GetElementAt("output_indices", "dim", rank) + ")";
39-
std::string lower_pads_str = GetElementAt("uniforms.lower_pads", "dim", rank);
40-
std::string data_shape_str = "i32(" + GetElementAt("uniforms.data_shape", "dim", rank) + ")";
41-
std::string data_stride_str = rank == 1 ? "" : " * " + GetElementAt("uniforms.data_stride", "dim", rank - 1);
42-
SS(axis_body_ss, 1024);
43-
switch (mode_) {
44-
case Mode::Constant:
45-
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << " || " << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
46-
<< " use_pad_value = true;\n";
47-
break;
48-
case Mode::Edge:
49-
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << ") {\n"
50-
<< " in_coord = 0;\n"
51-
<< " } else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
52-
<< " in_coord = " << data_shape_str + " - 1;\n";
53-
break;
54-
case Mode::Reflect:
55-
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << " || " << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
56-
<< " in_coord = " << output_indices_str << " - " << lower_pads_str << ";\n"
57-
<< " if (in_coord < 0) {\n"
58-
<< " in_coord = -in_coord;\n"
59-
<< " }\n"
60-
<< " {\n"
61-
<< " let _2n_1 = 2 * (" << data_shape_str << " - 1);\n"
62-
<< " in_coord = in_coord % _2n_1;\n"
63-
<< " if(in_coord >= " << data_shape_str << ") {\n"
64-
<< " in_coord = _2n_1 - in_coord;\n"
65-
<< " }\n"
66-
<< " }\n";
67-
break;
68-
case Mode::Wrap:
69-
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << ") {\n"
70-
<< " in_coord = " << data_shape_str << " + " << output_indices_str << " - " << lower_pads_str + ";\n"
71-
<< " } else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
72-
<< " in_coord = " << output_indices_str << " - " << lower_pads_str << " - " << data_shape_str << ";\n";
73-
break;
74-
default:
75-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported mode type: ", static_cast<int>(mode_));
76-
}
77-
axis_body_ss << " } else {\n"
78-
<< " " << "in_coord = " << output_indices_str << " - " << lower_pads_str << ";\n"
79-
<< " }\n";
80-
81-
shader.MainFunctionBody() << " for (var dim = 0; dim < " << rank << " && !use_pad_value; dim++) {\n"
82-
<< SS_GET(axis_body_ss)
83-
<< " input_index += select(u32(in_coord)" << data_stride_str << ", u32(in_coord), dim == " << rank - 1 << ");\n"
84-
<< " }\n"
85-
<< " " << constant_value_str
86-
<< " " << output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)");
87-
88-
return Status::OK();
22+
return WGSL_TEMPLATE_APPLY(shader, "tensor/pad.wgsl.template",
23+
WGSL_TEMPLATE_PARAMETER(dim_value_zero, dim_value_zero_),
24+
WGSL_TEMPLATE_PARAMETER(is_float16, is_float16_),
25+
WGSL_TEMPLATE_PARAMETER(pad_mode, mode_),
26+
WGSL_TEMPLATE_VARIABLE(output, output));
8927
}
9028

9129
Status Pad::ComputeInternal(ComputeContext& context) const {
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#define PAD_MODE_CONSTANT 0
2+
#define PAD_MODE_REFLECT 1
3+
#define PAD_MODE_EDGE 2
4+
#define PAD_MODE_WRAP 3
5+
6+
#use guardAgainstOutOfBoundsWorkgroupSizes
7+
#use getElementAt
8+
#use .offsetToIndices .setByOffset .rank
9+
10+
#param dim_value_zero
11+
#param is_float16
12+
#param pad_mode
13+
14+
$MAIN {
15+
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size);
16+
17+
let constant_value =
18+
#if is_float16
19+
bitcast<vec2<f16>>(uniforms.constant_value)[0];
20+
#else
21+
bitcast<output_value_t>(uniforms.constant_value);
22+
#endif
23+
24+
#if dim_value_zero
25+
output[global_idx] = constant_value;
26+
#else
27+
let output_indices = output.offsetToIndices(global_idx);
28+
var input_index = u32(0);
29+
var use_pad_value = false;
30+
var in_coord = i32(0);
31+
32+
for (var dim = 0; dim < output.rank && !use_pad_value; dim++) {
33+
let output_index = i32(getElementAt(output_indices, dim, output.rank));
34+
let lower_pads = getElementAt(uniforms.lower_pads, dim, output.rank);
35+
let data_shape = i32(getElementAt(uniforms.data_shape, dim, output.rank));
36+
#if pad_mode == PAD_MODE_CONSTANT
37+
if (output_index < lower_pads || output_index >= data_shape + lower_pads) {
38+
use_pad_value = true;
39+
#elif pad_mode == PAD_MODE_EDGE
40+
if (output_index < lower_pads) {
41+
in_coord = 0;
42+
} else if (output_index >= data_shape + lower_pads) {
43+
in_coord = data_shape - 1;
44+
#elif pad_mode == PAD_MODE_REFLECT
45+
if (output_index < lower_pads || output_index >= data_shape + lower_pads) {
46+
in_coord = output_index - lower_pads;
47+
if (in_coord < 0) {
48+
in_coord = -in_coord;
49+
}
50+
let _2n_1 = 2 * (data_shape - 1);
51+
in_coord = in_coord % _2n_1;
52+
if (in_coord >= data_shape) {
53+
in_coord = _2n_1 - in_coord;
54+
}
55+
#else // PAD_MODE_WRAP
56+
if (output_index < lower_pads) {
57+
in_coord = data_shape + output_index - lower_pads;
58+
} else if (output_index >= data_shape + lower_pads) {
59+
in_coord = output_index - data_shape - lower_pads;
60+
#endif // pad_mode
61+
} else {
62+
in_coord = output_index - lower_pads;
63+
}
64+
65+
input_index += select(u32(in_coord)
66+
#if output.rank > 1
67+
* getElementAt(uniforms.data_stride, dim, output.rank - 1)
68+
#endif
69+
, u32(in_coord), dim == output.rank - 1);
70+
}
71+
72+
output.setByOffset(global_idx, select(data[input_index], constant_value, use_pad_value));
73+
#endif
74+
} // MAIN
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
node_modules/

0 commit comments

Comments
 (0)