Skip to content

Commit 7ba6ed6

Browse files
qjia7ankitm3k
authored andcommitted
[webgpu] Optimize Expand (microsoft#23052)
### Description <!-- Describe your changes. --> Use components = 4 if possible. This is the webgpu native implementation from microsoft#22752
1 parent 2e970ed commit 7ba6ed6

File tree

1 file changed

+28
-13
lines changed
  • onnxruntime/core/providers/webgpu/tensor

1 file changed

+28
-13
lines changed

onnxruntime/core/providers/webgpu/tensor/expand.cc

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@ namespace onnxruntime {
1111
namespace webgpu {
1212

1313
Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const {
14-
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform);
14+
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
1515
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
16-
17-
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
18-
<< " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
19-
<< " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n "
20-
<< output.SetByOffset("global_idx", input.GetByOffset("input_offset"));
21-
16+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size");
17+
if (input.NumComponents() != output.NumComponents()) {
18+
const auto& output_indices = shader.AddIndices("output_indices");
19+
shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"
20+
<< " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n "
21+
<< " let value = vec4<input_value_t>(" << input.GetByOffset("input_offset") << ");\n"
22+
<< output.SetByOffset("global_idx", "value");
23+
} else {
24+
shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
25+
<< " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n "
26+
<< output.SetByOffset("global_idx", input.GetByOffset("input_offset"));
27+
}
2228
return Status::OK();
2329
}
2430

@@ -28,18 +34,27 @@ Status Expand::ComputeInternal(ComputeContext& context) const {
2834

2935
auto output_dims = input_shape_tensor->DataAsSpan<int64_t>();
3036
TensorShape output_shape{};
31-
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape));
37+
TensorShape input_shape = input_tensor->Shape();
38+
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_shape, output_dims, output_shape));
3239

3340
auto* output_tensor = context.Output(0, output_shape);
34-
uint32_t data_size = gsl::narrow<uint32_t>(output_shape.Size());
41+
const int components_i = input_shape.IsScalar() ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4
42+
: 1;
43+
const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4
44+
: 1;
45+
uint32_t data_size = gsl::narrow<uint32_t>(output_shape.Size() / components_o);
46+
3547
ExpandProgram program{};
3648
program
37-
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
38-
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}})
49+
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}})
50+
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}})
3951
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
4052
.AddUniformVariables({
4153
{data_size},
4254
});
55+
if (components_i != components_o) {
56+
program.AddIndices(output_shape);
57+
}
4358
return context.RunProgram(program);
4459
}
4560

@@ -55,8 +70,8 @@ Status Expand::ComputeInternal(ComputeContext& context) const {
5570
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
5671
KERNEL_CLASS);
5772

58-
WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes())
59-
WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes())
73+
WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedNumberTypes())
74+
WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedNumberTypes())
6075

6176
} // namespace webgpu
6277
} // namespace onnxruntime

0 commit comments

Comments
 (0)