|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 | import json
|
| 5 | +import tempfile |
| 6 | +import contextlib |
5 | 7 | from pathlib import Path
|
6 |
| -from typing import List, Optional, Union |
| 8 | +from typing import List, Optional, Union, Generator |
7 | 9 | from PIL import Image
|
8 | 10 | from loguru import logger
|
9 | 11 |
|
10 | 12 | import time
|
| 13 | +from vlmrun.common.utils import remote_image |
11 | 14 | from vlmrun.common.image import encode_image, _open_image_with_exif
|
12 | 15 | from vlmrun.client.base_requestor import APIRequestor
|
13 | 16 | from vlmrun.types.abstract import VLMRunProtocol
|
|
25 | 28 | from cachetools.keys import hashkey
|
26 | 29 |
|
27 | 30 |
|
| 31 | +@contextlib.contextmanager |
| 32 | +def image_path_ctx( |
| 33 | + image: Image.Image | None = None, |
| 34 | + url: str | None = None, |
| 35 | +) -> Generator[Path, None, None]: |
| 36 | + """Context manager to handle temporary image paths. |
| 37 | +
|
| 38 | + Args: |
| 39 | + image: PIL Image object |
| 40 | + url: URL of the image |
| 41 | +
|
| 42 | + Yields: |
| 43 | + str: Path to the temporary image file |
| 44 | + """ |
| 45 | + if not url and not image: |
| 46 | + raise ValueError("Either `image` or `url` must be provided") |
| 47 | + if url and image: |
| 48 | + raise ValueError("Cannot provide both `image` and `url`") |
| 49 | + |
| 50 | + # Download the image from the URL if provided |
| 51 | + if url: |
| 52 | + image: Image.Image = remote_image(url) |
| 53 | + |
| 54 | + # Save the image to a temporary file, and yield the path |
| 55 | + with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file: |
| 56 | + image.save(temp_file.name, format="JPEG", quality=98) |
| 57 | + yield Path(temp_file.name) |
| 58 | + |
| 59 | + |
28 | 60 | @cachetools.cached(
|
29 | 61 | cache=cachetools.TTLCache(maxsize=100, ttl=3600),
|
30 | 62 | key=lambda _client, domain, config: hashkey(
|
@@ -292,6 +324,22 @@ def generate(
|
292 | 324 | raise ValueError("Either `images` or `urls` must be provided")
|
293 | 325 | if images and urls:
|
294 | 326 | raise ValueError("Only one of `images` or `urls` can be provided")
|
| 327 | + if batch and len(images) > 1: |
| 328 | + raise ValueError("Batch mode only supports one image") |
| 329 | + |
| 330 | + if batch: |
| 331 | + assert len(images) == 1, "Batch mode only supports one image" |
| 332 | + with image_path_ctx(image=images[0]) as image_path: |
| 333 | + return self._client.document.generate( |
| 334 | + file=image_path, |
| 335 | + model=model, |
| 336 | + domain=domain, |
| 337 | + batch=batch, |
| 338 | + config=config, |
| 339 | + metadata=metadata, |
| 340 | + callback_url=callback_url, |
| 341 | + autocast=autocast, |
| 342 | + ) |
295 | 343 |
|
296 | 344 | if images:
|
297 | 345 | # Check if all images are of the same type
|
|
0 commit comments