File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed
composer/algorithms/gated_linear_units Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -100,17 +100,18 @@ def apply_gated_linear_units(model: torch.nn.Module,
100
100
101
101
# get the activation functions used
102
102
act_fns = {module .intermediate_act_fn for module in intermediate_modules }
103
- if len (act_fns ) == 0 :
103
+ num_act_fns = len ({type (act_fn ) for act_fn in act_fns })
104
+ if num_act_fns == 0 :
104
105
raise ValueError ('Tried to get the activation function from the model, but none were found. '
105
106
'Please specify `act_fn` manually to use Gated Linear Units.' )
106
- elif len ( act_fns ) > 1 :
107
+ elif num_act_fns > 1 :
107
108
raise ValueError ('Tried to get the activation function from the model, but multiple different '
108
109
'functions are used. This is currently unsupported with Gated Linear Units. '
109
110
'Please either use one activation function in BertIntermediate modules or '
110
111
'specify `act_fn` to manually override activation functions.' )
111
112
112
- # since our set is of 1, let's extract the only activation function remaining.
113
- ( act_fn ,) = act_fns #type: ignore will fail below if None
113
+ # since our set is of 1, let's extract the activation function
114
+ act_fn = next ( iter ( act_fns )) # type: ignore will fail below if None
114
115
115
116
if act_fn is None :
116
117
raise ValueError (
You can’t perform that action at this time.
0 commit comments