@@ -11,14 +11,20 @@ namespace onnxruntime {
11
11
namespace webgpu {
12
12
13
13
Status ExpandProgram::GenerateShaderCode (ShaderHelper& shader) const {
14
- const auto & input = shader.AddInput (" input" , ShaderUsage::UseUniform);
14
+ const auto & input = shader.AddInput (" input" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias );
15
15
const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform);
16
-
17
- shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.data_size" )
18
- << " let output_indices = " << output.OffsetToIndices (" global_idx" ) << " ;\n "
19
- << " let input_offset = " << input.BroadcastedIndicesToOffset (" output_indices" , output) << " ;\n "
20
- << output.SetByOffset (" global_idx" , input.GetByOffset (" input_offset" ));
21
-
16
+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.data_size" );
17
+ if (input.NumComponents () != output.NumComponents ()) {
18
+ const auto & output_indices = shader.AddIndices (" output_indices" );
19
+ shader.MainFunctionBody () << " let output_indices = " << output_indices.OffsetToIndices (" global_idx * 4" ) << " ;\n "
20
+ << " let input_offset = " << input.BroadcastedIndicesToOffset (" output_indices" , output_indices) << " ;\n "
21
+ << " let value = vec4<input_value_t>(" << input.GetByOffset (" input_offset" ) << " );\n "
22
+ << output.SetByOffset (" global_idx" , " value" );
23
+ } else {
24
+ shader.MainFunctionBody () << " let output_indices = " << output.OffsetToIndices (" global_idx" ) << " ;\n "
25
+ << " let input_offset = " << input.BroadcastedIndicesToOffset (" output_indices" , output) << " ;\n "
26
+ << output.SetByOffset (" global_idx" , input.GetByOffset (" input_offset" ));
27
+ }
22
28
return Status::OK ();
23
29
}
24
30
@@ -28,18 +34,27 @@ Status Expand::ComputeInternal(ComputeContext& context) const {
28
34
29
35
auto output_dims = input_shape_tensor->DataAsSpan <int64_t >();
30
36
TensorShape output_shape{};
31
- ORT_RETURN_IF_ERROR (ComputeBroadcastOutputShape (Node ().Name (), input_tensor->Shape (), output_dims, output_shape));
37
+ TensorShape input_shape = input_tensor->Shape ();
38
+ ORT_RETURN_IF_ERROR (ComputeBroadcastOutputShape (Node ().Name (), input_shape, output_dims, output_shape));
32
39
33
40
auto * output_tensor = context.Output (0 , output_shape);
34
- uint32_t data_size = gsl::narrow<uint32_t >(output_shape.Size ());
41
+ const int components_i = input_shape.IsScalar () ? 1 : input_shape[input_shape.NumDimensions () - 1 ] % 4 == 0 ? 4
42
+ : 1 ;
43
+ const int components_o = output_shape.IsScalar () ? 1 : output_shape[output_shape.NumDimensions () - 1 ] % 4 == 0 ? 4
44
+ : 1 ;
45
+ uint32_t data_size = gsl::narrow<uint32_t >(output_shape.Size () / components_o);
46
+
35
47
ExpandProgram program{};
36
48
program
37
- .AddInputs ({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
38
- .AddOutputs ({{output_tensor, ProgramTensorMetadataDependency::Rank }})
49
+ .AddInputs ({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i }})
50
+ .AddOutputs ({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o }})
39
51
.SetDispatchGroupSize ((data_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
40
52
.AddUniformVariables ({
41
53
{data_size},
42
54
});
55
+ if (components_i != components_o) {
56
+ program.AddIndices (output_shape);
57
+ }
43
58
return context.RunProgram (program);
44
59
}
45
60
@@ -55,8 +70,8 @@ Status Expand::ComputeInternal(ComputeContext& context) const {
55
70
KernelDefBuilder ().TypeConstraint(" T" , TYPE).InputMemoryType(OrtMemTypeCPU, 1 ), \
56
71
KERNEL_CLASS);
57
72
58
- WEBGPU_EXPAND_VERSIONED_KERNEL (Expand, 8 , 12 , Expand, WebGpuSupportedFloatTypes ())
59
- WEBGPU_EXPAND_KERNEL(Expand, 13 , Expand, WebGpuSupportedFloatTypes ())
73
+ WEBGPU_EXPAND_VERSIONED_KERNEL (Expand, 8 , 12 , Expand, WebGpuSupportedNumberTypes ())
74
+ WEBGPU_EXPAND_KERNEL(Expand, 13 , Expand, WebGpuSupportedNumberTypes ())
60
75
61
76
} // namespace webgpu
62
77
} // namespace onnxruntime
0 commit comments