Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,18 @@ def apply_gated_linear_units(model: torch.nn.Module,

# get the activation functions used
act_fns = {module.intermediate_act_fn for module in intermediate_modules}
if len(act_fns) == 0:
num_act_fns = len({type(act_fn) for act_fn in act_fns})
if num_act_fns == 0:
raise ValueError('Tried to get the activation function from the model, but none were found. '
'Please specify `act_fn` manually to use Gated Linear Units.')
elif len(act_fns) > 1:
elif num_act_fns > 1:
raise ValueError('Tried to get the activation function from the model, but multiple different '
'functions are used. This is currently unsupported with Gated Linear Units. '
'Please either use one activation function in BertIntermediate modules or '
'specify `act_fn` to manually override activation functions.')

# since our set is of 1, let's extract the only activation function remaining.
(act_fn,) = act_fns #type: ignore will fail below if None
# since our set is of 1, let's extract the activation function
act_fn = next(iter(act_fns)) # type: ignore will fail below if None

if act_fn is None:
raise ValueError(
Expand Down