Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions keras/backend/jax/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,29 @@
)


def rgb_to_grayscale(image, data_format="channels_last"):
if data_format == "channels_first":
if len(image.shape) == 4:
image = jnp.transpose(image, (0, 2, 3, 1))
elif len(image.shape) == 3:
image = jnp.transpose(image, (1, 2, 0))
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
red, green, blue = image[..., 0], image[..., 1], image[..., 2]
grayscale_image = 0.2989 * red + 0.5870 * green + 0.1140 * blue
grayscale_image = jnp.expand_dims(grayscale_image, axis=-1)
if data_format == "channels_first":
if len(image.shape) == 4:
grayscale_image = jnp.transpose(grayscale_image, (0, 3, 1, 2))
elif len(image.shape) == 3:
grayscale_image = jnp.transpose(grayscale_image, (2, 0, 1))
return jnp.array(grayscale_image)


def resize(
image,
size,
Expand Down
23 changes: 23 additions & 0 deletions keras/backend/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@
)


def rgb_to_grayscale(image, data_format="channels_last"):
if data_format == "channels_first":
if len(image.shape) == 4:
image = np.transpose(image, (0, 2, 3, 1))
elif len(image.shape) == 3:
image = np.transpose(image, (1, 2, 0))
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
red, green, blue = image[..., 0], image[..., 1], image[..., 2]
grayscale_image = 0.2989 * red + 0.5870 * green + 0.1140 * blue
grayscale_image = np.expand_dims(grayscale_image, axis=-1)
if data_format == "channels_first":
if len(image.shape) == 4:
grayscale_image = np.transpose(grayscale_image, (0, 3, 1, 2))
elif len(image.shape) == 3:
grayscale_image = np.transpose(grayscale_image, (2, 0, 1))
return np.array(grayscale_image)


def resize(
image,
size,
Expand Down
21 changes: 21 additions & 0 deletions keras/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,27 @@
)


def rgb_to_grayscale(image, data_format="channels_last"):
if data_format == "channels_first":
if len(image.shape) == 4:
image = tf.transpose(image, (0, 2, 3, 1))
elif len(image.shape) == 3:
image = tf.transpose(image, (1, 2, 0))
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
grayscale_image = tf.image.rgb_to_grayscale(image)
if data_format == "channels_first":
if len(image.shape) == 4:
grayscale_image = tf.transpose(grayscale_image, (0, 3, 1, 2))
elif len(image.shape) == 3:
grayscale_image = tf.transpose(grayscale_image, (2, 0, 1))
return tf.cast(grayscale_image, image.dtype)


def resize(
image,
size,
Expand Down
32 changes: 32 additions & 0 deletions keras/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,38 @@
)


def rgb_to_grayscale(image, data_format="channel_last"):
try:
import torchvision
except:
raise ImportError(
"The torchvision package is necessary to use "
"`rgb_to_grayscale` with the torch backend. "
"Please install torchvision. "
)
image = convert_to_tensor(image)
if data_format == "channels_last":
if image.ndim == 4:
image = image.permute((0, 3, 1, 2))
elif image.ndim == 3:
image = image.permute((2, 0, 1))
else:
raise ValueError(
"Invalid input rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)
grayscale_image = torchvision.transforms.functional.rgb_to_grayscale(
img=image,
)
if data_format == "channels_last":
if len(image.shape) == 4:
grayscale_image = grayscale_image.permute((0, 2, 3, 1))
elif len(image.shape) == 3:
grayscale_image = grayscale_image.permute((1, 2, 0))
return grayscale_image


def resize(
image,
size,
Expand Down
97 changes: 97 additions & 0 deletions keras/ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,103 @@
from keras.ops.operation_utils import compute_conv_output_shape


class RGBToGrayscale(Operation):
def __init__(
self,
data_format="channels_last",
):
super().__init__()
self.data_format = data_format

def call(self, image):
return backend.image.rgb_to_grayscale(
image,
data_format=self.data_format,
)

def compute_output_spec(self, image):
if len(image.shape) not in (3, 4):
raise ValueError(
"Invalid image rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"image.shape={image.shape}"
)

if len(image.shape) == 3:
if self.data_format == "channels_last":
return KerasTensor(image.shape[:-1] + (1,), dtype=image.dtype)
else:
return KerasTensor((1,) + image.shape[1:], dtype=image.dtype)
elif len(image.shape) == 4:
if self.data_format == "channels_last":
return KerasTensor(
(image.shape[0],) + image.shape[1:-1] + (1,),
dtype=image.dtype,
)
else:
return KerasTensor(
(
image.shape[0],
1,
)
+ image.shape[2:],
dtype=image.dtype,
)


@keras_export("keras.ops.image.rgb_to_grayscale")
def rgb_to_grayscale(
image,
data_format="channels_last",
):
"""Convert RGB images to grayscale.

This function converts RGB images to grayscale images. It supports both
3D and 4D tensors, where the last dimension represents channels.

Args:
image: Input RGB image or batch of RGB images. Must be a 3D tensor
with shape `(height, width, channels)` or a 4D tensor with shape
`(batch, height, width, channels)`.
data_format: A string specifying the data format of the input tensor.
It can be either `"channels_last"` or `"channels_first"`.
`"channels_last"` corresponds to inputs with shape
`(batch, height, width, channels)`, while `"channels_first"`
corresponds to inputs with shape `(batch, channels, height, width)`.
Defaults to `"channels_last"`.

Returns:
Grayscale image or batch of grayscale images.

Examples:

>>> import numpy as np
>>> from keras import ops
>>> x = np.random.random((2, 4, 4, 3))
>>> y = ops.image.rgb_to_grayscale(x)
>>> y.shape
(2, 4, 4, 1)

>>> x = np.random.random((4, 4, 3)) # Single RGB image
>>> y = ops.image.rgb_to_grayscale(x)
>>> y.shape
(4, 4, 1)

>>> x = np.random.random((2, 3, 4, 4))
>>> y = ops.image.rgb_to_grayscale(x, data_format="channels_first")
>>> y.shape
(2, 1, 4, 4)
"""
if any_symbolic_tensors((image,)):
return RGBToGrayscale(
data_format=data_format,
).symbolic_call(image)
return backend.image.rgb_to_grayscale(
image,
data_format=data_format,
)


class Resize(Operation):
def __init__(
self,
Expand Down
55 changes: 55 additions & 0 deletions keras/ops/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@


class ImageOpsDynamicShapeTest(testing.TestCase):
def test_rgb_to_grayscale(self):
x = KerasTensor([None, 20, 20, 3])
out = kimage.rgb_to_grayscale(x)
self.assertEqual(out.shape, (None, 20, 20, 1))

def test_resize(self):
x = KerasTensor([None, 20, 20, 3])
out = kimage.resize(x, size=(15, 15))
Expand Down Expand Up @@ -62,6 +67,11 @@ def test_crop_images(self):


class ImageOpsStaticShapeTest(testing.TestCase):
def test_rgb_to_grayscale(self):
x = KerasTensor([20, 20, 3])
out = kimage.rgb_to_grayscale(x)
self.assertEqual(out.shape, (20, 20, 1))

def test_resize(self):
x = KerasTensor([20, 20, 3])
out = kimage.resize(x, size=(15, 15))
Expand Down Expand Up @@ -187,6 +197,51 @@ def _fixed_map_coordinates(


class ImageOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
[
("channels_last"),
("channels_first"),
]
)
def test_rgb_to_grayscale(self, data_format):
# Unbatched case
if data_format == "channels_first":
x = np.random.random((3, 50, 50)) * 255
else:
x = np.random.random((50, 50, 3)) * 255
out = kimage.rgb_to_grayscale(
x,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (1, 2, 0))
ref_out = tf.image.rgb_to_grayscale(
x,
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (2, 0, 1))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)

# Batched case
if data_format == "channels_first":
x = np.random.random((2, 3, 50, 50)) * 255
else:
x = np.random.random((2, 50, 50, 3)) * 255
out = kimage.rgb_to_grayscale(
x,
data_format=data_format,
)
if data_format == "channels_first":
x = np.transpose(x, (0, 2, 3, 1))
ref_out = tf.image.rgb_to_grayscale(
x,
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)

@parameterized.parameters(
[
("bilinear", True, "channels_last"),
Expand Down