Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion vlmrun/client/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from __future__ import annotations
import json
import tempfile
import contextlib
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, Generator
from PIL import Image
from loguru import logger

import time
from vlmrun.common.utils import remote_image
from vlmrun.common.image import encode_image, _open_image_with_exif
from vlmrun.client.base_requestor import APIRequestor
from vlmrun.types.abstract import VLMRunProtocol
Expand All @@ -25,6 +28,35 @@
from cachetools.keys import hashkey


@contextlib.contextmanager
def image_path_ctx(
image: Image.Image | None = None,
url: str | None = None,
) -> Generator[Path, None, None]:
"""Context manager to handle temporary image paths.

Args:
image: PIL Image object
url: URL of the image

Yields:
str: Path to the temporary image file
"""
if not url and not image:
raise ValueError("Either `image` or `url` must be provided")
if url and image:
raise ValueError("Cannot provide both `image` and `url`")

# Download the image from the URL if provided
if url:
image: Image.Image = remote_image(url)

# Save the image to a temporary file, and yield the path
with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file:
image.save(temp_file.name, format="JPEG", quality=98)
yield Path(temp_file.name)


@cachetools.cached(
cache=cachetools.TTLCache(maxsize=100, ttl=3600),
key=lambda _client, domain, config: hashkey(
Expand Down Expand Up @@ -292,6 +324,22 @@ def generate(
raise ValueError("Either `images` or `urls` must be provided")
if images and urls:
raise ValueError("Only one of `images` or `urls` can be provided")
if batch and len(images) > 1:
raise ValueError("Batch mode only supports one image")

if batch:
assert len(images) == 1, "Batch mode only supports one image"
with image_path_ctx(image=images[0]) as image_path:
return self._client.document.generate(
file=image_path,
model=model,
domain=domain,
batch=batch,
config=config,
metadata=metadata,
callback_url=callback_url,
autocast=autocast,
)

if images:
# Check if all images are of the same type
Expand Down
2 changes: 1 addition & 1 deletion vlmrun/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.19"
__version__ = "0.2.20"