3
3
4
4
"""Check GPU Health during training."""
5
5
import logging
6
+ import os
6
7
from collections import deque
7
8
from datetime import datetime
8
9
from typing import List , Optional , Tuple
9
10
10
- import torch
11
-
12
- try :
13
- import pynvml
14
- except ImportError :
15
- pynvml = None
16
-
17
- import os
18
-
19
11
import numpy as np
20
- from slack_sdk . webhook import WebhookClient
12
+ import torch
21
13
22
14
from composer .core import Callback , State
23
15
from composer .core .time import Timestamp
24
16
from composer .loggers import Logger
25
- from composer .utils import dist
17
+ from composer .utils import MissingConditionalImportError , dist
26
18
27
19
log = logging .getLogger (__name__ )
28
20
@@ -69,6 +61,14 @@ def __init__(
69
61
if not self .slack_webhook_url :
70
62
self .slack_webhook_url = os .environ .get ('SLACK_WEBHOOK_URL' , None )
71
63
64
+ if self .slack_webhook_url :
65
+ # fail fast if missing import
66
+ try :
67
+ import slack_sdk
68
+ del slack_sdk
69
+ except ImportError as e :
70
+ raise MissingConditionalImportError ('health_checker' , 'slack_sdk' , None ) from e
71
+
72
72
self .last_sample = 0
73
73
self .last_check = 0
74
74
@@ -133,6 +133,7 @@ def _alert(self, message: str, state: State) -> None:
133
133
134
134
logging .warning (message )
135
135
if self .slack_webhook_url :
136
+ from slack_sdk .webhook import WebhookClient
136
137
client = WebhookClient (url = self .slack_webhook_url )
137
138
client .send (text = message )
138
139
@@ -141,12 +142,13 @@ def _is_available() -> bool:
141
142
if not torch .cuda .is_available ():
142
143
return False
143
144
try :
145
+ import pynvml
144
146
pynvml .nvmlInit () # type: ignore
145
147
return True
148
+ except ImportError :
149
+ raise MissingConditionalImportError ('health_checker' , 'pynvml' , None )
146
150
except pynvml .NVMLError_LibraryNotFound : # type: ignore
147
151
logging .warning ('NVML not found, disabling GPU health checking' )
148
- except ImportError :
149
- logging .warning ('pynvml library not found, disabling GPU health checking.' )
150
152
except Exception as e :
151
153
logging .warning (f'Error initializing NVML: { e } ' )
152
154
@@ -168,13 +170,18 @@ def sample(self) -> None:
168
170
self .samples .append (sample )
169
171
170
172
def _sample (self ) -> Optional [List ]:
173
+ try :
174
+ import pynvml
175
+ except ImportError :
176
+ raise MissingConditionalImportError ('health_checker' , 'pynvml' , None )
177
+
171
178
try :
172
179
samples = []
173
- device_count = pynvml .nvmlDeviceGetCount () # type: ignore
180
+ device_count = pynvml .nvmlDeviceGetCount ()
174
181
for i in range (device_count ):
175
- handle = pynvml .nvmlDeviceGetHandleByIndex (i ) # type: ignore
176
- samples .append (pynvml .nvmlDeviceGetUtilizationRates (handle ).gpu ) # type: ignore
177
- except pynvml .NVMLError : # type: ignore
182
+ handle = pynvml .nvmlDeviceGetHandleByIndex (i )
183
+ samples .append (pynvml .nvmlDeviceGetUtilizationRates (handle ).gpu )
184
+ except pynvml .NVMLError :
178
185
return None
179
186
return samples
180
187
0 commit comments