3
3
#include < ATen/xpu/XPUContext.h>
4
4
5
5
#include < comm/Runtime.h>
6
- #include < iostream>
7
6
8
7
namespace xpu {
9
8
namespace sycl {
10
9
11
10
template <class KernelClass >
12
11
static int64_t syclMaxWorkGroupSize (
13
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
14
- auto q = c10::xpu::getCurrentXPUStream (dev_id).queue ();
15
- auto ctx = q.get_context ();
16
- auto dev = q.get_device ();
12
+ at::DeviceIndex dev_id = at::xpu::current_device()) {
13
+ auto & ctx = c10::xpu::get_device_context ();
14
+ auto & dev = c10::xpu::get_raw_device (dev_id);
17
15
18
16
auto kid = ::sycl::get_kernel_id<KernelClass>();
19
17
// The kernel won't be built for devices except for the first device.
@@ -30,73 +28,69 @@ static int64_t syclMaxWorkGroupSize(
30
28
31
29
template <class KernelClass >
32
30
static int64_t syclMaxWorkGroupSize (
33
- KernelClass /* kfn*/ ,
34
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
31
+ const KernelClass& /* kfn*/ ,
32
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
35
33
return syclMaxWorkGroupSize<KernelClass>(dev_id);
36
34
}
37
35
38
36
static inline int64_t syclDeviceMaxWorkGroupSize (
39
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
37
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
40
38
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
41
39
return dev_prop->max_work_group_size ;
42
40
}
43
41
44
42
static inline int64_t syclMaxSubGroupSize (
45
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
43
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
46
44
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
47
- auto subgroup_sizes = dev_prop->sub_group_sizes ;
48
- uint64_t max_val = 0 ;
49
- for (auto i : subgroup_sizes) {
50
- if (i > max_val)
51
- max_val = i;
52
- }
53
- return max_val;
45
+ const auto & subgroup_sizes = dev_prop->sub_group_sizes ;
46
+ TORCH_CHECK (
47
+ !subgroup_sizes.empty (),
48
+ " The device subgroup sizes is empty, please check the device status." );
49
+ return *std::max_element (subgroup_sizes.begin (), subgroup_sizes.end ());
54
50
}
55
51
56
52
static inline int64_t syclMinSubGroupSize (
57
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
53
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
58
54
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
59
- auto subgroup_sizes = dev_prop->sub_group_sizes ;
60
- uint64_t min_val = dev_prop->max_work_group_size ;
61
- for (auto i : subgroup_sizes) {
62
- if (i < min_val)
63
- min_val = i;
64
- }
65
- return min_val;
55
+ const auto & subgroup_sizes = dev_prop->sub_group_sizes ;
56
+ TORCH_CHECK (
57
+ !subgroup_sizes.empty (),
58
+ " The device subgroup sizes is empty, please check the device status." );
59
+ return *std::min_element (subgroup_sizes.begin (), subgroup_sizes.end ());
66
60
}
67
61
68
62
static inline int64_t syclMaxComputeUnitSize (
69
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
63
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
70
64
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
71
65
return dev_prop->max_compute_units ;
72
66
}
73
67
74
68
static inline int64_t syclGpuEuCount (
75
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
69
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
76
70
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
77
71
return dev_prop->gpu_eu_count ;
78
72
}
79
73
80
74
static inline int64_t syclGpuEuSimdWidth (
81
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
75
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
82
76
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
83
77
return dev_prop->gpu_eu_simd_width ;
84
78
}
85
79
86
80
static inline int64_t syclGpuHWThreadsPerEU (
87
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
81
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
88
82
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
89
83
return dev_prop->gpu_hw_threads_per_eu ;
90
84
}
91
85
92
86
static inline int64_t syclGpuEUCountPerSubslice (
93
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
87
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
94
88
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
95
89
return dev_prop->gpu_eu_count_per_subslice ;
96
90
}
97
91
98
92
static inline int64_t syclMaxWorkItemsPerTile (
99
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
93
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
100
94
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
101
95
int64_t eu_cnt = dev_prop->gpu_eu_count ;
102
96
int64_t simd_width = syclMaxSubGroupSize (dev_id);
@@ -105,110 +99,92 @@ static inline int64_t syclMaxWorkItemsPerTile(
105
99
}
106
100
107
101
static inline int64_t syclMaxWorkItemsPerSubSlice (
108
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
102
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
109
103
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
110
104
int64_t simd_width = syclMaxSubGroupSize (dev_id);
111
105
int64_t eu_count = dev_prop->gpu_eu_count_per_subslice ;
112
106
return simd_width * eu_count;
113
107
}
114
108
115
109
static inline int64_t syclMaxWorkItemsPerEU (
116
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
110
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
117
111
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
118
112
int64_t simd_width = syclMaxSubGroupSize (dev_id);
119
113
int64_t hw_threads = dev_prop->gpu_hw_threads_per_eu ;
120
114
return simd_width * hw_threads;
121
115
}
122
116
123
117
static inline int64_t syclMaxNumSubGroups (
124
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
118
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
125
119
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
126
120
return dev_prop->max_num_sub_groups ;
127
121
}
128
122
129
123
static inline int64_t syclMaxDSSNum (
130
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
124
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
131
125
int64_t dss_num =
132
126
syclMaxComputeUnitSize (dev_id) / syclGpuEUCountPerSubslice (dev_id);
133
127
return dss_num;
134
128
}
135
129
136
130
static inline size_t syclGlobalMemSize (
137
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
131
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
138
132
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
139
133
return dev_prop->global_mem_size ;
140
134
}
141
135
142
136
static inline int64_t syclLocalMemSize (
143
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
137
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
144
138
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
145
139
return dev_prop->local_mem_size ;
146
140
}
147
141
148
142
template <typename T>
149
143
uint32_t syclPrefVectorWidth (
150
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
144
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
151
145
(void )dev_id; // Suppress unused variable warning
152
146
153
147
// Hot fix. This is the preferred vector width for GPUs up to LNL/BMG.
154
- uint32_t vec_width = 16 ;
148
+ constexpr uint32_t vec_width = 16 ;
155
149
156
- if (std::is_same<T, char >::value) {
157
- return vec_width / sizeof (char );
158
- }
159
- if (std::is_same<T, short >::value) {
160
- return vec_width / sizeof (short );
161
- }
162
- if (std::is_same<T, int >::value) {
163
- return vec_width / sizeof (int );
164
- }
165
- if (std::is_same<T, int64_t >::value) {
166
- return vec_width / sizeof (int64_t );
167
- }
168
- if (std::is_same<T, float >::value) {
169
- return vec_width / sizeof (float );
170
- }
171
- if (std::is_same<T, double >::value) {
172
- return vec_width / sizeof (double );
150
+ if constexpr (
151
+ std::is_same_v<T, char > || std::is_same_v<T, short > ||
152
+ std::is_same_v<T, int > || std::is_same_v<T, int64_t > ||
153
+ std::is_same_v<T, float > || std::is_same_v<T, double > ||
154
+ std::is_same_v<T, ::sycl::half>) {
155
+ return vec_width / sizeof (T);
156
+ } else {
157
+ throw std::invalid_argument (
158
+ " Invalid data type to fetch preferred vector width!" );
173
159
}
174
- if (std::is_same<T, ::sycl::half>::value) {
175
- return vec_width / sizeof (::sycl::half);
176
- }
177
- throw std::invalid_argument (
178
- " Invalid data type to fetch preferred vector width!" );
179
160
}
180
161
181
162
template <typename T>
182
163
uint32_t syclNativeVectorWidth (
183
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
164
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
184
165
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
185
- if (std::is_same <T, char >::value ) {
166
+ if constexpr (std::is_same_v <T, char >) {
186
167
return dev_prop->native_vector_width_char ;
187
- }
188
- if (std::is_same<T, short >::value) {
168
+ } else if constexpr (std::is_same_v<T, short >) {
189
169
return dev_prop->native_vector_width_short ;
190
- }
191
- if (std::is_same<T, int >::value) {
170
+ } else if constexpr (std::is_same_v<T, int >) {
192
171
return dev_prop->native_vector_width_int ;
193
- }
194
- if (std::is_same<T, int64_t >::value) {
172
+ } else if constexpr (std::is_same_v<T, int64_t >) {
195
173
return dev_prop->native_vector_width_long ;
196
- }
197
- if (std::is_same<T, float >::value) {
174
+ } else if constexpr (std::is_same_v<T, float >) {
198
175
return dev_prop->native_vector_width_float ;
199
- }
200
- if (std::is_same<T, double >::value) {
176
+ } else if constexpr (std::is_same_v<T, double >) {
201
177
return dev_prop->native_vector_width_double ;
202
- }
203
- if (std::is_same<T, ::sycl::half>::value) {
178
+ } else if constexpr (std::is_same_v<T, ::sycl::half>) {
204
179
return dev_prop->native_vector_width_half ;
180
+ } else {
181
+ throw std::invalid_argument (
182
+ " Invalid data type to fetch native vector width!" );
205
183
}
206
- throw std::invalid_argument (
207
- " Invalid data type to fetch native vector width!" );
208
184
}
209
185
210
186
static inline bool syclHasFloat64 (
211
- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
187
+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
212
188
auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
213
189
return dev_prop->has_fp64 ;
214
190
}
0 commit comments