Skip to content

Commit 3c604eb

Browse files
ksivamanptrendx
authored andcommitted
Avoid amax roll for non-run modules (#825)
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 9f0a4a4 commit 3c604eb

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

transformer_engine/common/recipe/delayed_scaling.cu

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,16 +197,18 @@ kernel_bulk(
197197
const auto last_amax = ((amax_reduction_buffer != nullptr)
198198
&& (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ?
199199
amax_reduction_buffer[offset_in_buffer+count] : amax_history[0];
200-
for (size_t off = 0; off < length; off += bsize) {
201-
const size_t i = off + tid;
202-
float a = 0;
203-
if (i < length) {
204-
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
205-
amax = fmaxf(amax, a);
206-
}
207-
__syncthreads(); // Inplace roll
208-
if (i < length) {
209-
amax_history[i*stride] = (i > 0) ? a : 0;
200+
if (last_amax != 0.0f) {
201+
for (size_t off = 0; off < length; off += bsize) {
202+
const size_t i = off + tid;
203+
float a = 0;
204+
if (i < length) {
205+
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
206+
amax = fmaxf(amax, a);
207+
}
208+
__syncthreads(); // Inplace roll
209+
if (i < length) {
210+
amax_history[i*stride] = (i > 0) ? a : 0;
211+
}
210212
}
211213
}
212214

0 commit comments

Comments
 (0)