Skip to content

Commit 28efacf

Browse files
authored
[MigraphX] Fix potential synchronization problem when ORT_ENABLE_STREAM is true (#22589)
### Description Replace `hipMemcpy` with `hipMemcpyWithStream` ### Motivation and Context `hipMemcpy` uses default stream, which may be out of synchronization with the current stream when ORT_ENABLE_STREAM is defined.
1 parent 7acbd51 commit 28efacf

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

onnxruntime/core/providers/migraphx/gpu_data_transfer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst,
5757
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast<hipStream_t>(stream.GetHandle())));
5858
} else {
5959
// copy from other CPU memory to GPU, this is blocking
60-
HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice));
60+
HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
6161
}
6262
} else if (src_device.Type() == OrtDevice::GPU) {
6363
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast<hipStream_t>(stream.GetHandle())));

onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
14451445
std::vector<int64_t> ort_shape{res_lens.begin(), res_lens.end()};
14461446
auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size());
14471447
void* output_data = output_tensor.GetTensorMutableRawData();
1448-
HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice));
1448+
HIP_CALL_THROW(hipMemcpyWithStream(output_data,
1449+
gpu_res.data(),
1450+
res_shape.bytes(),
1451+
hipMemcpyDeviceToDevice,
1452+
static_cast<hipStream_t>(rocm_stream)));
14491453
}
14501454
}
14511455
};

0 commit comments

Comments
 (0)