diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py index 607c6895ba..40efb6de04 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone.py @@ -15,11 +15,11 @@ from keras import layers from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone @keras_nlp_export("keras_nlp.models.CSPDarkNetBackbone") -class CSPDarkNetBackbone(Backbone): +class CSPDarkNetBackbone(FeaturePyramidBackbone): """This class represents Keras Backbone of CSPDarkNet model. This class implements a CSPDarkNet backbone as described in @@ -65,12 +65,15 @@ def __init__( self, stackwise_num_filters, stackwise_depth, - include_rescaling, + include_rescaling=True, block_type="basic_block", - image_shape=(224, 224, 3), + image_shape=(None, None, 3), **kwargs, ): # === Functional Model === + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) apply_ConvBlock = ( apply_darknet_conv_block_depthwise if block_type == "depthwise_block" @@ -83,15 +86,22 @@ def __init__( if include_rescaling: x = layers.Rescaling(scale=1 / 255.0)(x) - x = apply_focus(name="stem_focus")(x) + x = apply_focus(channel_axis, name="stem_focus")(x) x = apply_darknet_conv_block( - base_channels, kernel_size=3, strides=1, name="stem_conv" + base_channels, + channel_axis, + kernel_size=3, + strides=1, + name="stem_conv", )(x) + + pyramid_outputs = {} for index, (channels, depth) in enumerate( zip(stackwise_num_filters, stackwise_depth) ): x = apply_ConvBlock( channels, + channel_axis, kernel_size=3, strides=2, name=f"dark{index + 2}_conv", @@ -100,17 +110,20 @@ def __init__( if index == len(stackwise_depth) - 1: x = apply_spatial_pyramid_pooling_bottleneck( channels, + channel_axis, hidden_filters=channels // 2, name=f"dark{index + 2}_spp", )(x) x = apply_cross_stage_partial( channels, + channel_axis, num_bottlenecks=depth, block_type="basic_block", residual=(index != len(stackwise_depth) - 1), name=f"dark{index + 2}_csp", )(x) + pyramid_outputs[f"P{index + 2}"] = x super().__init__(inputs=image_input, outputs=x, **kwargs) @@ -120,6 +133,7 @@ def __init__( self.include_rescaling = include_rescaling self.block_type = block_type self.image_shape = image_shape + self.pyramid_outputs = pyramid_outputs def get_config(self): config = super().get_config() @@ -135,7 +149,7 @@ def get_config(self): return config -def apply_focus(name=None): +def apply_focus(channel_axis, name=None): """A block used in CSPDarknet to focus information into channels of the image. @@ -151,7 +165,7 @@ def apply_focus(name=None): """ def apply(x): - return layers.Concatenate(name=name)( + return layers.Concatenate(axis=channel_axis, name=name)( [ x[..., ::2, ::2, :], x[..., 1::2, ::2, :], @@ -164,7 +178,13 @@ def apply(x): def apply_darknet_conv_block( - filters, kernel_size, strides, use_bias=False, activation="silu", name=None + filters, + channel_axis, + kernel_size, + strides, + use_bias=False, + activation="silu", + name=None, ): """ The basic conv block used in Darknet. Applies Conv2D followed by a @@ -193,11 +213,12 @@ def apply(inputs): kernel_size, strides, padding="same", + data_format=keras.config.image_data_format(), use_bias=use_bias, name=name + "_conv", )(inputs) - x = layers.BatchNormalization(name=name + "_bn")(x) + x = layers.BatchNormalization(axis=channel_axis, name=name + "_bn")(x) if activation == "silu": x = layers.Lambda(lambda x: keras.activations.silu(x))(x) @@ -212,7 +233,7 @@ def apply(inputs): def apply_darknet_conv_block_depthwise( - filters, kernel_size, strides, activation="silu", name=None + filters, channel_axis, kernel_size, strides, activation="silu", name=None ): """ The depthwise conv block used in CSPDarknet. @@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise( def apply(inputs): x = layers.DepthwiseConv2D( - kernel_size, strides, padding="same", use_bias=False + kernel_size, + strides, + padding="same", + data_format=keras.config.image_data_format(), + use_bias=False, )(inputs) - x = layers.BatchNormalization()(x) + x = layers.BatchNormalization(axis=channel_axis)(x) if activation == "silu": x = layers.Lambda(lambda x: keras.activations.swish(x))(x) @@ -248,7 +273,11 @@ def apply(inputs): x = layers.LeakyReLU(0.1)(x) x = apply_darknet_conv_block( - filters, kernel_size=1, strides=1, activation=activation + filters, + channel_axis, + kernel_size=1, + strides=1, + activation=activation, )(x) return x @@ -258,6 +287,7 @@ def apply(inputs): def apply_spatial_pyramid_pooling_bottleneck( filters, + channel_axis, hidden_filters=None, kernel_sizes=(5, 9, 13), activation="silu", @@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck( def apply(x): x = apply_darknet_conv_block( hidden_filters, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -304,13 +335,15 @@ def apply(x): kernel_size, strides=1, padding="same", + data_format=keras.config.image_data_format(), name=f"{name}_maxpool_{kernel_size}", )(x[0]) ) - x = layers.Concatenate(name=f"{name}_concat")(x) + x = layers.Concatenate(axis=channel_axis, name=f"{name}_concat")(x) x = apply_darknet_conv_block( filters, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -324,6 +357,7 @@ def apply(x): def apply_cross_stage_partial( filters, + channel_axis, num_bottlenecks, residual=True, block_type="basic_block", @@ -361,6 +395,7 @@ def apply(inputs): x1 = apply_darknet_conv_block( hidden_channels, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -369,6 +404,7 @@ def apply(inputs): x2 = apply_darknet_conv_block( hidden_channels, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -379,6 +415,7 @@ def apply(inputs): residual_x = x1 x1 = apply_darknet_conv_block( hidden_channels, + channel_axis, kernel_size=1, strides=1, activation=activation, @@ -386,6 +423,7 @@ def apply(inputs): )(x1) x1 = ConvBlock( hidden_channels, + channel_axis, kernel_size=3, strides=1, activation=activation, @@ -399,6 +437,7 @@ def apply(inputs): x = layers.Concatenate(name=f"{name}_concat")([x1, x2]) x = apply_darknet_conv_block( filters, + channel_axis, kernel_size=1, strides=1, activation=activation, diff --git a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py index 857e06039d..ed6dc7b525 100644 --- a/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py +++ b/keras_nlp/src/models/csp_darknet/csp_darknet_backbone_test.py @@ -24,21 +24,26 @@ class CSPDarkNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { - "stackwise_num_filters": [32, 64, 128, 256], + "stackwise_num_filters": [2, 4, 6, 8], "stackwise_depth": [1, 3, 3, 1], - "include_rescaling": False, "block_type": "basic_block", - "image_shape": (224, 224, 3), + "image_shape": (32, 32, 3), } - self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + self.input_size = 32 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) def test_backbone_basics(self): - self.run_backbone_test( + self.run_vision_backbone_test( cls=CSPDarkNetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 7, 7, 256), + expected_output_shape=(2, 1, 1, 8), + expected_pyramid_output_keys=["P2", "P3", "P4", "P5"], + expected_pyramid_image_sizes=[(8, 8), (4, 4), (2, 2), (1, 1)], run_mixed_precision_check=False, + run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_nlp/src/models/densenet/densenet_backbone.py b/keras_nlp/src/models/densenet/densenet_backbone.py index 60a5b28849..13e3d8597f 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone.py +++ b/keras_nlp/src/models/densenet/densenet_backbone.py @@ -14,14 +14,13 @@ import keras from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone -BN_AXIS = 3 BN_EPSILON = 1.001e-5 @keras_nlp_export("keras_nlp.models.DenseNetBackbone") -class DenseNetBackbone(Backbone): +class DenseNetBackbone(FeaturePyramidBackbone): """Instantiates the DenseNet architecture. This class implements a DenseNet backbone as described in @@ -35,7 +34,7 @@ class DenseNetBackbone(Backbone): include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer. Defaults to `True`. - image_shape: optional shape tuple, defaults to (224, 224, 3). + image_shape: optional shape tuple, defaults to (None, None, 3). compression_ratio: float, compression rate at transition layers, defaults to 0.5. growth_rate: int, number of filters added by each dense block, @@ -62,12 +61,14 @@ def __init__( self, stackwise_num_repeats, include_rescaling=True, - image_shape=(224, 224, 3), + image_shape=(None, None, 3), compression_ratio=0.5, growth_rate=32, **kwargs, ): # === Functional Model === + data_format = keras.config.image_data_format() + channel_axis = -1 if data_format == "channels_last" else 1 image_input = keras.layers.Input(shape=image_shape) x = image_input @@ -75,37 +76,47 @@ def __init__( x = keras.layers.Rescaling(1 / 255.0)(x) x = keras.layers.Conv2D( - 64, 7, strides=2, use_bias=False, padding="same", name="conv1_conv" + 64, + 7, + strides=2, + use_bias=False, + padding="same", + data_format=data_format, + name="conv1_conv", )(x) x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name="conv1_bn" + axis=channel_axis, epsilon=BN_EPSILON, name="conv1_bn" )(x) x = keras.layers.Activation("relu", name="conv1_relu")(x) x = keras.layers.MaxPooling2D( - 3, strides=2, padding="same", name="pool1" + 3, strides=2, padding="same", data_format=data_format, name="pool1" )(x) + pyramid_outputs = {} for stack_index in range(len(stackwise_num_repeats) - 1): index = stack_index + 2 x = apply_dense_block( x, + channel_axis, stackwise_num_repeats[stack_index], growth_rate, name=f"conv{index}", ) + pyramid_outputs[f"P{index}"] = x x = apply_transition_block( - x, compression_ratio, name=f"pool{index}" + x, channel_axis, compression_ratio, name=f"pool{index}" ) x = apply_dense_block( x, + channel_axis, stackwise_num_repeats[-1], growth_rate, name=f"conv{len(stackwise_num_repeats) + 1}", ) - + pyramid_outputs[f"P{len(stackwise_num_repeats) +1}"] = x x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name="bn" + axis=channel_axis, epsilon=BN_EPSILON, name="bn" )(x) x = keras.layers.Activation("relu", name="relu")(x) @@ -117,6 +128,7 @@ def __init__( self.compression_ratio = compression_ratio self.growth_rate = growth_rate self.image_shape = image_shape + self.pyramid_outputs = pyramid_outputs def get_config(self): config = super().get_config() @@ -132,7 +144,7 @@ def get_config(self): return config -def apply_dense_block(x, num_repeats, growth_rate, name=None): +def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None): """A dense block. Args: @@ -145,11 +157,13 @@ def apply_dense_block(x, num_repeats, growth_rate, name=None): name = f"dense_block_{keras.backend.get_uid('dense_block')}" for i in range(num_repeats): - x = apply_conv_block(x, growth_rate, name=f"{name}_block_{i}") + x = apply_conv_block( + x, channel_axis, growth_rate, name=f"{name}_block_{i}" + ) return x -def apply_transition_block(x, compression_ratio, name=None): +def apply_transition_block(x, channel_axis, compression_ratio, name=None): """A transition block. Args: @@ -157,24 +171,28 @@ def apply_transition_block(x, compression_ratio, name=None): compression_ratio: float, compression rate at transition layers. name: string, block label. """ + data_format = keras.config.image_data_format() if name is None: name = f"transition_block_{keras.backend.get_uid('transition_block')}" x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_bn" + axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_bn" )(x) x = keras.layers.Activation("relu", name=f"{name}_relu")(x) x = keras.layers.Conv2D( - int(x.shape[BN_AXIS] * compression_ratio), + int(x.shape[channel_axis] * compression_ratio), 1, use_bias=False, + data_format=data_format, name=f"{name}_conv", )(x) - x = keras.layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x) + x = keras.layers.AveragePooling2D( + 2, strides=2, data_format=data_format, name=f"{name}_pool" + )(x) return x -def apply_conv_block(x, growth_rate, name=None): +def apply_conv_block(x, channel_axis, growth_rate, name=None): """A building block for a dense block. Args: @@ -182,19 +200,24 @@ def apply_conv_block(x, growth_rate, name=None): growth_rate: int, number of filters added by each dense block. name: string, block label. """ + data_format = keras.config.image_data_format() if name is None: name = f"conv_block_{keras.backend.get_uid('conv_block')}" shortcut = x x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_0_bn" + axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_0_bn" )(x) x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x) x = keras.layers.Conv2D( - 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv" + 4 * growth_rate, + 1, + use_bias=False, + data_format=data_format, + name=f"{name}_1_conv", )(x) x = keras.layers.BatchNormalization( - axis=BN_AXIS, epsilon=BN_EPSILON, name=f"{name}_1_bn" + axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_1_bn" )(x) x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x) x = keras.layers.Conv2D( @@ -202,9 +225,10 @@ def apply_conv_block(x, growth_rate, name=None): 3, padding="same", use_bias=False, + data_format=data_format, name=f"{name}_2_conv", )(x) - x = keras.layers.Concatenate(axis=BN_AXIS, name=f"{name}_concat")( + x = keras.layers.Concatenate(axis=channel_axis, name=f"{name}_concat")( [shortcut, x] ) return x diff --git a/keras_nlp/src/models/densenet/densenet_backbone_test.py b/keras_nlp/src/models/densenet/densenet_backbone_test.py index 63f358035c..7720411319 100644 --- a/keras_nlp/src/models/densenet/densenet_backbone_test.py +++ b/keras_nlp/src/models/densenet/densenet_backbone_test.py @@ -22,21 +22,26 @@ class DenseNetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { - "stackwise_num_repeats": [6, 12, 24, 16], - "include_rescaling": True, + "stackwise_num_repeats": [2, 4, 6, 4], "compression_ratio": 0.5, - "growth_rate": 32, - "image_shape": (224, 224, 3), + "growth_rate": 2, + "image_shape": (32, 32, 3), } - self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + self.input_size = 32 + self.input_data = np.ones( + (2, self.input_size, self.input_size, 3), dtype="float32" + ) def test_backbone_basics(self): - self.run_backbone_test( + self.run_vision_backbone_test( cls=DenseNetBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 7, 7, 1024), + expected_output_shape=(2, 1, 1, 24), + expected_pyramid_output_keys=["P2", "P3", "P4", "P5"], + expected_pyramid_image_sizes=[(8, 8), (4, 4), (2, 2), (1, 1)], run_mixed_precision_check=False, + run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py index 2cfe7f6761..35c5f7fd5a 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone.py @@ -37,7 +37,7 @@ def __init__( patch_sizes, strides, include_rescaling=True, - image_shape=(224, 224, 3), + image_shape=(None, None, 3), hidden_dims=None, **kwargs, ): @@ -63,7 +63,7 @@ def __init__( include_rescaling: bool, whether to rescale the inputs. If set to `True`, inputs will be passed through a `Rescaling(1/255.0)` layer. Defaults to `True`. - image_shape: optional shape tuple, defaults to (224, 224, 3). + image_shape: optional shape tuple, defaults to (None, None, 3). hidden_dims: the embedding dims per hierarchical layer, used as the levels of the feature pyramid. patch_sizes: list of integers, the patch_size to apply for each layer. diff --git a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py index 4f1955297f..280adca065 100644 --- a/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_nlp/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -14,7 +14,6 @@ import numpy as np import pytest -from keras import models from keras_nlp.src.models.mix_transformer.mix_transformer_backbone import ( MiTBackbone, @@ -42,30 +41,18 @@ def setUp(self): ) def test_backbone_basics(self): - self.run_backbone_test( + self.run_vision_backbone_test( cls=MiTBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, 2, 2, 8), + expected_pyramid_output_keys=["P1", "P2"], + expected_pyramid_image_sizes=[(4, 4), (2, 2)], run_quantization_check=False, run_mixed_precision_check=False, + run_data_format_check=False, ) - def test_pyramid_output_format(self): - init_kwargs = self.init_kwargs - backbone = MiTBackbone(**init_kwargs) - model = models.Model(backbone.inputs, backbone.pyramid_outputs) - output_data = model(self.input_data) - - self.assertIsInstance(output_data, dict) - self.assertEqual( - list(output_data.keys()), list(backbone.pyramid_outputs.keys()) - ) - self.assertEqual(list(output_data.keys()), ["P1", "P2"]) - for k, v in output_data.items(): - size = self.input_size // (2 ** (int(k[1:]) + 1)) - self.assertEqual(tuple(v.shape[:3]), (2, size, size)) - @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index a6a30362cd..fdf52cab73 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -14,7 +14,6 @@ import pytest from absl.testing import parameterized -from keras import models from keras import ops from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone @@ -51,26 +50,10 @@ def test_backbone_basics(self, use_pre_activation, block_type): expected_output_shape=( (2, 64) if block_type == "basic_block" else (2, 256) ), + expected_pyramid_output_keys=["P2", "P3", "P4"], + expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)], ) - def test_pyramid_output_format(self): - init_kwargs = self.init_kwargs.copy() - init_kwargs.update( - {"block_type": "basic_block", "use_pre_activation": False} - ) - backbone = ResNetBackbone(**init_kwargs) - model = models.Model(backbone.inputs, backbone.pyramid_outputs) - output_data = model(self.input_data) - - self.assertIsInstance(output_data, dict) - self.assertEqual( - list(output_data.keys()), list(backbone.pyramid_outputs.keys()) - ) - self.assertEqual(list(output_data.keys()), ["P2", "P3", "P4"]) - for k, v in output_data.items(): - size = self.input_size // (2 ** int(k[1:])) - self.assertEqual(tuple(v.shape[:3]), (2, size, size)) - @parameterized.named_parameters( ("v1_basic", False, "basic_block"), ("v1_bottleneck", False, "bottleneck_block"), diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 8e63bc19d9..6dad75b84d 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -465,6 +465,8 @@ def run_vision_backbone_test( init_kwargs, input_data, expected_output_shape, + expected_pyramid_output_keys=None, + expected_pyramid_image_sizes=None, variable_length_data=None, run_mixed_precision_check=True, run_quantization_check=True, @@ -491,6 +493,25 @@ def run_vision_backbone_test( run_mixed_precision_check=run_mixed_precision_check, run_quantization_check=run_quantization_check, ) + if expected_pyramid_output_keys: + backbone = cls(**init_kwargs) + model = keras.models.Model( + backbone.inputs, backbone.pyramid_outputs + ) + output_data = model(input_data) + + self.assertIsInstance(output_data, dict) + self.assertEqual( + list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + ) + self.assertEqual( + list(output_data.keys()), expected_pyramid_output_keys + ) + # check height and width of each level. + for i, (k, v) in enumerate(output_data.items()): + self.assertEqual( + tuple(v.shape[1:3]), expected_pyramid_image_sizes[i] + ) # Check data_format. We assume that `input_data` is in "channels_last" # format.