@@ -99,7 +99,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
99
99
}
100
100
_ubuf = torch::from_blob (_ubuf_ptr, {sample.size (0 ), sample.size (1 )}, sample.options ());
101
101
102
- at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream ();
102
+ at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream ();
103
103
for (int i = 0 ; i < std::min (num_max_streams, num_splits); i++) {
104
104
cudaStream_t stream;
105
105
cudaStreamCreateWithPriority (&stream, cudaStreamNonBlocking, -1 );
@@ -596,7 +596,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
596
596
ubuf_byte_ptr += ubuf_chunk_bytes;
597
597
}
598
598
599
- at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream ();
599
+ at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream ();
600
600
for (int i = 0 ; i < std::min (num_max_streams, tp_size); i++) {
601
601
cudaStream_t stream;
602
602
cudaStreamCreateWithPriority (&stream, cudaStreamNonBlocking, -1 );
@@ -691,7 +691,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
691
691
assert (pre_gelu_out.numel () == 0 );
692
692
693
693
// Catch up the default torch stream
694
- at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream ();
694
+ at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream ();
695
695
CHECK_CUDA (cudaEventRecord (_start_compute, (cudaStream_t)stream_main));
696
696
CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t)_stream_send, _start_compute, 0 ));
697
697
CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t)_stream_recv, _start_compute, 0 ));
@@ -974,7 +974,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
974
974
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
975
975
976
976
// Catch up the main stream
977
- at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream ();
977
+ at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream ();
978
978
CHECK_CUDA (cudaEventRecord (_start_compute, (cudaStream_t)stream_main));
979
979
CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t)_stream_recv, _start_compute, 0 ));
980
980
@@ -1055,8 +1055,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
1055
1055
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
1056
1056
1057
1057
// Catch up the main stream
1058
- at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream ();
1058
+ at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream ();
1059
1059
CHECK_CUDA (cudaEventRecord (_start_compute, (cudaStream_t)stream_main));
1060
+ CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t)_stream_send, _start_compute, 0 ));
1061
+ CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t)_stream_recv, _start_compute, 0 ));
1060
1062
for (int i = 0 ; i < _stream_compute.size (); i++) {
1061
1063
CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t) _stream_compute[i], _start_compute, 0 ));
1062
1064
}
@@ -1113,13 +1115,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
1113
1115
reduce_buf_ptr, {_tp_size, _ubufs[0 ].size (0 ), _ubufs[0 ].size (1 )}, _ubuf.options ());
1114
1116
torch::sum_out (rs_output, reduce_buf, 0 );
1115
1117
}
1118
+ for (int i = 0 ; i < _stream_compute.size (); i++) {
1119
+ CHECK_CUDA (
1120
+ cudaEventRecord (_stop_compute, (cudaStream_t)_stream_compute[i]));
1121
+ CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t)stream_main, _stop_compute, 0 ));
1122
+ }
1123
+ CHECK_CUDA (cudaEventRecord (_stop_send, (cudaStream_t)_stream_send));
1124
+ CHECK_CUDA (cudaStreamWaitEvent ((cudaStream_t)stream_main, _stop_send, 0 ));
1116
1125
}
1117
1126
1118
1127
/*
1119
1128
** Copy input to _ubufs[0]
1120
1129
*/
1121
1130
void copy_input_to_ubuf (torch::Tensor input, bool chunk) {
1122
- at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream ();
1131
+ at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream ();
1123
1132
if (chunk) {
1124
1133
// Copy input to the target ubuf chunk by rank offset
1125
1134
if (input.numel () != _ubufs[0 ].numel () || input.element_size () != _ubufs[0 ].element_size ()) {
0 commit comments