Skip to content

Commit 3b26d3a

Browse files
committed
Subclass Imageconverter and overload call method for object detection method
1 parent 05fdefe commit 3b26d3a

File tree

4 files changed

+73
-34
lines changed

4 files changed

+73
-34
lines changed

keras_hub/src/models/image_object_detector_preprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ def __init__(
7070
@preprocessing_function
7171
def call(self, x, y=None, sample_weight=None):
7272
if self.image_converter:
73-
x = self.image_converter(x)
73+
x, y = self.image_converter(x, y)
7474
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,63 @@
1+
from keras import ops
2+
13
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.bounding_box.converters import convert_format
25
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
36
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
7+
from keras_hub.src.utils.keras_utils import standardize_data_format
8+
from keras_hub.src.utils.tensor_utils import preprocessing_function
49

510

611
@keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
712
class RetinaNetImageConverter(ImageConverter):
813
backbone_cls = RetinaNetBackbone
14+
15+
def __init__(
16+
self,
17+
ground_truth_bounding_box_format,
18+
target_bounding_box_format,
19+
image_size=None,
20+
scale=None,
21+
offset=None,
22+
crop_to_aspect_ratio=True,
23+
interpolation="bilinear",
24+
data_format=None,
25+
**kwargs
26+
):
27+
super().__init__(**kwargs)
28+
self.ground_truth_bounding_box_format = ground_truth_bounding_box_format
29+
self.target_bounding_box_format = target_bounding_box_format
30+
self.image_size = image_size
31+
self.scale = scale
32+
self.offset = offset
33+
self.crop_to_aspect_ratio = crop_to_aspect_ratio
34+
self.interpolation = interpolation
35+
self.data_format = standardize_data_format(data_format)
36+
37+
@preprocessing_function
38+
def call(self, x, y=None, sample_weight=None, **kwargs):
39+
if self.image_size is not None:
40+
x = self.resizing(x)
41+
if self.offset is not None:
42+
x -= self._expand_non_channel_dims(self.offset, x)
43+
if self.scale is not None:
44+
x /= self._expand_non_channel_dims(self.scale, x)
45+
if y is not None and ops.is_tensor(y):
46+
y = convert_format(
47+
y,
48+
source=self.ground_truth_bounding_box_format,
49+
target=self.target_bounding_box_format,
50+
images=x,
51+
)
52+
return x, y
53+
54+
def get_config(self):
55+
config = super().get_config()
56+
config.update(
57+
{
58+
"ground_truth_bounding_box_format": self.ground_truth_bounding_box_format,
59+
"target_bounding_box_format": self.target_bounding_box_format,
60+
}
61+
)
62+
63+
return config

keras_hub/src/models/retinanet/retinanet_object_detector.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,7 @@ class RetinaNetObjectDetector(ImageObjectDetector):
3232
`RetinaNetObjectDetector` training targets.
3333
anchor_generator: A `keras_Hub.layers.AnchorGenerator`.
3434
num_classes: The number of object classes to be detected.
35-
ground_truth_bounding_box_format: Ground truth bounding box format.
36-
Refer TODO: https://github.com/keras-team/keras-hub/issues/1907
37-
Ensure that ground truth boxes follow one of the following formats.
38-
- `rel_xyxy`
39-
- `rel_yxyx`
40-
- `rel_xywh`
41-
target_bounding_box_format: Target bounding box format.
35+
bounding_box_format: The format of bounding boxes of input dataset.
4236
Refer TODO: https://github.com/keras-team/keras-hub/issues/1907
4337
preprocessor: Optional. An instance of the
4438
`RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor.
@@ -60,24 +54,13 @@ def __init__(
6054
label_encoder,
6155
anchor_generator,
6256
num_classes,
63-
ground_truth_bounding_box_format,
64-
target_bounding_box_format,
57+
bounding_box_format,
6558
preprocessor=None,
6659
activation=None,
6760
dtype=None,
6861
prediction_decoder=None,
6962
**kwargs,
7063
):
71-
if "rel" not in ground_truth_bounding_box_format:
72-
raise ValueError(
73-
f"Only relative bounding box formats are supported "
74-
f"Received ground_truth_bounding_box_format="
75-
f"`{ground_truth_bounding_box_format}`. "
76-
f"Please provide a `ground_truth_bounding_box_format` from one of "
77-
f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. "
78-
f"Ensure that the provided ground truth bounding boxes are "
79-
f"normalized and relative to the image size. "
80-
)
8164
# === Layers ===
8265
image_input = keras.layers.Input(backbone.image_shape, name="images")
8366
head_dtype = dtype or backbone.dtype_policy
@@ -131,8 +114,7 @@ def __init__(
131114
)
132115

133116
# === Config ===
134-
self.ground_truth_bounding_box_format = ground_truth_bounding_box_format
135-
self.target_bounding_box_format = target_bounding_box_format
117+
self.bounding_box_format = bounding_box_format
136118
self.num_classes = num_classes
137119
self.backbone = backbone
138120
self.preprocessor = preprocessor
@@ -143,13 +125,13 @@ def __init__(
143125
self.classification_head = classification_head
144126
self._prediction_decoder = prediction_decoder or NonMaxSuppression(
145127
from_logits=(activation != keras.activations.sigmoid),
146-
bounding_box_format=self.target_bounding_box_format,
128+
bounding_box_format=self.bounding_box_format,
147129
)
148130

149131
def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
150132
y_for_label_encoder = convert_format(
151133
y,
152-
source=self.ground_truth_bounding_box_format,
134+
source=self.bounding_box_format,
153135
target=self.label_encoder.bounding_box_format,
154136
images=x,
155137
)
@@ -255,14 +237,14 @@ def decode_predictions(self, predictions, data):
255237
anchors=anchor_boxes,
256238
boxes_delta=box_pred,
257239
anchor_format=self.anchor_generator.bounding_box_format,
258-
box_format=self.target_bounding_box_format,
240+
box_format=self.bounding_box_format,
259241
variance=BOX_VARIANCE,
260242
image_shape=image_shape,
261243
)
262-
# box_pred is now in "self.target_bounding_box_format" format
244+
# box_pred is now in "self.bounding_box_format" format
263245
box_pred = convert_format(
264246
box_pred,
265-
source=self.target_bounding_box_format,
247+
source=self.bounding_box_format,
266248
target=self.prediction_decoder.bounding_box_format,
267249
image_shape=image_shape,
268250
)
@@ -272,7 +254,7 @@ def decode_predictions(self, predictions, data):
272254
y_pred["boxes"] = convert_format(
273255
y_pred["boxes"],
274256
source=self.prediction_decoder.bounding_box_format,
275-
target=self.target_bounding_box_format,
257+
target=self.bounding_box_format,
276258
image_shape=image_shape,
277259
)
278260
return y_pred
@@ -282,8 +264,7 @@ def get_config(self):
282264
config.update(
283265
{
284266
"num_classes": self.num_classes,
285-
"ground_truth_bounding_box_format": self.ground_truth_bounding_box_format,
286-
"target_bounding_box_format": self.target_bounding_box_format,
267+
"bounding_box_format": self.bounding_box_format,
287268
"anchor_generator": keras.layers.serialize(
288269
self.anchor_generator
289270
),

keras_hub/src/models/retinanet/retinanet_object_detector_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import numpy as np
22
import pytest
33

4-
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
54
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
65
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
76
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
7+
from keras_hub.src.models.retinanet.retinanet_image_converter import (
8+
RetinaNetImageConverter,
9+
)
810
from keras_hub.src.models.retinanet.retinanet_label_encoder import (
911
RetinaNetLabelEncoder,
1012
)
@@ -50,8 +52,10 @@ def setUp(self):
5052
bounding_box_format="yxyx", anchor_generator=anchor_generator
5153
)
5254

53-
image_converter = ImageConverter(
55+
image_converter = RetinaNetImageConverter(
5456
image_size=(256, 256),
57+
ground_truth_bounding_box_format="rel_yxyx",
58+
target_bounding_box_format="yxyx",
5559
)
5660

5761
preprocessor = RetinaNetObjectDetectorPreprocessor(
@@ -63,8 +67,7 @@ def setUp(self):
6367
"anchor_generator": anchor_generator,
6468
"label_encoder": label_encoder,
6569
"num_classes": 10,
66-
"ground_truth_bounding_box_format": "rel_yxyx",
67-
"target_bounding_box_format": "xywh",
70+
"bounding_box_format": "yxyx",
6871
"preprocessor": preprocessor,
6972
}
7073

0 commit comments

Comments
 (0)