Skip to content

Commit 89c4d9f

Browse files
TestLowCpuMemUsage UT get device by device_name (#6397)
Co-authored-by: Shaik Raza Sikander <[email protected]> Co-authored-by: Logan Adams <[email protected]>
1 parent a7ffe54 commit 89c4d9f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/unit/inference/test_inference.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,12 @@ def verify_injection(module):
298298
verify_injection(model)
299299

300300

301+
# Used to Get Device name
302+
def getDeviceId(local_rank):
303+
device = torch.device(f"{get_accelerator().device_name(local_rank)}")
304+
return device
305+
306+
301307
# Verify that test is valid
302308
def validate_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
303309
model, task = model_w_task
@@ -484,8 +490,8 @@ def test(
484490
pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.")
485491

486492
local_rank = int(os.getenv("LOCAL_RANK", "0"))
487-
488-
pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=local_rank, framework="pt")
493+
device = getDeviceId(local_rank)
494+
pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage": True}, device=device, framework="pt")
489495
bs_output = pipe(query, **inf_kwargs)
490496
pipe.model = deepspeed.init_inference(pipe.model,
491497
mp_size=self.world_size,

0 commit comments

Comments
 (0)