Skip to content

Commit 824c75f

Browse files
committed
resolve comments (part1)
1 parent cbb050d commit 824c75f

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
168168
const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0;
169169
const bool has_present_key = output_count > 1 && past_key;
170170
const bool has_attention_bias = attention_bias != nullptr;
171-
const int tile_size = 12;
171+
constexpr int tile_size = 12;
172172
const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1);
173173

174174
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,

onnxruntime/core/providers/webgpu/shader_variable.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace onnxruntime {
1414
namespace webgpu {
1515

1616
namespace {
17-
constexpr static const std::string_view STORAGE_TYPE[] = {
17+
constexpr static const std::string_view STORAGE_TYPE_ARRAY[] = {
1818
"f32", // Float32
1919
"vec2<f32>", // Float32x2
2020
"vec4<f32>", // Float32x4
@@ -34,8 +34,9 @@ constexpr static const std::string_view STORAGE_TYPE[] = {
3434
"vec2<u32>", // Uint8x8
3535
"vec4<u32>", // Uint8x16
3636
};
37+
constexpr static const auto STORAGE_TYPE = details::_to_std_array(STORAGE_TYPE_ARRAY);
3738

38-
constexpr static const std::string_view VALUE_TYPE[] = {
39+
constexpr static const std::string_view VALUE_TYPE_ARRAY[] = {
3940
"f32", // Float32
4041
"vec2<f32>", // Float32x2
4142
"vec4<f32>", // Float32x4
@@ -55,8 +56,9 @@ constexpr static const std::string_view VALUE_TYPE[] = {
5556
"vec2<u32>", // Uint8x8 (vec2<u32> as 2x4 elements of uint8)
5657
"vec4<u32>", // Uint8x16 (vec4<u32> as 4x4 elements of uint8)
5758
};
59+
constexpr static const auto VALUE_TYPE = details::_to_std_array(VALUE_TYPE_ARRAY);
5860

59-
constexpr static const std::string_view ELEMENT_TYPE[] = {
61+
constexpr static const std::string_view ELEMENT_TYPE_ARRAY[] = {
6062
"f32", // Float32
6163
"f32", // Float32x2
6264
"f32", // Float32x4
@@ -76,6 +78,7 @@ constexpr static const std::string_view ELEMENT_TYPE[] = {
7678
"u32", // Uint8x8
7779
"u32", // Uint8x16
7880
};
81+
constexpr static const auto ELEMENT_TYPE = details::_to_std_array(ELEMENT_TYPE_ARRAY);
7982

8083
inline std::string GetIndicesType(int rank) {
8184
return rank < 2 ? "u32"

onnxruntime/core/providers/webgpu/tensor/where.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ Status ComputeOutputShape(const TensorShape& cond_shape,
3131
if (i < y_rank)
3232
y_dim = y_shape[y_rank - 1 - i];
3333

34-
int64_t output_dim = std::max(std::max(cond_dim, x_dim), y_dim);
34+
int64_t output_dim = std::max({cond_dim, x_dim, y_dim});
3535
// special case to handle a dim of 0 which can be broadcast with a 1
3636
if (output_dim == 1)
37-
output_dim = std::min(std::min(cond_dim, x_dim), y_dim);
37+
output_dim = std::min({cond_dim, x_dim, y_dim});
3838

3939
const auto node_name = "Where";
4040
if (cond_dim != output_dim && cond_dim != 1)

0 commit comments

Comments
 (0)