diff --git a/timm/data/loader.py b/timm/data/loader.py index 313f33efe4..b6804f6fb5 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -123,10 +123,10 @@ def __init__( def __iter__(self): first = True if self.is_cuda: - stream = torch.cuda.Stream() + stream = torch.cuda.Stream(device=self.device) stream_context = partial(torch.cuda.stream, stream=stream) elif self.is_npu: - stream = torch.npu.Stream() + stream = torch.npu.Stream(device=self.device) stream_context = partial(torch.npu.stream, stream=stream) else: stream = None @@ -148,9 +148,9 @@ def __iter__(self): if stream is not None: if self.is_cuda: - torch.cuda.current_stream().wait_stream(stream) + torch.cuda.current_stream(device=self.device).wait_stream(stream) elif self.is_npu: - torch.npu.current_stream().wait_stream(stream) + torch.npu.current_stream(device=self.device).wait_stream(stream) input = next_input target = next_target diff --git a/timm/data/naflex_loader.py b/timm/data/naflex_loader.py index d615bd63f9..7c6509d311 100644 --- a/timm/data/naflex_loader.py +++ b/timm/data/naflex_loader.py @@ -91,10 +91,10 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]: """ first = True if self.is_cuda: - stream = torch.cuda.Stream() + stream = torch.cuda.Stream(device=self.device) stream_context = partial(torch.cuda.stream, stream=stream) elif self.is_npu: - stream = torch.npu.Stream() + stream = torch.npu.Stream(device=self.device) stream_context = partial(torch.npu.stream, stream=stream) else: stream = None @@ -152,9 +152,9 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]: if stream is not None: if self.is_cuda: - torch.cuda.current_stream().wait_stream(stream) + torch.cuda.current_stream(device=self.device).wait_stream(stream) elif self.is_npu: - torch.npu.current_stream().wait_stream(stream) + torch.npu.current_stream(device=self.device).wait_stream(stream) input_dict = next_input_dict target = next_target