Skip to content

Commit 1034522

Browse files
committed
Add ResNet_vd to ResNet backbone
1 parent ececd14 commit 1034522

File tree

2 files changed

+148
-33
lines changed

2 files changed

+148
-33
lines changed

keras_nlp/src/models/resnet/resnet_backbone.py

Lines changed: 126 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,23 @@ class ResNetBackbone(FeaturePyramidBackbone):
2727
This class implements a ResNet backbone as described in [Deep Residual
2828
Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
2929
CVPR 2016), [Identity Mappings in Deep Residual Networks](
30-
https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An
30+
https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An
3131
improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
32-
NeurIPS 2021 Workshop).
32+
NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with
33+
Convolutional Neural Networks](https://arxiv.org/abs/1812.01187).
3334
3435
The difference in ResNet and ResNetV2 rests in the structure of their
3536
individual building blocks. In ResNetV2, the batch normalization and
3637
ReLU activation precede the convolution layers, as opposed to ResNet where
3738
the batch normalization and ReLU activation are applied after the
3839
convolution layers.
3940
41+
ResNetVd introduces two key modifications to the standard ResNet. First,
42+
the initial convolutional layer is replaced by a series of three
43+
successive convolutional layers. Second, shortcut connections use an
44+
additional pooling operation rather than performing downsampling within
45+
the convolutional layers themselves.
46+
4047
Note that `ResNetBackbone` expects the inputs to be images with a value
4148
range of `[0, 255]` when `include_rescaling=True`.
4249
@@ -51,6 +58,7 @@ class ResNetBackbone(FeaturePyramidBackbone):
5158
Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
5259
use_pre_activation: boolean. Whether to use pre-activation or not.
5360
`True` for ResNetV2, `False` for ResNet.
61+
use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
5462
include_rescaling: boolean. If `True`, rescale the input using
5563
`Rescaling` and `Normalization` layers. If `False`, do nothing.
5664
Defaults to `True`.
@@ -106,6 +114,7 @@ def __init__(
106114
stackwise_num_strides,
107115
block_type,
108116
use_pre_activation=False,
117+
use_vd_pooling=False,
109118
include_rescaling=True,
110119
input_image_shape=(None, None, 3),
111120
pooling="avg",
@@ -133,7 +142,12 @@ def __init__(
133142
'`block_type` must be either `"basic_block"` or '
134143
f'`"bottleneck_block"`. Received block_type={block_type}.'
135144
)
136-
version = "v1" if not use_pre_activation else "v2"
145+
if use_vd_pooling:
146+
version = "vd"
147+
elif use_pre_activation:
148+
version = "v2"
149+
else:
150+
version = "v1"
137151
data_format = standardize_data_format(data_format)
138152
bn_axis = -1 if data_format == "channels_last" else 1
139153
num_stacks = len(stackwise_num_filters)
@@ -155,21 +169,21 @@ def __init__(
155169
# The padding between torch and tensorflow/jax differs when `strides>1`.
156170
# Therefore, we need to manually pad the tensor.
157171
x = layers.ZeroPadding2D(
158-
3,
172+
1 if use_vd_pooling else 3,
159173
data_format=data_format,
160174
dtype=dtype,
161175
name="conv1_pad",
162176
)(x)
163-
x = layers.Conv2D(
164-
64,
165-
7,
166-
strides=2,
167-
data_format=data_format,
168-
use_bias=False,
169-
dtype=dtype,
170-
name="conv1_conv",
171-
)(x)
172-
if not use_pre_activation:
177+
if use_vd_pooling:
178+
x = layers.Conv2D(
179+
32,
180+
3,
181+
strides=2,
182+
data_format=data_format,
183+
use_bias=False,
184+
dtype=dtype,
185+
name="conv1_conv",
186+
)(x)
173187
x = layers.BatchNormalization(
174188
axis=bn_axis,
175189
epsilon=1e-5,
@@ -178,6 +192,57 @@ def __init__(
178192
name="conv1_bn",
179193
)(x)
180194
x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
195+
x = layers.Conv2D(
196+
32,
197+
3,
198+
strides=1,
199+
padding="same",
200+
data_format=data_format,
201+
use_bias=False,
202+
dtype=dtype,
203+
name="conv2_conv",
204+
)(x)
205+
x = layers.BatchNormalization(
206+
axis=bn_axis,
207+
epsilon=1e-5,
208+
momentum=0.9,
209+
dtype=dtype,
210+
name="conv2_bn",
211+
)(x)
212+
x = layers.Activation("relu", dtype=dtype, name="conv2_relu")(x)
213+
x = layers.Conv2D(
214+
64,
215+
3,
216+
strides=1,
217+
padding="same",
218+
data_format=data_format,
219+
use_bias=False,
220+
dtype=dtype,
221+
name="conv3_conv",
222+
)(x)
223+
else:
224+
x = layers.Conv2D(
225+
64,
226+
7,
227+
strides=2,
228+
data_format=data_format,
229+
use_bias=False,
230+
dtype=dtype,
231+
name="conv1_conv",
232+
)(x)
233+
if not use_pre_activation:
234+
x = layers.BatchNormalization(
235+
axis=bn_axis,
236+
epsilon=1e-5,
237+
momentum=0.9,
238+
dtype=dtype,
239+
name="conv3_bn" if use_vd_pooling else "conv1_bn",
240+
)(x)
241+
x = layers.Activation(
242+
"relu",
243+
dtype=dtype,
244+
name="conv3_relu" if use_vd_pooling else "conv1_relu",
245+
)(x)
181246

182247
if use_pre_activation:
183248
# A workaround for ResNetV2: we need -inf padding to prevent zeros
@@ -210,8 +275,11 @@ def __init__(
210275
stride=stackwise_num_strides[stack_index],
211276
block_type=block_type,
212277
use_pre_activation=use_pre_activation,
278+
use_vd_pooling=use_vd_pooling,
213279
first_shortcut=(
214-
block_type == "bottleneck_block" or stack_index > 0
280+
block_type == "bottleneck_block"
281+
or stack_index > 0
282+
or use_vd_pooling
215283
),
216284
data_format=data_format,
217285
dtype=dtype,
@@ -253,6 +321,7 @@ def __init__(
253321
self.stackwise_num_strides = stackwise_num_strides
254322
self.block_type = block_type
255323
self.use_pre_activation = use_pre_activation
324+
self.use_vd_pooling = use_vd_pooling
256325
self.include_rescaling = include_rescaling
257326
self.input_image_shape = input_image_shape
258327
self.pooling = pooling
@@ -267,6 +336,7 @@ def get_config(self):
267336
"stackwise_num_strides": self.stackwise_num_strides,
268337
"block_type": self.block_type,
269338
"use_pre_activation": self.use_pre_activation,
339+
"use_vd_pooling": self.use_vd_pooling,
270340
"include_rescaling": self.include_rescaling,
271341
"input_image_shape": self.input_image_shape,
272342
"pooling": self.pooling,
@@ -282,6 +352,7 @@ def apply_basic_block(
282352
stride=1,
283353
conv_shortcut=False,
284354
use_pre_activation=False,
355+
use_vd_pooling=False,
285356
data_format=None,
286357
dtype=None,
287358
name=None,
@@ -299,6 +370,7 @@ def apply_basic_block(
299370
`False`.
300371
use_pre_activation: boolean. Whether to use pre-activation or not.
301372
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
373+
use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
302374
data_format: `None` or str. the ordering of the dimensions in the
303375
inputs. Can be `"channels_last"`
304376
(`(batch_size, height, width, channels)`) or`"channels_first"`
@@ -327,16 +399,27 @@ def apply_basic_block(
327399
)(x_preact)
328400

329401
if conv_shortcut:
330-
x = x_preact if x_preact is not None else x
402+
if x_preact is not None:
403+
shortcut = x_preact
404+
elif use_vd_pooling and stride > 1:
405+
shortcut = layers.AveragePooling2D(
406+
2,
407+
strides=stride,
408+
data_format=data_format,
409+
dtype=dtype,
410+
padding="same",
411+
)(x)
412+
else:
413+
shortcut = x
331414
shortcut = layers.Conv2D(
332415
filters,
333416
1,
334-
strides=stride,
417+
strides=1 if use_vd_pooling else stride,
335418
data_format=data_format,
336419
use_bias=False,
337420
dtype=dtype,
338421
name=f"{name}_0_conv",
339-
)(x)
422+
)(shortcut)
340423
if not use_pre_activation:
341424
shortcut = layers.BatchNormalization(
342425
axis=bn_axis,
@@ -407,6 +490,7 @@ def apply_bottleneck_block(
407490
stride=1,
408491
conv_shortcut=False,
409492
use_pre_activation=False,
493+
use_vd_pooling=False,
410494
data_format=None,
411495
dtype=None,
412496
name=None,
@@ -424,6 +508,7 @@ def apply_bottleneck_block(
424508
`False`.
425509
use_pre_activation: boolean. Whether to use pre-activation or not.
426510
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
511+
use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
427512
data_format: `None` or str. the ordering of the dimensions in the
428513
inputs. Can be `"channels_last"`
429514
(`(batch_size, height, width, channels)`) or`"channels_first"`
@@ -452,16 +537,27 @@ def apply_bottleneck_block(
452537
)(x_preact)
453538

454539
if conv_shortcut:
455-
x = x_preact if x_preact is not None else x
540+
if x_preact is not None:
541+
shortcut = x_preact
542+
elif use_vd_pooling and stride > 1:
543+
shortcut = layers.AveragePooling2D(
544+
2,
545+
strides=stride,
546+
data_format=data_format,
547+
dtype=dtype,
548+
padding="same",
549+
)(x)
550+
else:
551+
shortcut = x
456552
shortcut = layers.Conv2D(
457553
4 * filters,
458554
1,
459-
strides=stride,
555+
strides=1 if use_vd_pooling else stride,
460556
data_format=data_format,
461557
use_bias=False,
462558
dtype=dtype,
463559
name=f"{name}_0_conv",
464-
)(x)
560+
)(shortcut)
465561
if not use_pre_activation:
466562
shortcut = layers.BatchNormalization(
467563
axis=bn_axis,
@@ -548,6 +644,7 @@ def apply_stack(
548644
stride,
549645
block_type,
550646
use_pre_activation,
647+
use_vd_pooling=False,
551648
first_shortcut=True,
552649
data_format=None,
553650
dtype=None,
@@ -565,6 +662,7 @@ def apply_stack(
565662
Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
566663
use_pre_activation: boolean. Whether to use pre-activation or not.
567664
`True` for ResNetV2, `False` for ResNet and ResNeXt.
665+
use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
568666
first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
569667
use an identity or pooling shortcut based on the stride. Defaults to
570668
`True`.
@@ -580,7 +678,12 @@ def apply_stack(
580678
Output tensor for the stacked blocks.
581679
"""
582680
if name is None:
583-
version = "v1" if not use_pre_activation else "v2"
681+
if use_vd_pooling:
682+
version = "vd"
683+
elif use_pre_activation:
684+
version = "v2"
685+
else:
686+
version = "v1"
584687
name = f"{version}_stack"
585688

586689
if block_type == "basic_block":
@@ -605,6 +708,7 @@ def apply_stack(
605708
stride=stride,
606709
conv_shortcut=conv_shortcut,
607710
use_pre_activation=use_pre_activation,
711+
use_vd_pooling=use_vd_pooling,
608712
data_format=data_format,
609713
dtype=dtype,
610714
name=f"{name}_block{str(i)}",

keras_nlp/src/models/resnet/resnet_backbone_test.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,23 @@ def setUp(self):
3434
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))
3535

3636
@parameterized.named_parameters(
37-
("v1_basic", False, "basic_block"),
38-
("v1_bottleneck", False, "bottleneck_block"),
39-
("v2_basic", True, "basic_block"),
40-
("v2_bottleneck", True, "bottleneck_block"),
37+
("v1_basic", False, False, "basic_block"),
38+
("v1_bottleneck", False, False, "bottleneck_block"),
39+
("v2_basic", True, False, "basic_block"),
40+
("v2_bottleneck", True, False, "bottleneck_block"),
41+
("vd_basic", False, True, "basic_block"),
42+
("vd_bottleneck", False, True, "bottleneck_block"),
4143
)
42-
def test_backbone_basics(self, use_pre_activation, block_type):
44+
def test_backbone_basics(
45+
self, use_pre_activation, use_vd_pooling, block_type
46+
):
4347
init_kwargs = self.init_kwargs.copy()
4448
init_kwargs.update(
45-
{"block_type": block_type, "use_pre_activation": use_pre_activation}
49+
{
50+
"block_type": block_type,
51+
"use_pre_activation": use_pre_activation,
52+
"use_vd_pooling": use_vd_pooling,
53+
}
4654
)
4755
self.run_vision_backbone_test(
4856
cls=ResNetBackbone,
@@ -72,18 +80,21 @@ def test_pyramid_output_format(self):
7280
self.assertEqual(tuple(v.shape[:3]), (2, size, size))
7381

7482
@parameterized.named_parameters(
75-
("v1_basic", False, "basic_block"),
76-
("v1_bottleneck", False, "bottleneck_block"),
77-
("v2_basic", True, "basic_block"),
78-
("v2_bottleneck", True, "bottleneck_block"),
83+
("v1_basic", False, False, "basic_block"),
84+
("v1_bottleneck", False, False, "bottleneck_block"),
85+
("v2_basic", True, False, "basic_block"),
86+
("v2_bottleneck", True, False, "bottleneck_block"),
87+
("vd_basic", False, True, "basic_block"),
88+
("vd_bottleneck", False, True, "bottleneck_block"),
7989
)
8090
@pytest.mark.large
81-
def test_saved_model(self, use_pre_activation, block_type):
91+
def test_saved_model(self, use_pre_activation, use_vd_pooling, block_type):
8292
init_kwargs = self.init_kwargs.copy()
8393
init_kwargs.update(
8494
{
8595
"block_type": block_type,
8696
"use_pre_activation": use_pre_activation,
97+
"use_vd_pooling": use_vd_pooling,
8798
"input_image_shape": (None, None, 3),
8899
}
89100
)

0 commit comments

Comments
 (0)