Skip to content

Commit 8b31a2f

Browse files
author
Ihar Hrachyshka
authored
Fix get_balanced_memory for MPS (#3464)
This also fixes a failure in test_get_balanced_memory: ``` assert {0: 215, 1: 300} == {0: 300, 1: 300} [...] tests/test_modeling_utils.py:871: AssertionError ``` Signed-off-by: Ihar Hrachyshka <[email protected]>
1 parent 3f636d6 commit 8b31a2f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/accelerate/utils/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,8 @@ def get_balanced_memory(
970970
expected_device_type = "xpu"
971971
elif is_hpu_available():
972972
expected_device_type = "hpu"
973+
elif is_mps_available():
974+
expected_device_type = "mps"
973975
else:
974976
expected_device_type = "cuda"
975977
num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0])

0 commit comments

Comments
 (0)