@@ -32,13 +32,7 @@ class RetinaNetObjectDetector(ImageObjectDetector):
32
32
`RetinaNetObjectDetector` training targets.
33
33
anchor_generator: A `keras_Hub.layers.AnchorGenerator`.
34
34
num_classes: The number of object classes to be detected.
35
- ground_truth_bounding_box_format: Ground truth bounding box format.
36
- Refer TODO: https://github.com/keras-team/keras-hub/issues/1907
37
- Ensure that ground truth boxes follow one of the following formats.
38
- - `rel_xyxy`
39
- - `rel_yxyx`
40
- - `rel_xywh`
41
- target_bounding_box_format: Target bounding box format.
35
+ bounding_box_format: The format of bounding boxes of input dataset.
42
36
Refer TODO: https://github.com/keras-team/keras-hub/issues/1907
43
37
preprocessor: Optional. An instance of the
44
38
`RetinaNetObjectDetectorPreprocessor` class or a custom preprocessor.
@@ -60,24 +54,13 @@ def __init__(
60
54
label_encoder ,
61
55
anchor_generator ,
62
56
num_classes ,
63
- ground_truth_bounding_box_format ,
64
- target_bounding_box_format ,
57
+ bounding_box_format ,
65
58
preprocessor = None ,
66
59
activation = None ,
67
60
dtype = None ,
68
61
prediction_decoder = None ,
69
62
** kwargs ,
70
63
):
71
- if "rel" not in ground_truth_bounding_box_format :
72
- raise ValueError (
73
- f"Only relative bounding box formats are supported "
74
- f"Received ground_truth_bounding_box_format="
75
- f"`{ ground_truth_bounding_box_format } `. "
76
- f"Please provide a `ground_truth_bounding_box_format` from one of "
77
- f"the following `rel_xyxy` or `rel_yxyx` or `rel_xywh`. "
78
- f"Ensure that the provided ground truth bounding boxes are "
79
- f"normalized and relative to the image size. "
80
- )
81
64
# === Layers ===
82
65
image_input = keras .layers .Input (backbone .image_shape , name = "images" )
83
66
head_dtype = dtype or backbone .dtype_policy
@@ -131,8 +114,7 @@ def __init__(
131
114
)
132
115
133
116
# === Config ===
134
- self .ground_truth_bounding_box_format = ground_truth_bounding_box_format
135
- self .target_bounding_box_format = target_bounding_box_format
117
+ self .bounding_box_format = bounding_box_format
136
118
self .num_classes = num_classes
137
119
self .backbone = backbone
138
120
self .preprocessor = preprocessor
@@ -143,13 +125,13 @@ def __init__(
143
125
self .classification_head = classification_head
144
126
self ._prediction_decoder = prediction_decoder or NonMaxSuppression (
145
127
from_logits = (activation != keras .activations .sigmoid ),
146
- bounding_box_format = self .target_bounding_box_format ,
128
+ bounding_box_format = self .bounding_box_format ,
147
129
)
148
130
149
131
def compute_loss (self , x , y , y_pred , sample_weight , ** kwargs ):
150
132
y_for_label_encoder = convert_format (
151
133
y ,
152
- source = self .ground_truth_bounding_box_format ,
134
+ source = self .bounding_box_format ,
153
135
target = self .label_encoder .bounding_box_format ,
154
136
images = x ,
155
137
)
@@ -255,14 +237,14 @@ def decode_predictions(self, predictions, data):
255
237
anchors = anchor_boxes ,
256
238
boxes_delta = box_pred ,
257
239
anchor_format = self .anchor_generator .bounding_box_format ,
258
- box_format = self .target_bounding_box_format ,
240
+ box_format = self .bounding_box_format ,
259
241
variance = BOX_VARIANCE ,
260
242
image_shape = image_shape ,
261
243
)
262
- # box_pred is now in "self.target_bounding_box_format " format
244
+ # box_pred is now in "self.bounding_box_format " format
263
245
box_pred = convert_format (
264
246
box_pred ,
265
- source = self .target_bounding_box_format ,
247
+ source = self .bounding_box_format ,
266
248
target = self .prediction_decoder .bounding_box_format ,
267
249
image_shape = image_shape ,
268
250
)
@@ -272,7 +254,7 @@ def decode_predictions(self, predictions, data):
272
254
y_pred ["boxes" ] = convert_format (
273
255
y_pred ["boxes" ],
274
256
source = self .prediction_decoder .bounding_box_format ,
275
- target = self .target_bounding_box_format ,
257
+ target = self .bounding_box_format ,
276
258
image_shape = image_shape ,
277
259
)
278
260
return y_pred
@@ -282,8 +264,7 @@ def get_config(self):
282
264
config .update (
283
265
{
284
266
"num_classes" : self .num_classes ,
285
- "ground_truth_bounding_box_format" : self .ground_truth_bounding_box_format ,
286
- "target_bounding_box_format" : self .target_bounding_box_format ,
267
+ "bounding_box_format" : self .bounding_box_format ,
287
268
"anchor_generator" : keras .layers .serialize (
288
269
self .anchor_generator
289
270
),
0 commit comments