Skip to content

Commit 41239d0

Browse files
authored
Update glu check (#1689)
* update glu check * separate out type check * fix test * update docs * type ignore
1 parent 6f4bbbf commit 41239d0

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

composer/algorithms/gated_linear_units/gated_linear_units.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,18 @@ def apply_gated_linear_units(model: torch.nn.Module,
100100

101101
# get the activation functions used
102102
act_fns = {module.intermediate_act_fn for module in intermediate_modules}
103-
if len(act_fns) == 0:
103+
num_act_fns = len({type(act_fn) for act_fn in act_fns})
104+
if num_act_fns == 0:
104105
raise ValueError('Tried to get the activation function from the model, but none were found. '
105106
'Please specify `act_fn` manually to use Gated Linear Units.')
106-
elif len(act_fns) > 1:
107+
elif num_act_fns > 1:
107108
raise ValueError('Tried to get the activation function from the model, but multiple different '
108109
'functions are used. This is currently unsupported with Gated Linear Units. '
109110
'Please either use one activation function in BertIntermediate modules or '
110111
'specify `act_fn` to manually override activation functions.')
111112

112-
# since our set is of 1, let's extract the only activation function remaining.
113-
(act_fn,) = act_fns #type: ignore will fail below if None
113+
# since our set is of 1, let's extract the activation function
114+
act_fn = next(iter(act_fns)) # type: ignore will fail below if None
114115

115116
if act_fn is None:
116117
raise ValueError(

0 commit comments

Comments
 (0)