Skip to content

Commit 5d97d1a

Browse files
authored
[RetinaNet] Image Converter and ObjectDetector (#1906)
* Rebased phase 1 changes * Rebased phase 1 changes * nit * Retina Phase 2 * nit * Expose Anchor Generator as layer, docstring correction and test correction * nit * Add missing args for prediction heads * - Use FeaturePyramidBackbone cls for RetinaNet backbone. - Correct test cases. * fix decoding error * - Add ground truth arg for RetinaNet model and remove source and target format from preprocessor * nit * Subclass Imageconverter and overload call method for object detection method * Revert "Subclass Imageconverter and overload call method for object detection method" This reverts commit 3b26d3a. * add names to layers * correct fpn coarser level as per torch retinanet model * nit * Polish Prediction head and fpn layers to include flags and norm layers * nit * nit * add prior probability flag for prediction head to use it for classification head and user friendly * compute_shape seems redudant here and correct layers for channels_first * keep compute_output_shape for fpn * nit * Change AnchorGen Implementation as per torch * correct the source format of anchors format * use plain rescaling and normalization no resizing for od models as it can effect the bounding boxes and the ops i backend framework dependent * use single bbox format for model * - Add arg for encoding format - Add required docstrings - Use `center_xywh` encoding for retinanet as per torch weights * make anchor generator optional * init as layers for anchor generator and label encoder and as one more arg for prediction head configuration * nit * - only consider levels from min level to backbone maxlevel fro feature extraction from image encoder * nit * nit * update resizing as per new keras3 resizing layer for bboxes * Revert "update resizing as per new keras3 resizing layer for bboxes" This reverts commit eb555ca. * Add TODO's for keras bounding box ops * Use keras layers to rescale and normalize * check with plain values * use convert_preprocessing_inputs function for basic operations as backend cause some gpu misplacement * use keras for init variables * modify task test for cases when test runs on gpu * modify the order of steps * fix tensor device placement error for torch backend * this should fix error while image size is give and not given cases * use numpy arrays * make `yxyx` as default bbox format and some nit * use image_size argument so that we dont break presets * Add retinanet_resnet50_fpn_coco preset * register retinanet presets
1 parent 5a7ecb6 commit 5d97d1a

24 files changed

+1504
-230
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
from keras_hub.src.models.resnet.resnet_image_converter import (
5252
ResNetImageConverter,
5353
)
54+
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
55+
from keras_hub.src.models.retinanet.retinanet_image_converter import (
56+
RetinaNetImageConverter,
57+
)
5458
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
5559
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
5660
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder

keras_hub/api/models/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@
185185
from keras_hub.src.models.image_classifier_preprocessor import (
186186
ImageClassifierPreprocessor,
187187
)
188+
from keras_hub.src.models.image_object_detector import ImageObjectDetector
189+
from keras_hub.src.models.image_object_detector_preprocessor import (
190+
ImageObjectDetectorPreprocessor,
191+
)
188192
from keras_hub.src.models.image_segmenter import ImageSegmenter
189193
from keras_hub.src.models.image_segmenter_preprocessor import (
190194
ImageSegmenterPreprocessor,
@@ -252,6 +256,13 @@
252256
from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
253257
ResNetImageClassifierPreprocessor,
254258
)
259+
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
260+
from keras_hub.src.models.retinanet.retinanet_object_detector import (
261+
RetinaNetObjectDetector,
262+
)
263+
from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import (
264+
RetinaNetObjectDetectorPreprocessor,
265+
)
255266
from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
256267
from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM
257268
from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import (
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# TODO: Once all bounding boxes are moved to keras repostory remove the
2+
# bounding box folder.

keras_hub/src/bounding_box/converters.py

Lines changed: 102 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,74 @@ class RequiresImagesException(Exception):
2020
ALL_AXES = 4
2121

2222

23-
def _encode_box_to_deltas(
23+
def encode_box_to_deltas(
2424
anchors,
2525
boxes,
26-
anchor_format: str,
27-
box_format: str,
26+
anchor_format,
27+
box_format,
28+
encoding_format="center_yxhw",
2829
variance=None,
2930
image_shape=None,
3031
):
31-
"""Converts bounding_boxes from `center_yxhw` to delta format."""
32+
"""Encodes bounding boxes relative to anchors as deltas.
33+
34+
This function calculates the deltas that represent the difference between
35+
bounding boxes and provided anchors. Deltas encode the offsets and scaling
36+
factors to apply to anchors to obtain the target boxes.
37+
38+
Boxes and anchors are first converted to the specified `encoding_format`
39+
(defaulting to `center_yxhw`) for consistent delta representation.
40+
41+
Args:
42+
anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the
43+
number of anchors.
44+
boxes: `Tensors` Bounding boxes to encode. Boxes can be of shape
45+
`(B, N, 4)` or `(N, 4)`.
46+
anchor_format: str. The format of the input `anchors`
47+
(e.g., "xyxy", "xywh", etc.).
48+
box_format: str. The format of the input `boxes`
49+
(e.g., "xyxy", "xywh", etc.).
50+
encoding_format: str. The intermediate format to which boxes and anchors
51+
are converted before delta calculation. Defaults to "center_yxhw".
52+
variance: `List[float]`. A 4-element array/tensor representing variance
53+
factors to scale the box deltas. If provided, the calculated deltas
54+
are divided by the variance. Defaults to None.
55+
image_shape: `Tuple[int]`. The shape of the image (height, width, 3).
56+
When using relative bounding box format for `box_format` the
57+
`image_shape` is used for normalization.
58+
Returns:
59+
Encoded box deltas. The return type matches the `encode_format`.
60+
61+
Raises:
62+
ValueError: If `variance` is not None and its length is not 4.
63+
ValueError: If `encoding_format` is not `"center_xywh"` or
64+
`"center_yxhw"`.
65+
66+
"""
3267
if variance is not None:
3368
variance = ops.convert_to_tensor(variance, "float32")
3469
var_len = variance.shape[-1]
3570

3671
if var_len != 4:
3772
raise ValueError(f"`variance` must be length 4, got {variance}")
73+
74+
if encoding_format not in ["center_xywh", "center_yxhw"]:
75+
raise ValueError(
76+
"`encoding_format` should be one of 'center_xywh' or 'center_yxhw', "
77+
f"got {encoding_format}"
78+
)
79+
3880
encoded_anchors = convert_format(
3981
anchors,
4082
source=anchor_format,
41-
target="center_yxhw",
83+
target=encoding_format,
4284
image_shape=image_shape,
4385
)
4486
boxes = convert_format(
45-
boxes, source=box_format, target="center_yxhw", image_shape=image_shape
87+
boxes,
88+
source=box_format,
89+
target=encoding_format,
90+
image_shape=image_shape,
4691
)
4792
anchor_dimensions = ops.maximum(
4893
encoded_anchors[..., 2:], keras.backend.epsilon()
@@ -61,27 +106,72 @@ def _encode_box_to_deltas(
61106
return boxes_delta
62107

63108

64-
def _decode_deltas_to_boxes(
109+
def decode_deltas_to_boxes(
65110
anchors,
66111
boxes_delta,
67-
anchor_format: str,
68-
box_format: str,
112+
anchor_format,
113+
box_format,
114+
encoded_format="center_yxhw",
69115
variance=None,
70116
image_shape=None,
71117
):
72-
"""Converts bounding_boxes from delta format to `center_yxhw`."""
118+
"""Converts bounding boxes from delta format to the specified `box_format`.
119+
120+
This function decodes bounding box deltas relative to anchors to obtain the
121+
final bounding box coordinates. The boxes are encoded in a specific
122+
`encoded_format` (center_yxhw by default) during the decoding process.
123+
This allows flexibility in how the deltas are applied to the anchors.
124+
125+
Args:
126+
anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level
127+
indices and values are corresponding anchor boxes.
128+
The shape of the array/tensor should be `(N, 4)` where N is the
129+
number of anchors.
130+
boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas
131+
must have the same type and structure as `anchors`. The
132+
shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is
133+
the number of boxes.
134+
anchor_format: str. The format of the input `anchors`.
135+
(e.g., `"xyxy"`, `"xywh"`, etc.)
136+
box_format: str. The desired format for the output boxes.
137+
(e.g., `"xyxy"`, `"xywh"`, etc.)
138+
encoded_format: str. Raw output format from regression head. Defaults
139+
to `"center_yxhw"`.
140+
variance: `List[floats]`. A 4-element array/tensor representing
141+
variance factors to scale the box deltas. If provided, the deltas
142+
are multiplied by the variance before being applied to the anchors.
143+
Defaults to None.
144+
image_shape: The shape of the image (height, width). This is needed
145+
if normalization to image size is required when converting between
146+
formats. Defaults to None.
147+
148+
Returns:
149+
Decoded box coordinates. The return type matches the `box_format`.
150+
151+
Raises:
152+
ValueError: If `variance` is not None and its length is not 4.
153+
ValueError: If `encoded_format` is not `"center_xywh"` or
154+
`"center_yxhw"`.
155+
156+
"""
73157
if variance is not None:
74158
variance = ops.convert_to_tensor(variance, "float32")
75159
var_len = variance.shape[-1]
76160

77161
if var_len != 4:
78162
raise ValueError(f"`variance` must be length 4, got {variance}")
79163

164+
if encoded_format not in ["center_xywh", "center_yxhw"]:
165+
raise ValueError(
166+
f"`encoded_format` should be 'center_xywh' or 'center_yxhw', "
167+
f"but got '{encoded_format}'."
168+
)
169+
80170
def decode_single_level(anchor, box_delta):
81171
encoded_anchor = convert_format(
82172
anchor,
83173
source=anchor_format,
84-
target="center_yxhw",
174+
target=encoded_format,
85175
image_shape=image_shape,
86176
)
87177
if variance is not None:
@@ -97,7 +187,7 @@ def decode_single_level(anchor, box_delta):
97187
)
98188
box = convert_format(
99189
box,
100-
source="center_yxhw",
190+
source=encoded_format,
101191
target=box_format,
102192
image_shape=image_shape,
103193
)

keras_hub/src/layers/preprocessing/image_converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def _expand_non_channel_dims(self, value, inputs):
164164
# If inputs are not a tensor type, return a numpy array.
165165
# This might happen when running under tf.data.
166166
if ops.is_tensor(inputs):
167+
# preprocessing decorator moves tensors to cpu in torch backend and
168+
# processed on CPU, and then converted back to the appropriate
169+
# device (potentially GPU) after preprocessing.
170+
if keras.backend.backend() == "torch" and self.image_size is None:
171+
return ops.expand_dims(value, broadcast_dims).cpu()
167172
return ops.expand_dims(value, broadcast_dims)
168173
else:
169174
return np.expand_dims(value, broadcast_dims)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.task import Task
5+
6+
7+
@keras_hub_export("keras_hub.models.ImageObjectDetector")
8+
class ImageObjectDetector(Task):
9+
"""Base class for all image object detection tasks.
10+
11+
The `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and
12+
a `keras_hub.models.Preprocessor` to create a model that can be used for
13+
object detection. `ImageObjectDetector` tasks take an additional
14+
`num_classes` argument, controlling the number of predicted output classes.
15+
16+
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
17+
labels where `x` is a string and `y` is dictionary with `boxes` and
18+
`classes`.
19+
20+
All `ImageObjectDetector` tasks include a `from_preset()` constructor which
21+
can be used to load a pre-trained config and weights.
22+
"""
23+
24+
def compile(
25+
self,
26+
optimizer="auto",
27+
box_loss="auto",
28+
classification_loss="auto",
29+
metrics=None,
30+
**kwargs,
31+
):
32+
"""Configures the `ImageObjectDetector` task for training.
33+
34+
The `ImageObjectDetector` task extends the default compilation signature of
35+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
36+
`metrics`. To override these defaults, pass any value
37+
to these arguments during compilation.
38+
39+
Args:
40+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
41+
instance. Defaults to `"auto"`, which uses the default optimizer
42+
for the given model and task. See `keras.Model.compile` and
43+
`keras.optimizers` for more info on possible `optimizer` values.
44+
box_loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
45+
Defaults to `"auto"`, where a
46+
`keras.losses.Huber` loss will be
47+
applied for the object detector task. See
48+
`keras.Model.compile` and `keras.losses` for more info on
49+
possible `loss` values.
50+
classification_loss: `"auto"`, a loss name, or a `keras.losses.Loss`
51+
instance. Defaults to `"auto"`, where a
52+
`keras.losses.BinaryFocalCrossentropy` loss will be
53+
applied for the object detector task. See
54+
`keras.Model.compile` and `keras.losses` for more info on
55+
possible `loss` values.
56+
metrics: `a list of metrics to be evaluated by
57+
the model during training and testing. Defaults to `None`.
58+
See `keras.Model.compile` and `keras.metrics` for
59+
more info on possible `metrics` values.
60+
**kwargs: See `keras.Model.compile` for a full list of arguments
61+
supported by the compile method.
62+
"""
63+
if optimizer == "auto":
64+
optimizer = keras.optimizers.Adam(5e-5)
65+
if box_loss == "auto":
66+
box_loss = keras.losses.Huber(reduction="sum")
67+
if classification_loss == "auto":
68+
activation = getattr(self, "activation", None)
69+
activation = keras.activations.get(activation)
70+
from_logits = activation != keras.activations.sigmoid
71+
classification_loss = keras.losses.BinaryFocalCrossentropy(
72+
from_logits=from_logits, reduction="sum"
73+
)
74+
if metrics is not None:
75+
raise ValueError("User metrics not yet supported")
76+
77+
losses = {
78+
"bbox_regression": box_loss,
79+
"cls_logits": classification_loss,
80+
}
81+
82+
super().compile(
83+
optimizer=optimizer,
84+
loss=losses,
85+
metrics=metrics,
86+
**kwargs,
87+
)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.preprocessor import Preprocessor
5+
from keras_hub.src.utils.tensor_utils import preprocessing_function
6+
7+
8+
@keras_hub_export("keras_hub.models.ImageObjectDetectorPreprocessor")
9+
class ImageObjectDetectorPreprocessor(Preprocessor):
10+
"""Base class for object detector preprocessing layers.
11+
12+
`ImageObjectDetectorPreprocessor` tasks wraps a
13+
`keras_hub.layers.Preprocessor` to create a preprocessing layer for
14+
object detection tasks. It is intended to be paired with a
15+
`keras_hub.models.ImageObjectDetector` task.
16+
17+
All `ImageObjectDetectorPreprocessor` take three inputs, `x`, `y`, and
18+
`sample_weight`. `x`, the first input, should always be included. It can
19+
be a image or batch of images. See examples below. `y` and `sample_weight`
20+
are optional inputs that will be passed through unaltered. Usually, `y` will
21+
be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4),
22+
"classes": (batch_size, num_boxes)}.
23+
24+
The layer will returns either `x`, an `(x, y)` tuple if labels were provided,
25+
or an `(x, y, sample_weight)` tuple if labels and sample weight were
26+
provided. `x` will be the input images after all model preprocessing has
27+
been applied.
28+
29+
All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()`
30+
constructor which can be used to load a pre-trained config and vocabularies.
31+
You can call the `from_preset()` constructor directly on this base class, in
32+
which case the correct class for your model will be automatically
33+
instantiated.
34+
35+
Args:
36+
image_converter: Preprocessing pipeline for images.
37+
38+
Examples.
39+
```python
40+
preprocessor = keras_hub.models.ImageObjectDetectorPreprocessor.from_preset(
41+
"retinanet_resnet50",
42+
)
43+
"""
44+
45+
def __init__(
46+
self,
47+
image_converter=None,
48+
**kwargs,
49+
):
50+
super().__init__(**kwargs)
51+
self.image_converter = image_converter
52+
53+
@preprocessing_function
54+
def call(self, x, y=None, sample_weight=None):
55+
if self.image_converter:
56+
x = self.image_converter(x)
57+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
2+
from keras_hub.src.models.retinanet.retinanet_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, RetinaNetBackbone)

0 commit comments

Comments
 (0)