Skip to content

Commit 315c463

Browse files
jma127facebook-github-bot
authored andcommitted
Add periodic CUDA cache cleanup (#882)
Summary: This adds a periodic call to `torch.cuda.empty_cache()` in order to mitigate memory fragmentation in the PyTorch CUDA cached allocator that can cause OOMs on models approaching GPU memory limit. By default, this will occur every 64 updates. Performance considerations: - I've benchmarked this on a reasonably large model with memory footprint 16 GB, and the overhead with the default setting is <0.2%. With `update-freq > 1`, the cost is mitigated even further. - This behavior can be disabled with a value of zero. Pull Request resolved: fairinternal/fairseq-py#882 Differential Revision: D17742386 Pulled By: jma127 fbshipit-source-id: 68d8f93f798d6818b5efc3d67d43b52dfb8b2865
1 parent de348d1 commit 315c463

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

fairseq/options.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def get_parser(desc, default_task='translation'):
193193
help='threshold FP16 loss scale from below')
194194
parser.add_argument('--user-dir', default=None,
195195
help='path to a python module containing custom extensions (tasks and/or architectures)')
196+
parser.add_argument('--empty-cache-freq', default=0, type=int,
197+
help='how often to clear the PyTorch CUDA cache (0 to disable)')
196198

197199
from fairseq.registry import REGISTRIES
198200
for registry_name, REGISTRY in REGISTRIES.items():

fairseq/trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,14 @@ def maybe_no_sync():
426426

427427
if 'nll_loss' in logging_output:
428428
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
429+
430+
# clear CUDA cache to reduce memory fragmentation
431+
if (self.args.empty_cache_freq > 0 and
432+
((self.get_num_updates() + self.args.empty_cache_freq - 1) %
433+
self.args.empty_cache_freq) == 0 and
434+
torch.cuda.is_available() and
435+
not self.args.cpu):
436+
torch.cuda.empty_cache()
429437
except OverflowError as e:
430438
print('| WARNING: overflow detected, ' + str(e))
431439
self.zero_grad()

0 commit comments

Comments
 (0)