Skip to content

Commit 6ef13e3

Browse files
authored
[webgpu] extend cast version to 23 (#25235)
1 parent a532c8a commit 6ef13e3

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

onnxruntime/core/providers/webgpu/tensor/cast.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,28 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
5252
.TypeConstraint("T1", CastOpTypeConstraints())
5353
.TypeConstraint("T2", CastOpTypeConstraints()),
5454
Cast);
55+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
56+
Cast,
57+
kOnnxDomain,
58+
19, 20,
59+
kWebGpuExecutionProvider,
60+
(*KernelDefBuilder::Create())
61+
.TypeConstraint("T1", CastOpTypeConstraints())
62+
.TypeConstraint("T2", CastOpTypeConstraints()),
63+
Cast);
64+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
65+
Cast,
66+
kOnnxDomain,
67+
21, 22,
68+
kWebGpuExecutionProvider,
69+
(*KernelDefBuilder::Create())
70+
.TypeConstraint("T1", CastOpTypeConstraints())
71+
.TypeConstraint("T2", CastOpTypeConstraints()),
72+
Cast);
5573
ONNX_OPERATOR_KERNEL_EX(
5674
Cast,
5775
kOnnxDomain,
58-
19,
76+
23,
5977
kWebGpuExecutionProvider,
6078
(*KernelDefBuilder::Create())
6179
.TypeConstraint("T1", CastOpTypeConstraints())

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
123123
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast);
124124
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast);
125125
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast);
126-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast);
126+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Cast);
127+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Cast);
128+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Cast);
127129

128130
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip);
129131
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip);
@@ -455,7 +457,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
455457
KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast),
456458
KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast),
457459
KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast),
458-
KERNEL_CREATE_INFO(19, Cast),
460+
KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast),
461+
KERNEL_CREATE_INFO_VERSIONED(21, 22, Cast),
462+
KERNEL_CREATE_INFO(23, Cast),
459463

460464
// // activations
461465
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip)>,

0 commit comments

Comments
 (0)