@@ -27,16 +27,23 @@ class ResNetBackbone(FeaturePyramidBackbone):
27
27
This class implements a ResNet backbone as described in [Deep Residual
28
28
Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
29
29
CVPR 2016), [Identity Mappings in Deep Residual Networks](
30
- https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An
30
+ https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An
31
31
improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
32
- NeurIPS 2021 Workshop).
32
+ NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with
33
+ Convolutional Neural Networks](https://arxiv.org/abs/1812.01187).
33
34
34
35
The difference in ResNet and ResNetV2 rests in the structure of their
35
36
individual building blocks. In ResNetV2, the batch normalization and
36
37
ReLU activation precede the convolution layers, as opposed to ResNet where
37
38
the batch normalization and ReLU activation are applied after the
38
39
convolution layers.
39
40
41
+ ResNetVd introduces two key modifications to the standard ResNet. First,
42
+ the initial convolutional layer is replaced by a series of three
43
+ successive convolutional layers. Second, shortcut connections use an
44
+ additional pooling operation rather than performing downsampling within
45
+ the convolutional layers themselves.
46
+
40
47
Note that `ResNetBackbone` expects the inputs to be images with a value
41
48
range of `[0, 255]` when `include_rescaling=True`.
42
49
@@ -51,6 +58,7 @@ class ResNetBackbone(FeaturePyramidBackbone):
51
58
Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
52
59
use_pre_activation: boolean. Whether to use pre-activation or not.
53
60
`True` for ResNetV2, `False` for ResNet.
61
+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
54
62
include_rescaling: boolean. If `True`, rescale the input using
55
63
`Rescaling` and `Normalization` layers. If `False`, do nothing.
56
64
Defaults to `True`.
@@ -106,6 +114,7 @@ def __init__(
106
114
stackwise_num_strides ,
107
115
block_type ,
108
116
use_pre_activation = False ,
117
+ use_vd_pooling = False ,
109
118
include_rescaling = True ,
110
119
input_image_shape = (None , None , 3 ),
111
120
pooling = "avg" ,
@@ -133,7 +142,12 @@ def __init__(
133
142
'`block_type` must be either `"basic_block"` or '
134
143
f'`"bottleneck_block"`. Received block_type={ block_type } .'
135
144
)
136
- version = "v1" if not use_pre_activation else "v2"
145
+ if use_vd_pooling :
146
+ version = "vd"
147
+ elif use_pre_activation :
148
+ version = "v2"
149
+ else :
150
+ version = "v1"
137
151
data_format = standardize_data_format (data_format )
138
152
bn_axis = - 1 if data_format == "channels_last" else 1
139
153
num_stacks = len (stackwise_num_filters )
@@ -155,21 +169,21 @@ def __init__(
155
169
# The padding between torch and tensorflow/jax differs when `strides>1`.
156
170
# Therefore, we need to manually pad the tensor.
157
171
x = layers .ZeroPadding2D (
158
- 3 ,
172
+ 1 if use_vd_pooling else 3 ,
159
173
data_format = data_format ,
160
174
dtype = dtype ,
161
175
name = "conv1_pad" ,
162
176
)(x )
163
- x = layers . Conv2D (
164
- 64 ,
165
- 7 ,
166
- strides = 2 ,
167
- data_format = data_format ,
168
- use_bias = False ,
169
- dtype = dtype ,
170
- name = "conv1_conv" ,
171
- )( x )
172
- if not use_pre_activation :
177
+ if use_vd_pooling :
178
+ x = layers . Conv2D (
179
+ 32 ,
180
+ 3 ,
181
+ strides = 2 ,
182
+ data_format = data_format ,
183
+ use_bias = False ,
184
+ dtype = dtype ,
185
+ name = "conv1_conv" ,
186
+ )( x )
173
187
x = layers .BatchNormalization (
174
188
axis = bn_axis ,
175
189
epsilon = 1e-5 ,
@@ -178,6 +192,57 @@ def __init__(
178
192
name = "conv1_bn" ,
179
193
)(x )
180
194
x = layers .Activation ("relu" , dtype = dtype , name = "conv1_relu" )(x )
195
+ x = layers .Conv2D (
196
+ 32 ,
197
+ 3 ,
198
+ strides = 1 ,
199
+ padding = "same" ,
200
+ data_format = data_format ,
201
+ use_bias = False ,
202
+ dtype = dtype ,
203
+ name = "conv2_conv" ,
204
+ )(x )
205
+ x = layers .BatchNormalization (
206
+ axis = bn_axis ,
207
+ epsilon = 1e-5 ,
208
+ momentum = 0.9 ,
209
+ dtype = dtype ,
210
+ name = "conv2_bn" ,
211
+ )(x )
212
+ x = layers .Activation ("relu" , dtype = dtype , name = "conv2_relu" )(x )
213
+ x = layers .Conv2D (
214
+ 64 ,
215
+ 3 ,
216
+ strides = 1 ,
217
+ padding = "same" ,
218
+ data_format = data_format ,
219
+ use_bias = False ,
220
+ dtype = dtype ,
221
+ name = "conv3_conv" ,
222
+ )(x )
223
+ else :
224
+ x = layers .Conv2D (
225
+ 64 ,
226
+ 7 ,
227
+ strides = 2 ,
228
+ data_format = data_format ,
229
+ use_bias = False ,
230
+ dtype = dtype ,
231
+ name = "conv1_conv" ,
232
+ )(x )
233
+ if not use_pre_activation :
234
+ x = layers .BatchNormalization (
235
+ axis = bn_axis ,
236
+ epsilon = 1e-5 ,
237
+ momentum = 0.9 ,
238
+ dtype = dtype ,
239
+ name = "conv3_bn" if use_vd_pooling else "conv1_bn" ,
240
+ )(x )
241
+ x = layers .Activation (
242
+ "relu" ,
243
+ dtype = dtype ,
244
+ name = "conv3_relu" if use_vd_pooling else "conv1_relu" ,
245
+ )(x )
181
246
182
247
if use_pre_activation :
183
248
# A workaround for ResNetV2: we need -inf padding to prevent zeros
@@ -210,8 +275,11 @@ def __init__(
210
275
stride = stackwise_num_strides [stack_index ],
211
276
block_type = block_type ,
212
277
use_pre_activation = use_pre_activation ,
278
+ use_vd_pooling = use_vd_pooling ,
213
279
first_shortcut = (
214
- block_type == "bottleneck_block" or stack_index > 0
280
+ block_type == "bottleneck_block"
281
+ or stack_index > 0
282
+ or use_vd_pooling
215
283
),
216
284
data_format = data_format ,
217
285
dtype = dtype ,
@@ -253,6 +321,7 @@ def __init__(
253
321
self .stackwise_num_strides = stackwise_num_strides
254
322
self .block_type = block_type
255
323
self .use_pre_activation = use_pre_activation
324
+ self .use_vd_pooling = use_vd_pooling
256
325
self .include_rescaling = include_rescaling
257
326
self .input_image_shape = input_image_shape
258
327
self .pooling = pooling
@@ -267,6 +336,7 @@ def get_config(self):
267
336
"stackwise_num_strides" : self .stackwise_num_strides ,
268
337
"block_type" : self .block_type ,
269
338
"use_pre_activation" : self .use_pre_activation ,
339
+ "use_vd_pooling" : self .use_vd_pooling ,
270
340
"include_rescaling" : self .include_rescaling ,
271
341
"input_image_shape" : self .input_image_shape ,
272
342
"pooling" : self .pooling ,
@@ -282,6 +352,7 @@ def apply_basic_block(
282
352
stride = 1 ,
283
353
conv_shortcut = False ,
284
354
use_pre_activation = False ,
355
+ use_vd_pooling = False ,
285
356
data_format = None ,
286
357
dtype = None ,
287
358
name = None ,
@@ -299,6 +370,7 @@ def apply_basic_block(
299
370
`False`.
300
371
use_pre_activation: boolean. Whether to use pre-activation or not.
301
372
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
373
+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
302
374
data_format: `None` or str. the ordering of the dimensions in the
303
375
inputs. Can be `"channels_last"`
304
376
(`(batch_size, height, width, channels)`) or`"channels_first"`
@@ -327,16 +399,27 @@ def apply_basic_block(
327
399
)(x_preact )
328
400
329
401
if conv_shortcut :
330
- x = x_preact if x_preact is not None else x
402
+ if x_preact is not None :
403
+ shortcut = x_preact
404
+ elif use_vd_pooling and stride > 1 :
405
+ shortcut = layers .AveragePooling2D (
406
+ 2 ,
407
+ strides = stride ,
408
+ data_format = data_format ,
409
+ dtype = dtype ,
410
+ padding = "same" ,
411
+ )(x )
412
+ else :
413
+ shortcut = x
331
414
shortcut = layers .Conv2D (
332
415
filters ,
333
416
1 ,
334
- strides = stride ,
417
+ strides = 1 if use_vd_pooling else stride ,
335
418
data_format = data_format ,
336
419
use_bias = False ,
337
420
dtype = dtype ,
338
421
name = f"{ name } _0_conv" ,
339
- )(x )
422
+ )(shortcut )
340
423
if not use_pre_activation :
341
424
shortcut = layers .BatchNormalization (
342
425
axis = bn_axis ,
@@ -407,6 +490,7 @@ def apply_bottleneck_block(
407
490
stride = 1 ,
408
491
conv_shortcut = False ,
409
492
use_pre_activation = False ,
493
+ use_vd_pooling = False ,
410
494
data_format = None ,
411
495
dtype = None ,
412
496
name = None ,
@@ -424,6 +508,7 @@ def apply_bottleneck_block(
424
508
`False`.
425
509
use_pre_activation: boolean. Whether to use pre-activation or not.
426
510
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
511
+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
427
512
data_format: `None` or str. the ordering of the dimensions in the
428
513
inputs. Can be `"channels_last"`
429
514
(`(batch_size, height, width, channels)`) or`"channels_first"`
@@ -452,16 +537,27 @@ def apply_bottleneck_block(
452
537
)(x_preact )
453
538
454
539
if conv_shortcut :
455
- x = x_preact if x_preact is not None else x
540
+ if x_preact is not None :
541
+ shortcut = x_preact
542
+ elif use_vd_pooling and stride > 1 :
543
+ shortcut = layers .AveragePooling2D (
544
+ 2 ,
545
+ strides = stride ,
546
+ data_format = data_format ,
547
+ dtype = dtype ,
548
+ padding = "same" ,
549
+ )(x )
550
+ else :
551
+ shortcut = x
456
552
shortcut = layers .Conv2D (
457
553
4 * filters ,
458
554
1 ,
459
- strides = stride ,
555
+ strides = 1 if use_vd_pooling else stride ,
460
556
data_format = data_format ,
461
557
use_bias = False ,
462
558
dtype = dtype ,
463
559
name = f"{ name } _0_conv" ,
464
- )(x )
560
+ )(shortcut )
465
561
if not use_pre_activation :
466
562
shortcut = layers .BatchNormalization (
467
563
axis = bn_axis ,
@@ -548,6 +644,7 @@ def apply_stack(
548
644
stride ,
549
645
block_type ,
550
646
use_pre_activation ,
647
+ use_vd_pooling = False ,
551
648
first_shortcut = True ,
552
649
data_format = None ,
553
650
dtype = None ,
@@ -565,6 +662,7 @@ def apply_stack(
565
662
Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
566
663
use_pre_activation: boolean. Whether to use pre-activation or not.
567
664
`True` for ResNetV2, `False` for ResNet and ResNeXt.
665
+ use_vd_pooling: boolean. Whether to use ResNetVd enhancements.
568
666
first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
569
667
use an identity or pooling shortcut based on the stride. Defaults to
570
668
`True`.
@@ -580,7 +678,12 @@ def apply_stack(
580
678
Output tensor for the stacked blocks.
581
679
"""
582
680
if name is None :
583
- version = "v1" if not use_pre_activation else "v2"
681
+ if use_vd_pooling :
682
+ version = "vd"
683
+ elif use_pre_activation :
684
+ version = "v2"
685
+ else :
686
+ version = "v1"
584
687
name = f"{ version } _stack"
585
688
586
689
if block_type == "basic_block" :
@@ -605,6 +708,7 @@ def apply_stack(
605
708
stride = stride ,
606
709
conv_shortcut = conv_shortcut ,
607
710
use_pre_activation = use_pre_activation ,
711
+ use_vd_pooling = use_vd_pooling ,
608
712
data_format = data_format ,
609
713
dtype = dtype ,
610
714
name = f"{ name } _block{ str (i )} " ,
0 commit comments