Skip to content
10 changes: 10 additions & 0 deletions keras/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ def test_application_pooling(self, app, last_dim):
)
self.assertShapeEqual(output_shape, (None, last_dim))

@parameterized.parameters(MODEL_LIST)
def test_application_classifier_activation(self, app, _):
if "RegNet" in app.__name__:
self.skipTest("RegNet models do not support classifier activation")
model = app(
weights=None, include_top=True, classifier_activation="softmax"
)
last_layer_act = model.layers[-1].activation.__name__
self.assertEqual(last_layer_act, "softmax")

@parameterized.parameters(*MODEL_LIST_NO_NASNET)
def test_application_variable_input_channels(self, app, last_dim):
if backend.image_data_format() == "channels_first":
Expand Down
15 changes: 12 additions & 3 deletions keras/applications/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,12 @@ def apply(x):
return apply


def Head(num_classes=1000, name=None):
def Head(num_classes=1000, classifier_activation=None, name=None):
"""Implementation of classification head of RegNet.

Args:
num_classes: number of classes for Dense layer
classifier_activation: activation function for the Dense layer
name: name prefix

Returns:
Expand All @@ -342,7 +343,11 @@ def apply(x):
x = layers.LayerNormalization(
epsilon=1e-6, name=name + "_head_layernorm"
)(x)
x = layers.Dense(num_classes, name=name + "_head_dense")(x)
x = layers.Dense(
num_classes,
activation=classifier_activation,
name=name + "_head_dense",
)(x)
return x

return apply
Expand Down Expand Up @@ -522,8 +527,12 @@ def ConvNeXt(
cur += depths[i]

if include_top:
x = Head(num_classes=classes, name=model_name)(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = Head(
num_classes=classes,
classifier_activation=classifier_activation,
name=model_name,
)(x)

else:
if pooling == "avg":
Expand Down