Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions ml-agents-envs/mlagents_envs/rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,38 @@ def behavior_spec_from_proto(
return BehaviorSpec(observation_shape, action_type, action_shape)


class OffsetBytesIO:
"""
Simple file-like class that wraps a bytes, and allows moving its "start"
position in the bytes. This is only used for reading concatenated PNGs,
because Pillow always calls seek(0) at the start of reading.
"""

__slots__ = ["fp", "offset"]

def __init__(self, data: bytes):
self.fp = io.BytesIO(data)
self.offset = 0

def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
res = self.fp.seek(offset + self.offset)
return res - self.offset
raise NotImplementedError()

def tell(self) -> int:
return self.fp.tell() - self.offset

def read(self, size: int = -1) -> bytes:
return self.fp.read(size)

def original_tell(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 even better

"""
Returns the offset into the original byte array
"""
return self.fp.tell()


@timed
def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray:
"""
Expand All @@ -54,12 +86,12 @@ def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray:
:param expected_channels: Expected output channels
:return: processed numpy array of observation from environment
"""
image_bytearray = bytearray(image_bytes)
image_fp = OffsetBytesIO(image_bytes)

if expected_channels == 1:
# Convert to grayscale
with hierarchical_timer("image_decompress"):
image = Image.open(io.BytesIO(image_bytearray))
image = Image.open(image_fp)
# Normally Image loads lazily, load() forces it to do loading in the timer scope.
image.load()
s = np.array(image, dtype=np.float32) / 255.0
Expand All @@ -69,22 +101,17 @@ def process_pixels(image_bytes: bytes, expected_channels: int) -> np.ndarray:

image_arrays = []

bytes_read = 0
# Read the images back from the bytes (without knowing the sizes).
while True:
# TODO avoid creating a new array here. Unfortunately, Pillow doesn't respect the current state of the buffer
# and always starts with seek(0), but we should be able to wrap BytesIO with something that lets us adjust
# the "start" offset.
buffer = io.BytesIO(image_bytearray[bytes_read:])
with hierarchical_timer("image_decompress"):
image = Image.open(buffer)
image = Image.open(image_fp)
image.load()
image_arrays.append(np.array(image, dtype=np.float32) / 255.0)

# Look for the next header, starting from the current stream location
try:
offset = buffer.getvalue().index(PNG_HEADER, buffer.tell())
bytes_read += offset
new_offset = image_bytes.index(PNG_HEADER, image_fp.original_tell())
image_fp.offset = new_offset
except ValueError:
# Didn't find the header, so must be at the end.
break
Expand Down