84
84
from torch .profiler import ProfilerActivity , profile , record_function
85
85
from torch .utils ._contextlib import _DecoratorContextManager
86
86
from triton .runtime .cache import FileCacheManager
87
- from video_reader import PyVideoReader
88
87
89
88
logger = logging .getLogger (__name__ )
90
89
@@ -758,17 +757,24 @@ def load_image(
758
757
759
758
def load_video (video_file : Union [str , bytes ], use_gpu : bool = True ):
760
759
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
761
- from video_reader import PyVideoReader
760
+ from decord import VideoReader , cpu , gpu
761
+
762
+ try :
763
+ from decord .bridge import decord_bridge
764
+
765
+ ctx = gpu (0 )
766
+ _ = decord_bridge .get_ctx_device (ctx )
767
+ except Exception :
768
+ ctx = cpu (0 )
762
769
763
- device = "cuda" if use_gpu and torch .cuda .is_available () else None
764
770
tmp_file = None
765
771
vr = None
766
772
try :
767
773
if isinstance (video_file , bytes ):
768
774
tmp_file = tempfile .NamedTemporaryFile (delete = False , suffix = ".mp4" )
769
775
tmp_file .write (video_file )
770
776
tmp_file .close ()
771
- vr = PyVideoReader (tmp_file .name , device = device , threads = 0 )
777
+ vr = VideoReader (tmp_file .name , ctx = ctx )
772
778
elif isinstance (video_file , str ):
773
779
if video_file .startswith (("http://" , "https://" )):
774
780
timeout = int (os .getenv ("REQUEST_TIMEOUT" , "10" ))
@@ -778,22 +784,22 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
778
784
for chunk in response .iter_content (chunk_size = 8192 ):
779
785
tmp_file .write (chunk )
780
786
tmp_file .close ()
781
- vr = PyVideoReader (tmp_file .name , device = device , threads = 0 )
787
+ vr = VideoReader (tmp_file .name , ctx = ctx )
782
788
elif video_file .startswith ("data:" ):
783
789
_ , encoded = video_file .split ("," , 1 )
784
790
video_bytes = base64 .b64decode (encoded )
785
791
tmp_file = tempfile .NamedTemporaryFile (delete = False , suffix = ".mp4" )
786
792
tmp_file .write (video_bytes )
787
793
tmp_file .close ()
788
- vr = PyVideoReader (tmp_file .name , device = device , threads = 0 )
794
+ vr = VideoReader (tmp_file .name , ctx = ctx )
789
795
elif os .path .isfile (video_file ):
790
- vr = PyVideoReader (video_file , device = device , threads = 0 )
796
+ vr = VideoReader (video_file , ctx = ctx )
791
797
else :
792
798
video_bytes = base64 .b64decode (video_file )
793
799
tmp_file = tempfile .NamedTemporaryFile (delete = False , suffix = ".mp4" )
794
800
tmp_file .write (video_bytes )
795
801
tmp_file .close ()
796
- vr = PyVideoReader (tmp_file .name , device = device , threads = 0 )
802
+ vr = VideoReader (tmp_file .name , ctx = ctx )
797
803
else :
798
804
raise ValueError (f"Unsupported video input type: { type (video_file )} " )
799
805
0 commit comments