We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3f636d6 commit 8b31a2fCopy full SHA for 8b31a2f
src/accelerate/utils/modeling.py
@@ -970,6 +970,8 @@ def get_balanced_memory(
970
expected_device_type = "xpu"
971
elif is_hpu_available():
972
expected_device_type = "hpu"
973
+ elif is_mps_available():
974
+ expected_device_type = "mps"
975
else:
976
expected_device_type = "cuda"
977
num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0])
0 commit comments