Skip to content

Commit 24fab67

Browse files
authored
Clean up getDeviceIndexOfCurrentQueue (#2060)
# Motivation Clean up `getDeviceIndexOfCurrentQueue` since it duplicates `current_device` and provides no additional functionality.
1 parent 6508096 commit 24fab67

File tree

6 files changed

+60
-95
lines changed

6 files changed

+60
-95
lines changed

src/ATen/native/xpu/sycl/BatchNormKernels.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ int get_num_threads_by_dev_max_group_size(
185185
int get_prefer_simd(int numPlane, int nHw) {
186186
// decide SIMD: SIMD32 or SIMD16
187187

188-
auto dev_id = at::xpu::getDeviceIndexOfCurrentQueue();
189-
188+
auto dev_id = at::xpu::current_device();
190189
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
191190
auto sub_group_size = dev_prop->sub_group_sizes;
192191
int simd = sub_group_size[1];

src/ATen/native/xpu/sycl/Norm.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,7 @@ class NormConfig {
269269
}
270270

271271
void get_max_vec_size() {
272-
auto dev_id = getDeviceIndexOfCurrentQueue();
273-
int total_resource = syclMaxWorkItemsPerTile(dev_id);
272+
int64_t total_resource = syclMaxWorkItemsPerTile();
274273

275274
constexpr int float4_size = sizeof(float) * 4;
276275
max_vec_size = float4_size / element_size_bytes;

src/ATen/native/xpu/sycl/SoftMaxKernels.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,8 +1559,7 @@ void spatial_softmax_forward(
15591559
canUse32BitIndexMath(input) && canUse32BitIndexMath(output);
15601560

15611561
// decide SIMD: SIMD32 or SIMD16
1562-
auto dev_id = at::xpu::getDeviceIndexOfCurrentQueue();
1563-
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
1562+
auto* dev_prop = at::xpu::getCurrentDeviceProperties();
15641563
auto sub_group_size = dev_prop->sub_group_sizes;
15651564
int SIMD = sub_group_size[1];
15661565
if (SIMD == SIMD32) {
@@ -1749,8 +1748,7 @@ void spatial_softmax_backward(
17491748
canUse32BitIndexMath(output) && canUse32BitIndexMath(gradOutput);
17501749

17511750
// decide SIMD: SIMD32 or SIMD16
1752-
auto* dev_prop =
1753-
at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue());
1751+
auto* dev_prop = at::xpu::getCurrentDeviceProperties();
17541752
auto sub_group_size = dev_prop->sub_group_sizes;
17551753
int SIMD = sub_group_size[1];
17561754
if (SIMD == SIMD32) {
@@ -1901,8 +1899,7 @@ Tensor& masked_softmax_forward(
19011899
canUse32BitIndexMath(input) && canUse32BitIndexMath(output);
19021900

19031901
// decide SIMD: SIMD32 or SIMD16
1904-
auto* dev_prop =
1905-
at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue());
1902+
auto* dev_prop = at::xpu::getCurrentDeviceProperties();
19061903
auto sub_group_size = dev_prop->sub_group_sizes;
19071904
int SIMD = sub_group_size[1];
19081905
if (SIMD == SIMD32) {
@@ -2026,8 +2023,7 @@ void masked_softmax_backward(
20262023
canUse32BitIndexMath(output) && canUse32BitIndexMath(gradOutput);
20272024

20282025
// decide SIMD: SIMD32 or SIMD16
2029-
auto* dev_prop =
2030-
at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue());
2026+
auto* dev_prop = at::xpu::getCurrentDeviceProperties();
20312027
auto sub_group_size = dev_prop->sub_group_sizes;
20322028
int SIMD = sub_group_size[1];
20332029
if (SIMD == SIMD32) {

src/ATen/native/xpu/sycl/TensorShapeKernels.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,7 @@ void split_with_sizes_copy_out_xpu_contiguous_no_cast(
669669
num_groups += div_up(split_chunk_size, GROUP_SIZE * BYTES_PER_THREAD);
670670
}
671671

672-
auto dev_id = getDeviceIndexOfCurrentQueue();
673-
int64_t tile_size = syclMaxWorkItemsPerTile(dev_id);
672+
int64_t tile_size = syclMaxWorkItemsPerTile();
674673
const int64_t max_groups = tile_size / GROUP_SIZE * 2.0;
675674

676675
// Make each thread process BYTES_PER_THREAD * iter_factor bytes to regulate

src/comm/DeviceProperties.h

Lines changed: 53 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
#include <ATen/xpu/XPUContext.h>
44

55
#include <comm/Runtime.h>
6-
#include <iostream>
76

87
namespace xpu {
98
namespace sycl {
109

1110
template <class KernelClass>
1211
static int64_t syclMaxWorkGroupSize(
13-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
14-
auto q = c10::xpu::getCurrentXPUStream(dev_id).queue();
15-
auto ctx = q.get_context();
16-
auto dev = q.get_device();
12+
at::DeviceIndex dev_id = at::xpu::current_device()) {
13+
auto& ctx = c10::xpu::get_device_context();
14+
auto& dev = c10::xpu::get_raw_device(dev_id);
1715

1816
auto kid = ::sycl::get_kernel_id<KernelClass>();
1917
// The kernel won't be built for devices except for the first device.
@@ -30,73 +28,69 @@ static int64_t syclMaxWorkGroupSize(
3028

3129
template <class KernelClass>
3230
static int64_t syclMaxWorkGroupSize(
33-
KernelClass /*kfn*/,
34-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
31+
const KernelClass& /*kfn*/,
32+
at::DeviceIndex dev_id = at::xpu::current_device()) {
3533
return syclMaxWorkGroupSize<KernelClass>(dev_id);
3634
}
3735

3836
static inline int64_t syclDeviceMaxWorkGroupSize(
39-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
37+
at::DeviceIndex dev_id = at::xpu::current_device()) {
4038
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
4139
return dev_prop->max_work_group_size;
4240
}
4341

4442
static inline int64_t syclMaxSubGroupSize(
45-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
43+
at::DeviceIndex dev_id = at::xpu::current_device()) {
4644
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
47-
auto subgroup_sizes = dev_prop->sub_group_sizes;
48-
uint64_t max_val = 0;
49-
for (auto i : subgroup_sizes) {
50-
if (i > max_val)
51-
max_val = i;
52-
}
53-
return max_val;
45+
const auto& subgroup_sizes = dev_prop->sub_group_sizes;
46+
TORCH_CHECK(
47+
!subgroup_sizes.empty(),
48+
"The device subgroup sizes is empty, please check the device status.");
49+
return *std::max_element(subgroup_sizes.begin(), subgroup_sizes.end());
5450
}
5551

5652
static inline int64_t syclMinSubGroupSize(
57-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
53+
at::DeviceIndex dev_id = at::xpu::current_device()) {
5854
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
59-
auto subgroup_sizes = dev_prop->sub_group_sizes;
60-
uint64_t min_val = dev_prop->max_work_group_size;
61-
for (auto i : subgroup_sizes) {
62-
if (i < min_val)
63-
min_val = i;
64-
}
65-
return min_val;
55+
const auto& subgroup_sizes = dev_prop->sub_group_sizes;
56+
TORCH_CHECK(
57+
!subgroup_sizes.empty(),
58+
"The device subgroup sizes is empty, please check the device status.");
59+
return *std::min_element(subgroup_sizes.begin(), subgroup_sizes.end());
6660
}
6761

6862
static inline int64_t syclMaxComputeUnitSize(
69-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
63+
at::DeviceIndex dev_id = at::xpu::current_device()) {
7064
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
7165
return dev_prop->max_compute_units;
7266
}
7367

7468
static inline int64_t syclGpuEuCount(
75-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
69+
at::DeviceIndex dev_id = at::xpu::current_device()) {
7670
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
7771
return dev_prop->gpu_eu_count;
7872
}
7973

8074
static inline int64_t syclGpuEuSimdWidth(
81-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
75+
at::DeviceIndex dev_id = at::xpu::current_device()) {
8276
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
8377
return dev_prop->gpu_eu_simd_width;
8478
}
8579

8680
static inline int64_t syclGpuHWThreadsPerEU(
87-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
81+
at::DeviceIndex dev_id = at::xpu::current_device()) {
8882
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
8983
return dev_prop->gpu_hw_threads_per_eu;
9084
}
9185

9286
static inline int64_t syclGpuEUCountPerSubslice(
93-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
87+
at::DeviceIndex dev_id = at::xpu::current_device()) {
9488
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
9589
return dev_prop->gpu_eu_count_per_subslice;
9690
}
9791

9892
static inline int64_t syclMaxWorkItemsPerTile(
99-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
93+
at::DeviceIndex dev_id = at::xpu::current_device()) {
10094
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
10195
int64_t eu_cnt = dev_prop->gpu_eu_count;
10296
int64_t simd_width = syclMaxSubGroupSize(dev_id);
@@ -105,110 +99,92 @@ static inline int64_t syclMaxWorkItemsPerTile(
10599
}
106100

107101
static inline int64_t syclMaxWorkItemsPerSubSlice(
108-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
102+
at::DeviceIndex dev_id = at::xpu::current_device()) {
109103
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
110104
int64_t simd_width = syclMaxSubGroupSize(dev_id);
111105
int64_t eu_count = dev_prop->gpu_eu_count_per_subslice;
112106
return simd_width * eu_count;
113107
}
114108

115109
static inline int64_t syclMaxWorkItemsPerEU(
116-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
110+
at::DeviceIndex dev_id = at::xpu::current_device()) {
117111
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
118112
int64_t simd_width = syclMaxSubGroupSize(dev_id);
119113
int64_t hw_threads = dev_prop->gpu_hw_threads_per_eu;
120114
return simd_width * hw_threads;
121115
}
122116

123117
static inline int64_t syclMaxNumSubGroups(
124-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
118+
at::DeviceIndex dev_id = at::xpu::current_device()) {
125119
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
126120
return dev_prop->max_num_sub_groups;
127121
}
128122

129123
static inline int64_t syclMaxDSSNum(
130-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
124+
at::DeviceIndex dev_id = at::xpu::current_device()) {
131125
int64_t dss_num =
132126
syclMaxComputeUnitSize(dev_id) / syclGpuEUCountPerSubslice(dev_id);
133127
return dss_num;
134128
}
135129

136130
static inline size_t syclGlobalMemSize(
137-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
131+
at::DeviceIndex dev_id = at::xpu::current_device()) {
138132
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
139133
return dev_prop->global_mem_size;
140134
}
141135

142136
static inline int64_t syclLocalMemSize(
143-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
137+
at::DeviceIndex dev_id = at::xpu::current_device()) {
144138
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
145139
return dev_prop->local_mem_size;
146140
}
147141

148142
template <typename T>
149143
uint32_t syclPrefVectorWidth(
150-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
144+
at::DeviceIndex dev_id = at::xpu::current_device()) {
151145
(void)dev_id; // Suppress unused variable warning
152146

153147
// Hot fix. This is the preferred vector width for GPUs up to LNL/BMG.
154-
uint32_t vec_width = 16;
148+
constexpr uint32_t vec_width = 16;
155149

156-
if (std::is_same<T, char>::value) {
157-
return vec_width / sizeof(char);
158-
}
159-
if (std::is_same<T, short>::value) {
160-
return vec_width / sizeof(short);
161-
}
162-
if (std::is_same<T, int>::value) {
163-
return vec_width / sizeof(int);
164-
}
165-
if (std::is_same<T, int64_t>::value) {
166-
return vec_width / sizeof(int64_t);
167-
}
168-
if (std::is_same<T, float>::value) {
169-
return vec_width / sizeof(float);
170-
}
171-
if (std::is_same<T, double>::value) {
172-
return vec_width / sizeof(double);
150+
if constexpr (
151+
std::is_same_v<T, char> || std::is_same_v<T, short> ||
152+
std::is_same_v<T, int> || std::is_same_v<T, int64_t> ||
153+
std::is_same_v<T, float> || std::is_same_v<T, double> ||
154+
std::is_same_v<T, ::sycl::half>) {
155+
return vec_width / sizeof(T);
156+
} else {
157+
throw std::invalid_argument(
158+
"Invalid data type to fetch preferred vector width!");
173159
}
174-
if (std::is_same<T, ::sycl::half>::value) {
175-
return vec_width / sizeof(::sycl::half);
176-
}
177-
throw std::invalid_argument(
178-
"Invalid data type to fetch preferred vector width!");
179160
}
180161

181162
template <typename T>
182163
uint32_t syclNativeVectorWidth(
183-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
164+
at::DeviceIndex dev_id = at::xpu::current_device()) {
184165
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
185-
if (std::is_same<T, char>::value) {
166+
if constexpr (std::is_same_v<T, char>) {
186167
return dev_prop->native_vector_width_char;
187-
}
188-
if (std::is_same<T, short>::value) {
168+
} else if constexpr (std::is_same_v<T, short>) {
189169
return dev_prop->native_vector_width_short;
190-
}
191-
if (std::is_same<T, int>::value) {
170+
} else if constexpr (std::is_same_v<T, int>) {
192171
return dev_prop->native_vector_width_int;
193-
}
194-
if (std::is_same<T, int64_t>::value) {
172+
} else if constexpr (std::is_same_v<T, int64_t>) {
195173
return dev_prop->native_vector_width_long;
196-
}
197-
if (std::is_same<T, float>::value) {
174+
} else if constexpr (std::is_same_v<T, float>) {
198175
return dev_prop->native_vector_width_float;
199-
}
200-
if (std::is_same<T, double>::value) {
176+
} else if constexpr (std::is_same_v<T, double>) {
201177
return dev_prop->native_vector_width_double;
202-
}
203-
if (std::is_same<T, ::sycl::half>::value) {
178+
} else if constexpr (std::is_same_v<T, ::sycl::half>) {
204179
return dev_prop->native_vector_width_half;
180+
} else {
181+
throw std::invalid_argument(
182+
"Invalid data type to fetch native vector width!");
205183
}
206-
throw std::invalid_argument(
207-
"Invalid data type to fetch native vector width!");
208184
}
209185

210186
static inline bool syclHasFloat64(
211-
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
187+
at::DeviceIndex dev_id = at::xpu::current_device()) {
212188
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
213189
return dev_prop->has_fp64;
214190
}

src/comm/Runtime.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44

55
namespace at::xpu {
66

7-
static inline at::DeviceIndex getDeviceIndexOfCurrentQueue() {
8-
return c10::xpu::getCurrentXPUStream().device_index();
9-
}
10-
117
static inline sycl::queue& getCurrentSYCLQueue() {
128
return c10::xpu::getCurrentXPUStream().queue();
139
}

0 commit comments

Comments
 (0)