Skip to content

Commit 0757149

Browse files
vasunvidiaksivaman
andauthored
Fix ring_exchange RS to support CUDA graph capture (#811)
Signed-off-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 816dd45 commit 0757149

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

transformer_engine/pytorch/csrc/comm_gemm_overlap.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
9999
}
100100
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
101101

102-
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
102+
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
103103
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
104104
cudaStream_t stream;
105105
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
@@ -596,7 +596,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
596596
ubuf_byte_ptr += ubuf_chunk_bytes;
597597
}
598598

599-
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
599+
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
600600
for (int i = 0; i < std::min(num_max_streams, tp_size); i++) {
601601
cudaStream_t stream;
602602
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
@@ -691,7 +691,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
691691
assert(pre_gelu_out.numel() == 0);
692692

693693
// Catch up the default torch stream
694-
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
694+
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
695695
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
696696
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
697697
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
@@ -974,7 +974,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
974974
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
975975

976976
// Catch up the main stream
977-
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
977+
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
978978
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
979979
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
980980

@@ -1055,8 +1055,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
10551055
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
10561056

10571057
// Catch up the main stream
1058-
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
1058+
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
10591059
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
1060+
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
1061+
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
10601062
for (int i = 0; i < _stream_compute.size(); i++) {
10611063
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0));
10621064
}
@@ -1113,13 +1115,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
11131115
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
11141116
torch::sum_out(rs_output, reduce_buf, 0);
11151117
}
1118+
for (int i = 0; i < _stream_compute.size(); i++) {
1119+
CHECK_CUDA(
1120+
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
1121+
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
1122+
}
1123+
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
1124+
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
11161125
}
11171126

11181127
/*
11191128
** Copy input to _ubufs[0]
11201129
*/
11211130
void copy_input_to_ubuf(torch::Tensor input, bool chunk) {
1122-
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
1131+
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
11231132
if (chunk) {
11241133
// Copy input to the target ubuf chunk by rank offset
11251134
if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) {

0 commit comments

Comments
 (0)