Skip to content

Commit 5b73515

Browse files
dskhudiaravi-mosaicml
authored andcommitted
Automatic Stochastic depth on residual blocks (#1253)
1 parent 35e9e0a commit 5b73515

File tree

3 files changed

+193
-15
lines changed

3 files changed

+193
-15
lines changed

composer/algorithms/stochastic_depth/stochastic_layers.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33

44
"""Stochastic forward functions for ResNet Bottleneck modules."""
55

6+
from typing import Optional
7+
68
import torch
9+
import torch.nn as nn
10+
from torch.fx import GraphModule
711
from torchvision.models.resnet import Bottleneck
812

13+
__all__ = ['make_resnet_bottleneck_stochastic', 'BlockStochasticModule']
14+
915

1016
def block_stochastic_forward(self, x):
1117
"""ResNet Bottleneck forward function where the layers are randomly
@@ -101,3 +107,37 @@ def make_resnet_bottleneck_stochastic(module: Bottleneck, module_index: int, mod
101107
module.forward = stochastic_func.__get__(module) # Bind new forward function to ResNet Bottleneck Module
102108

103109
return module
110+
111+
112+
class BlockStochasticModule(nn.Module):
113+
"""A convenience class that stochastically executes the provided main path of a residual block.
114+
115+
Args:
116+
main (GraphModule): Operators in the main (non-residual) path of a residual block.
117+
residual (GraphModule | None): Operators, if any, in the residual path of a residual block.
118+
drop_rate: The base probability of dropping this layer. Must be between 0.0 (inclusive) and 1.0 (inclusive).
119+
120+
Returns:
121+
BlockStochasticModule: An instance of :class:`.BlockStochasticModule`.
122+
"""
123+
124+
def __init__(self, main: GraphModule, residual: Optional[GraphModule] = None, drop_rate: float = 0.2):
125+
super().__init__()
126+
self.drop_rate = torch.tensor(drop_rate)
127+
self.main = main
128+
self.residual = residual
129+
130+
def forward(self, x):
131+
sample = (not self.training) or bool(torch.bernoulli(1 - self.drop_rate))
132+
# main side is the non-residual connection
133+
residual_result = x
134+
# residual side may or may not have any operations
135+
if self.residual:
136+
residual_result = self.residual(x)
137+
138+
if sample:
139+
main_result = self.main(x)
140+
if not self.training:
141+
main_result = main_result * (1 - self.drop_rate)
142+
residual_result = torch.add(main_result, residual_result)
143+
return residual_result

composer/utils/fx_utils.py

Lines changed: 126 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88

99
import logging
1010
import operator
11-
from typing import Any, Callable, Dict, List, Mapping, Tuple, Union
11+
import re
12+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1213

1314
import torch
1415
import torch.nn as nn
15-
from torch.fx import Node
16-
from torch.fx.graph_module import GraphModule
16+
from torch.fx import GraphModule, Node
17+
from torch.fx.passes.split_utils import split_by_tags
1718

19+
from composer.algorithms.stochastic_depth.stochastic_layers import BlockStochasticModule
1820
from composer.utils import ensure_tuple
1921

2022
log = logging.getLogger(__name__)
2123

22-
__all__ = ['count_op_instances', 'replace_op', 'fuse_parallel_linears']
24+
__all__ = ['count_op_instances', 'replace_op', 'fuse_parallel_linears', 'apply_stochastic_residual']
2325

2426

2527
def count_op_instances(gm: GraphModule, ops: Union[Callable, str, List[Union[Callable, str]]]) -> int:
@@ -111,28 +113,138 @@ def replace_op(gm: GraphModule, src_ops: Union[Callable, str, List[Union[Callabl
111113
return gm
112114

113115

114-
def detect_residual_pattern(gm: GraphModule):
115-
"""Search and replace the pattern with another.
116+
def _get_ancestors(node: Node) -> List[Node]:
117+
ancestorNodes = []
118+
while node.op != 'placeholder':
119+
ancestorNodes.append(node)
120+
node = node.all_input_nodes[0]
121+
return ancestorNodes
122+
123+
124+
def _get_residual_block_nodes(nodeLHS: Node, nodeRHS: Node) -> Tuple[List[Node], List[Node]]:
125+
"""Walk backwards from nodeLHS and nodeRSH to the root and construct lists of their parents.
116126
117127
Arguments:
118-
gm (GraphModule): The source FX-traced graph.
128+
nodeLHS (Node): left-hand side node for a binary operator
129+
nodeRHS (Node): right-hand side node for a binary operator
119130
120131
Returns:
121-
GraphModule: Modified GraphModule.
132+
(lhsAncestors, rhsAncestors): Two lists of nodes containing ancestors for ``nodeLHS`` and ``nodeRHS`` with
133+
their common ancestors removed.
122134
"""
123-
raise NotImplementedError('detect_residual_pattern is currently not implemented.')
135+
lhsAncestors = _get_ancestors(nodeLHS)
136+
rhsAncestors = _get_ancestors(nodeRHS)
137+
138+
# Iterate from back and eliminate common nodes
139+
while lhsAncestors and rhsAncestors and lhsAncestors[-1] == rhsAncestors[-1]:
140+
lhsAncestors.pop()
141+
rhsAncestors.pop()
142+
lhsAncestors.reverse()
143+
rhsAncestors.reverse()
144+
return lhsAncestors, rhsAncestors
124145

125146

126-
def replace_residual_with_stochastic(gm: GraphModule):
127-
"""Replaces residual pattern with their stoachstic equivalent.
147+
def _attach_tag(nodes: List[Node], tag: str):
148+
"""Attach tag to the given nodes for the splitter."""
149+
for node in nodes:
150+
node.tag = tag # type: ignore[attr-defined]
151+
152+
153+
def _tag_residual_nodes(gm: GraphModule) -> Tuple[List[str], int]:
154+
"""Tag nodes for splitting."""
155+
# all nodes that are not a part of the residual blocks are tagged with "mainN_{count}".
156+
# a tag is required for all nodes by split_by_tags
157+
# Also an earlier tag can be repeated for later nodes.
158+
count = 0
159+
all_tags = []
160+
# In this pass over all nodes, we just tag them
161+
for node in gm.graph.nodes:
162+
default_tag = f'mainN_{count}'
163+
node.tag = default_tag
164+
if default_tag not in all_tags:
165+
all_tags.append(default_tag)
166+
if node.op == 'call_function' and node.target in [torch.add, operator.add]:
167+
assert len(node.all_input_nodes) == 2
168+
node0, node1 = node.all_input_nodes[0], node.all_input_nodes[1]
169+
lhs_nodes, rhs_nodes = _get_residual_block_nodes(node0, node1)
170+
if lhs_nodes or rhs_nodes:
171+
if len(lhs_nodes):
172+
_attach_tag(lhs_nodes, f'non_res_{count}')
173+
all_tags.append(f'non_res_{count}')
174+
if len(rhs_nodes):
175+
_attach_tag(rhs_nodes, f'residual_{count}')
176+
all_tags.append(f'residual_{count}')
177+
add_tag = f'addN_{count}'
178+
if add_tag not in all_tags:
179+
all_tags.append(add_tag)
180+
node.tag = add_tag
181+
count += 1
182+
return all_tags, count
183+
184+
185+
def _get_residual_modules(gm: GraphModule, node: Node) -> Tuple[Optional[GraphModule], Optional[GraphModule], int]:
186+
"""Returns GraphModules for the main and residual branches.
187+
188+
node.op is assumed to be a call_module
189+
"""
190+
pattern = re.compile(r'non_res_(\d+)|residual_(\d+)')
191+
matches = pattern.match(str(node.target))
192+
if matches:
193+
idx = int(matches[1]) if matches[1] else int(matches[2])
194+
main_submod = getattr(gm, f'non_res_{idx}')
195+
residual_submod = getattr(gm, f'residual_{idx}', None)
196+
return main_submod, residual_submod, idx
197+
else:
198+
return None, None, 0
199+
200+
201+
def _replace_residual_pattern(gm: GraphModule,
202+
original_node: Node,
203+
replacement_module: str,
204+
has_residual_ops: bool = False) -> None:
205+
"""Replaces main, residual and add_node with the ``replacement_module``.
206+
207+
``replacement_module`` is already added to the gm.
208+
"""
209+
insert_node = original_node.prev
210+
add_node = original_node.next
211+
if has_residual_ops:
212+
add_node = original_node.next.next
213+
with gm.graph.inserting_after(insert_node):
214+
new_node = gm.graph.call_module(replacement_module, args=(insert_node,)) # type: ignore
215+
add_node.replace_all_uses_with(new_node)
216+
gm.graph.erase_node(add_node)
217+
if has_residual_ops:
218+
gm.graph.erase_node(original_node.next)
219+
gm.graph.erase_node(original_node)
220+
gm.graph.lint()
221+
222+
223+
def apply_stochastic_residual(gm: GraphModule, drop_rate: float = 0.2) -> Tuple[GraphModule, int]:
224+
"""Detect and replace residual pattern with their stochastic equivalent.
128225
129226
Arguments:
130-
gm (GraphModule): The source FX-traced graph.
227+
gm (GraphModule): The source FX-traced graph. It can be the whole model symbolically traced.
131228
132229
Returns:
133-
GraphModule: Modified GraphModule.
230+
GraphModule: Modified GraphModule that has stochastic residual connections.
134231
"""
135-
raise NotImplementedError('replace_residual_with_stochastic is currently not implemented.')
232+
if not isinstance(gm, GraphModule):
233+
raise ValueError(
234+
f'Input to apply_stochastic_residual should be an instance of GraphModule. Received {type(gm)}')
235+
all_tags, count = _tag_residual_nodes(gm)
236+
split_gm = split_by_tags(gm, all_tags)
237+
for node in split_gm.graph.nodes:
238+
if node.op != 'call_module':
239+
continue
240+
241+
main_submod, residual_submod, idx = _get_residual_modules(split_gm, node)
242+
if main_submod:
243+
residual_st_instance = BlockStochasticModule(main_submod, residual_submod, drop_rate)
244+
split_gm.add_submodule(f'resi_st_{idx}', residual_st_instance) # type: ignore
245+
_replace_residual_pattern(split_gm, node, f'resi_st_{idx}', residual_submod is not None)
246+
split_gm.recompile()
247+
return split_gm, count
136248

137249

138250
def _can_linears_be_fused(linear_nodes: List[Node], all_modules: Mapping[str, nn.Module]) -> bool:

tests/utils/test_fx_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from torch import nn
99
from torch.fx import symbolic_trace
1010
from torch.fx.graph_module import GraphModule
11+
from torchvision import models
1112

12-
from composer.utils.fx_utils import count_op_instances, fuse_parallel_linears, replace_op
13+
from composer.utils.fx_utils import apply_stochastic_residual, count_op_instances, fuse_parallel_linears, replace_op
1314

1415

1516
class MyTestModel(nn.Module):
@@ -153,3 +154,28 @@ def test_fuse_parallel_linears(model_cls, before_count, after_count):
153154
fuse_parallel_linears(traced)
154155

155156
assert count_op_instances(traced, nn.Linear) == after_count
157+
158+
159+
@pytest.mark.parametrize(
160+
'model_cls, block_count',
161+
[(models.resnet18, 8)],
162+
)
163+
@pytest.mark.filterwarnings(
164+
r'ignore:Attempted to insert a call_module Node with no underlying reference in the owning GraphModule!.*:UserWarning'
165+
)
166+
@pytest.mark.timeout(15)
167+
def test_stochastic_depth(model_cls, block_count):
168+
model = model_cls()
169+
traced = symbolic_trace(model)
170+
171+
assert isinstance(traced, GraphModule)
172+
173+
inp = torch.randn(1, 3, 224, 224)
174+
175+
traced_st_depth_no_drop, residual_count = apply_stochastic_residual(traced, 0.0)
176+
177+
out_traced = traced(inp)
178+
out_traced_st_depth_no_drop = traced_st_depth_no_drop(inp)
179+
assert torch.allclose(out_traced,
180+
out_traced_st_depth_no_drop), 'mismatch in outputs with 0 drop rate for stochastic modules'
181+
assert residual_count == block_count

0 commit comments

Comments
 (0)