first commit

This commit is contained in:
Hermes Agent
2026-05-10 13:52:46 +08:00
commit ccc63d1e70
4583 changed files with 584341 additions and 0 deletions

View File

@@ -0,0 +1 @@
# sn-image-base scripts

View File

@@ -0,0 +1,313 @@
from __future__ import annotations
import contextlib
import os
import warnings
from pathlib import Path
from typing import Annotated, Literal, get_args, get_origin, get_type_hints
from urllib.parse import urlparse
SCRIPT_DIR = Path(__file__).absolute().parent
# "skills" directory that contains "sn-*" skills (e.g. "sn-image-base", "sn-infographic", etc.)
SKILLS_DIR = SCRIPT_DIR.parents[1]
def prepare_env() -> None:
try:
from dotenv import load_dotenv
except ImportError:
warnings.warn("python-dotenv is not installed, `.env` files will be ignored", stacklevel=2)
return
# Priorities:
# 1. ".env" in the agent's config directory:
# - openclaw: ~/.openclaw/.env
# - hermes: ~/.openclaw/.env
# 2. ".env" in current working directory. (depends on how the agent runs the skill)
# 3. Environment variables
# ------------------------------------------------------------
# In reverse order of priority, the latter overrides the former:
# 3 -- do nothing; overridden by other env files
# 2 --
load_dotenv(override=True)
# 1 --
if "OPENCLAW_SHELL" in os.environ:
agent_config_dir = Path("~/.openclaw").expanduser()
else:
agent_config_dir = Path("~/.hermes").expanduser()
if (dotenv_path := agent_config_dir / ".env").exists():
load_dotenv(dotenv_path, override=True)
prepare_env()
class Field:
"""Metadata marker that pairs a field with one or more env var names.
Env vars are tried in order; the first env var that is set is returned.
"""
__slots__ = ("env_names", "required", "secret")
def __init__(self, *env_names: str, required: bool = False, secret: bool = False) -> None:
self.env_names: tuple[str, ...] | None = tuple(env_names) if env_names else None
self.required = required
self.secret = secret
def resolve(self, target_type: type | None = None) -> str | int | float | None:
"""Return the first env var value that is set, converted to target_type.
Args:
target_type: The type to convert to (str, int, float, etc.) or None.
If not int or float, returns the raw string.
Returns:
The converted value, or None if none of the env vars exist.
"""
if not self.env_names:
return None
for n in self.env_names:
if n in os.environ:
raw = os.environ[n]
if target_type is int:
return int(raw)
if target_type is float:
return float(raw)
# For other types (Literal, etc.), return raw string
return raw
return None
class Configs:
"""Central registry of env var names and built-in defaults.
Fields annotated with ``Annotated[str, EnvVar(...)]`` are resolved in
``__init__``: env vars are tried in order; if none is set, the class-level
default is kept.
"""
# global defaults shared by all SN capabilities.
SN_API_KEY: Annotated[str, Field("SN_API_KEY", secret=True)] = ""
SN_BASE_URL: Annotated[str, Field("SN_BASE_URL")] = ""
# image-generate
SN_IMAGE_GEN_API_KEY: Annotated[
str, Field("SN_IMAGE_GEN_API_KEY", "SN_API_KEY", required=True, secret=True)
] = ""
SN_IMAGE_GEN_BASE_URL: Annotated[
str, Field("SN_IMAGE_GEN_BASE_URL", "SN_BASE_URL", required=True)
] = "https://token.sensenova.cn/v1"
SN_IMAGE_GEN_MODEL_TYPE: Annotated[
Literal["sensenova", "nano-banana", "openai-image"], Field("SN_IMAGE_GEN_MODEL_TYPE")
] = "sensenova"
SN_IMAGE_GEN_MODEL: Annotated[str, Field("SN_IMAGE_GEN_MODEL")] = "sensenova-u1-fast"
# chat runtime shared by text and vision commands; command-specific
# SN_TEXT_* / SN_VISION_* values override these defaults.
SN_CHAT_API_KEY: Annotated[str, Field("SN_CHAT_API_KEY", "SN_API_KEY", secret=True)] = ""
SN_CHAT_BASE_URL: Annotated[str, Field("SN_CHAT_BASE_URL", "SN_BASE_URL")] = (
"https://token.sensenova.cn/v1"
)
SN_CHAT_TYPE: Annotated[
Literal["anthropic-messages", "openai-completions"], Field("SN_CHAT_TYPE")
] = "openai-completions"
SN_CHAT_MODEL: Annotated[str, Field("SN_CHAT_MODEL")] = "sensenova-6.7-flash-lite"
SN_TEXT_API_KEY: Annotated[
str, Field("SN_TEXT_API_KEY", "SN_CHAT_API_KEY", "SN_API_KEY", secret=True)
] = ""
SN_TEXT_BASE_URL: Annotated[
str, Field("SN_TEXT_BASE_URL", "SN_CHAT_BASE_URL", "SN_BASE_URL")
] = ""
SN_TEXT_TYPE: Annotated[
Literal["anthropic-messages", "openai-completions"],
Field("SN_TEXT_TYPE", "SN_CHAT_TYPE"),
] = ""
SN_TEXT_MODEL: Annotated[str, Field("SN_TEXT_MODEL", "SN_CHAT_MODEL")] = (
"sensenova-6.7-flash-lite"
)
SN_VISION_API_KEY: Annotated[
str, Field("SN_VISION_API_KEY", "SN_CHAT_API_KEY", "SN_API_KEY", secret=True)
] = ""
SN_VISION_BASE_URL: Annotated[
str, Field("SN_VISION_BASE_URL", "SN_CHAT_BASE_URL", "SN_BASE_URL")
] = ""
SN_VISION_TYPE: Annotated[
Literal["anthropic-messages", "openai-completions"],
Field("SN_VISION_TYPE", "SN_CHAT_TYPE"),
] = ""
SN_VISION_MODEL: Annotated[str, Field("SN_VISION_MODEL", "SN_CHAT_MODEL")] = (
"sensenova-6.7-flash-lite"
)
def __init__(self) -> None:
for field, hint in get_type_hints(type(self), include_extras=True).items():
env_var = next((a for a in get_args(hint) if isinstance(a, Field)), None)
if env_var is None:
continue
# Extract the actual type (unwrap Annotated, handle Literal)
origin = get_origin(hint)
actual_type = get_args(hint)[0] if origin is Annotated else hint
if (val := env_var.resolve(actual_type)) is not None:
setattr(self, field, val)
def to_string(self, mask_secrets: bool = True) -> str:
rows = []
for field_name, hint in get_type_hints(type(self), include_extras=True).items():
field = next((a for a in get_args(hint) if isinstance(a, Field)), None)
value = getattr(self, field_name, None)
v = str(value)
if mask_secrets and v and field and field.secret:
if len(v) > 10:
v = f"{v[:6]}{'*' * (len(v) - 10)}{v[-4:]}"
elif len(v) > 4:
v = f"{v[:4]}{'*' * (len(v) - 4)}"
else:
v = "*" * len(v)
rows.append(f"{field_name}: {v}")
return "\n".join(rows)
def validate_configs(self) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
field_env_names: dict[str, tuple[str, ...] | str] = {}
errors: list[tuple[str, str]] = []
for field_name, hint in get_type_hints(type(self), include_extras=True).items():
field = next((a for a in get_args(hint) if isinstance(a, Field)), None)
if field is None:
continue
if env_names := field.env_names:
if len(env_names) > 1:
field_env_names[field_name] = env_names
elif len(env_names) == 1:
field_env_names[field_name] = env_names[0]
value = getattr(self, field_name, None)
if not value:
if field.required:
if field_name == "SN_IMAGE_GEN_API_KEY":
msg = (
"Image generation API key is not set; configure SN_API_KEY, "
"or configure SN_IMAGE_GEN_API_KEY only for an image-generation-specific override"
)
else:
msg = f"Field '{field_name}' is required but not set; try setting the environment variable(s) {field.env_names}"
errors.append((field_name, msg))
continue
# Check fields combination rules:
if not self.SN_IMAGE_GEN_MODEL:
errors.append((
"SN_IMAGE_GEN_MODEL",
f"SN_IMAGE_GEN_MODEL is required when SN_IMAGE_GEN_MODEL_TYPE is {self.SN_IMAGE_GEN_MODEL_TYPE!r}",
))
warnings: list[tuple[str, str]] = []
runtime_checks = {
"text": {
"api_key": ("SN_TEXT_API_KEY",),
"base_url": ("SN_TEXT_BASE_URL", "SN_CHAT_BASE_URL"),
"model": ("SN_TEXT_MODEL",),
"type": ("SN_TEXT_TYPE", "SN_CHAT_TYPE"),
},
"vision": {
"api_key": ("SN_VISION_API_KEY",),
"base_url": ("SN_VISION_BASE_URL", "SN_CHAT_BASE_URL"),
"model": ("SN_VISION_MODEL",),
"type": ("SN_VISION_TYPE", "SN_CHAT_TYPE"),
},
}
for runtime, checks in runtime_checks.items():
for field_kind, keys in checks.items():
if any(getattr(self, key) for key in keys):
continue
env_help = " / ".join(
", ".join(field_env_names[key])
if isinstance(field_env_names.get(key), tuple)
else str(field_env_names.get(key, key))
for key in keys
)
warnings.append((
keys[0],
f"{keys[0]} is not set; {runtime} {field_kind} may be unavailable. Try setting: {env_help}",
))
# check urls
errors.extend(
(
key,
f"{key} is not a valid base URL: {getattr(self, key)}",
)
for key in ("SN_CHAT_BASE_URL", "SN_TEXT_BASE_URL", "SN_VISION_BASE_URL")
if getattr(self, key) and not is_valid_base_url(getattr(self, key))
)
errors.extend(
(
key,
f"{key} is not a valid base URL: {getattr(self, key)}",
)
for key in (
"SN_BASE_URL",
"SN_IMAGE_GEN_BASE_URL",
)
if getattr(self, key) and not is_valid_base_url(getattr(self, key))
)
return errors, warnings
def get_annotated_field(self, field_name: str) -> Field | None:
hints = get_type_hints(type(self), include_extras=True)
if field_name not in hints:
return None
hint = hints[field_name]
field_inst = next((a for a in get_args(hint) if isinstance(a, Field)), None)
return field_inst
def get_env_var_help(self, field_name: str) -> str:
"""Return a help string describing which environment variables can be used
to set the specified configuration field.
Args:
field_name: The name of the configuration field (e.g., "SN_CHAT_API_KEY").
Returns:
A string describing the environment variable(s) that control this field.
Returns an error message if the field does not exist or has no EnvVar annotation.
"""
if not hasattr(type(self), field_name):
return f"Field '{field_name}' does not exist in Configs."
field_inst = self.get_annotated_field(field_name)
if field_inst is None:
return f"Field '{field_name}' is not configurable via environment variables."
current_value = getattr(self, field_name)
env_names = list(field_inst.env_names) if field_inst.env_names else []
if len(env_names) == 1:
return (
f"To set '{field_name}', configure the environment variable: {env_names[0]}\n"
f"Current value: {current_value!r}"
)
else:
env_list = ", ".join(env_names)
return (
f"To set '{field_name}', configure one of these environment variables: {env_list}\n"
f"They are tried in order; the first set value is used.\n"
f"Current value: {current_value!r}"
)
def is_valid_base_url(url: str) -> bool:
with contextlib.suppress(ValueError):
parsed = urlparse(url)
return bool(parsed.scheme and parsed.netloc)
return False
def reload_env() -> None:
global global_configs
prepare_env()
try:
global_configs = Configs()
print("✅ Reloaded global_configs")
except Exception as e:
warnings.warn(f"Failed to reload global_configs: {e}", stacklevel=2)
global_configs = Configs()

View File

@@ -0,0 +1,39 @@
"""Shared exceptions for sn-image-base."""
from __future__ import annotations
class U1BaseError(Exception):
"""Base exception for sn-image-base."""
DEFAULT_MESSAGE = "An error occurred in the sn-image-base skill."
def __init__(self, message: str | None = None) -> None:
if message is None:
message = self.DEFAULT_MESSAGE
super().__init__(message)
class BadConfigurationError(U1BaseError):
"""Raised when the configuration is invalid."""
DEFAULT_MESSAGE = "The configuration is invalid."
class MissingApiKeyError(BadConfigurationError):
"""Raised when API key is not provided via CLI argument or environment variable."""
DEFAULT_MESSAGE = (
"API key is required but was not provided. "
"Set SN_API_KEY, or set SN_IMAGE_GEN_API_KEY only for an image-generation-specific "
"override, or pass --api-key explicitly."
)
class InvalidBaseUrlError(BadConfigurationError):
"""Raised when base URL is not provided via CLI argument or environment variable."""
DEFAULT_MESSAGE = (
"Base URL is required but was not provided. "
"Set SN_IMAGE_GEN_BASE_URL or SN_BASE_URL, or pass --base-url explicitly."
)

View File

@@ -0,0 +1,9 @@
from .nano_banana import NanoBananaText2ImageClient
from .openai_image import OpenAIImageGenerationClient
from .sensenova import SensenovaText2ImageClient
__all__ = [
"NanoBananaText2ImageClient",
"OpenAIImageGenerationClient",
"SensenovaText2ImageClient",
]

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from pathlib import Path
def ensure_output_path(path: Path) -> Path:
"""Ensure the parent directory of the given path exists.
Args:
path (Path):
The file path whose parent directory should be created.
Returns:
Path:
The original path unchanged.
"""
path.parent.mkdir(parents=True, exist_ok=True)
return path

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import typing
from abc import ABC, abstractmethod
from typing import Any
from sn_image_base.utils.error_utils import U1HttpResponseParseError
from sn_image_base.utils.httpx_client import (
create_async_httpx_client,
httpx_response_raise_for_status_code,
)
if typing.TYPE_CHECKING:
import httpx
DEFAULT_POLL_INTERVAL = 5.0
DEFAULT_HTTP_REQUEST_TIMEOUT = 300.0
DEFAULT_MAX_CONNECTIONS = 100
class T2IBaseClient(ABC):
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
*,
model: str | None = None,
max_connections: int = DEFAULT_MAX_CONNECTIONS,
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
ssl_verify: bool = True,
**kwargs: Any,
) -> None:
self._api_key = api_key
self._base_url = base_url
self.model = model
self._client: httpx.AsyncClient | None = None
self._max_connections = max_connections
self._timeout = timeout
self._ssl_verify = ssl_verify
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = create_async_httpx_client(
self.headers,
timeout=self._timeout,
max_connections=self._max_connections,
verify=self._ssl_verify,
)
return self._client
async def aclose(self) -> None:
if self._client is not None:
await self._client.aclose()
self._client = None
@property
def api_key(self) -> str | None:
return self._api_key
@property
def base_url(self) -> str | None:
return self._base_url
@abstractmethod
async def generate(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
@abstractmethod
def get_api_url(self, *args: Any, **kwargs: Any) -> str: ...
@abstractmethod
def build_payload(self, *args: Any, **kwargs: Any) -> Any: ...
@property
@abstractmethod
def headers(self) -> dict[str, str]: ...
def parse_response(self, response: httpx.Response) -> dict:
httpx_response_raise_for_status_code(response)
try:
data = response.json()
return data
except ValueError as exc:
raise U1HttpResponseParseError(
detail=f"Failed to parse HTTP response. {response.request.url}. Response content: {response.content}",
code=response.status_code,
) from exc

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
import base64
import time
from pathlib import Path
from typing import Any, Literal
import httpx
from typing_extensions import override
from sn_image_base.configs import global_configs, is_valid_base_url
from sn_image_base.utils.error_utils import U1HttpErrorBase
from .core import ensure_output_path
from .core.client_base import (
DEFAULT_HTTP_REQUEST_TIMEOUT,
DEFAULT_MAX_CONNECTIONS,
T2IBaseClient,
)
DEFAULT_MODEL_SIZE: Literal["1K", "2K", "4K"] = "2K"
DEFAULT_ASPECT_RATIO = "16:9"
DEFAULT_POLL_INTERVAL = 5.0
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
class NanoBananaText2ImageClient(T2IBaseClient):
"""Async client for Google Nano Banana API."""
# requires `{model}` placeholder for format string
DEFAULT_API_PATH = "/v1beta/models/{model}:generateContent"
def __init__(
self,
api_key: str,
base_url: str | None = None,
*,
model: str | None = None,
max_connections: int = DEFAULT_MAX_CONNECTIONS,
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
ssl_verify: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the NanoBananaText2ImageClient.
Args:
api_key (str):
API key for authentication.
base_url (str | None, optional):
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
model (str | None, optional):
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
max_connections (int, optional):
Maximum number of connections. Defaults to 100.
timeout (float, optional):
Total timeout in seconds for HTTP requests.
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
ssl_verify (bool, optional):
If True, enable TLS verification. Defaults to True.
"""
super().__init__(
api_key=api_key,
base_url=base_url,
model=model,
max_connections=max_connections,
timeout=timeout,
ssl_verify=ssl_verify,
**kwargs,
)
@override
async def generate(
self,
prompt: str,
negative_prompt: str = "",
*,
model: str | None = None,
image_size: Literal["1K", "2K", "4K"] = DEFAULT_MODEL_SIZE,
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
output_path: Path | None = None,
**kwargs: Any,
) -> dict:
"""Generate an image from text prompt.
Args:
prompt (str):
Text prompt for image generation.
negative_prompt (str, optional):
Negative prompt. Defaults to "".
model (str | None, optional):
Model name override. Defaults to None.
image_size (str, optional):
Image size preset ("1K", "2K", "4K"). Defaults to DEFAULT_MODEL_SIZE.
aspect_ratio (str, optional):
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
output_path (Path | None, optional):
Output path for the generated image. Defaults to None.
**kwargs:
Additional arguments reserved for backend compatibility.
Returns:
dict:
Dictionary with keys: status, output (path), message.
"""
model = model or self.model
# Normalize image_size to uppercase for NanoBanana API
image_size = image_size.upper() # type: ignore[assignment]
payload = self.build_payload(
prompt=prompt,
negative_prompt=negative_prompt,
image_size=image_size,
aspect_ratio=aspect_ratio,
)
headers = self.headers
api_url = self.get_api_url(model)
if output_path is None:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
output_path = ensure_output_path(output_path)
client = await self._get_client()
try:
create_response = await client.post(
api_url,
json=payload,
headers=headers,
)
data = self.parse_response(create_response)
except U1HttpErrorBase as exc:
details = exc.detail or ""
field_name = None
if exc.code == 404:
field_name = "SN_IMAGE_GEN_BASE_URL"
elif exc.code == 401:
field_name = "SN_IMAGE_GEN_API_KEY"
if field_name is not None:
field_hint = global_configs.get_annotated_field(field_name)
if field_hint is not None:
env_names = list(field_hint.env_names) if field_hint.env_names else []
if env_names:
if len(env_names) == 1:
details += (
f"\nIs the environment variable `{env_names[0]}` set correctly?"
)
else:
env_names_str = ", ".join([f"`{n}`" for n in env_names])
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
return {
"status": "failed",
"error": f"HTTP {exc.code}: {exc.message}",
"message": details,
}
try:
images = data["images"]
if not images:
return {
"status": "failed",
"error": "No image generated from the model",
}
image, mime_type = images[-1]
image_bytes = base64.b64decode(image)
suffix = mime_type_to_suffix(mime_type)
saved_path = output_path.with_suffix(suffix)
saved_path.write_bytes(image_bytes)
return {
"status": "ok",
"output": str(saved_path),
"message": "Image generated successfully",
}
except httpx.HTTPStatusError as exc:
return {
"status": "failed",
"error": f"HTTP {exc.response.status_code}",
"message": f"http error: {exc.response.status_code} {exc.response.text}",
}
except (httpx.HTTPError, OSError, ValueError) as exc:
return {
"status": "failed",
"error": type(exc).__name__,
"message": f"request error: {exc}",
}
@property
@override
def api_key(self) -> str:
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
if not api_key:
raise ValueError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
return api_key
@property
@override
def base_url(self) -> str:
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
if not base_url:
raise ValueError(
"Base URL is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
)
)
if not is_valid_base_url(base_url):
raise ValueError(
f"Base URL is not a valid base URL: {base_url}. "
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
)
return base_url
@override
def get_api_url(self, model: str | None = None) -> str:
model = model or self.model
path = self.DEFAULT_API_PATH.format(model=model).lstrip("/")
api_url = f"{self.base_url.rstrip('/')}/{path}"
return api_url
@override
def build_payload(
self,
prompt: str,
negative_prompt: str = "",
*,
image_size: Literal["1K", "2K", "4K"] = DEFAULT_MODEL_SIZE,
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
max_output_tokens: int = 8192,
**kwargs: Any,
) -> dict:
parts: list[dict] = [{"text": prompt}]
if (image_b64 := kwargs.get("image_b64")) and (
image_mime_type := kwargs.get("image_mime_type")
):
if image_mime_type not in ["image/jpeg", "image/png"]:
msg = (
f"Unsupported image MIME type: {image_mime_type}. "
"Supported types: image/jpeg, image/png"
)
raise ValueError(msg)
parts.append({"inline_data": {"mime_type": image_mime_type, "data": image_b64}})
return {
"contents": [{"role": "USER", "parts": parts}],
"generationConfig": {
"imageConfig": {"aspectRatio": aspect_ratio, "imageSize": image_size},
"maxOutputTokens": max_output_tokens,
},
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
],
}
@property
@override
def headers(self) -> dict[str, str]:
return {
"x-goog-api-key": self.api_key,
"Content-Type": "application/json",
}
@override
def parse_response(self, response: httpx.Response) -> dict:
raw_data = super().parse_response(response)
images: list[tuple[str, str]] = []
finish_reasons: list[str] = []
candidates: list[dict] = raw_data.get("candidates") or []
for c in candidates:
content: dict[str, Any] = c.get("content") or {}
parts: list[dict[str, Any]] = content.get("parts") or []
if f_reason := content.get("finishReason"):
finish_reasons.append(f_reason)
for p in parts:
inline_data: dict[str, Any] = p.get("inlineData", {})
if inline_data:
mime_type: str = inline_data.get("mimeType") # pyright: ignore[reportAssignmentType]
data: str = inline_data.get("data") # pyright: ignore[reportAssignmentType]
images.append((data, mime_type))
return {
"images": images,
"finish_reasons": finish_reasons,
}
def mime_type_to_suffix(mime_type: str) -> str:
"""Convert MIME type to file suffix.
Args:
mime_type: MIME type.
Returns:
str: File suffix.
"""
if mime_type == "image/jpeg":
return ".jpg"
elif mime_type == "image/png":
return ".png"
elif mime_type == "image/webp":
return ".webp"
else:
return ".png"

View File

@@ -0,0 +1,366 @@
from __future__ import annotations
import base64
import math
import re
import time
from pathlib import Path
from typing import Any, Literal
import httpx
from typing_extensions import override
from sn_image_base.configs import global_configs, is_valid_base_url
from sn_image_base.exceptions import BadConfigurationError
from sn_image_base.utils.error_utils import U1HttpErrorBase
from .core import ensure_output_path
from .core.client_base import (
DEFAULT_HTTP_REQUEST_TIMEOUT,
DEFAULT_MAX_CONNECTIONS,
T2IBaseClient,
)
DEFAULT_RESOLUTION: Literal["1K", "2K"] = "2K"
DEFAULT_ASPECT_RATIO = "16:9"
DEFAULT_POLL_INTERVAL = 5.0
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
B64_PARSE_PATTERN = re.compile(r"^data:([a-zA-Z0-9/]+?);base64,([+-/_A-Za-z0-9]+=*)$")
class OpenAIImageGenerationClient(T2IBaseClient):
"""Async client for OpenAI Image Generation API."""
DEFAULT_API_PATH = "/images/generations"
def __init__(
self,
api_key: str,
base_url: str | None = None,
*,
model: str | None = None,
max_connections: int = DEFAULT_MAX_CONNECTIONS,
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
ssl_verify: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the OpenAIImageGenerationClient.
Args:
api_key (str):
API key for authentication.
base_url (str | None, optional):
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
model (str | None, optional):
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
max_connections (int, optional):
Maximum number of connections. Defaults to 100.
timeout (float, optional):
Total timeout in seconds for HTTP requests.
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
ssl_verify (bool, optional):
If True, enable TLS verification. Defaults to True.
"""
super().__init__(
api_key=api_key,
base_url=base_url,
model=model,
max_connections=max_connections,
timeout=timeout,
ssl_verify=ssl_verify,
**kwargs,
)
@override
async def generate(
self,
prompt: str,
*,
model: str | None = None,
image_size: Literal["1K", "2K", "1k", "2k"] | None = None,
aspect_ratio: str | None = DEFAULT_ASPECT_RATIO,
output_path: Path | None = None,
**kwargs: Any,
) -> dict:
"""Generate an image from text prompt.
Args:
prompt (str):
Text prompt for image generation.
model (str | None, optional):
Model name override. Defaults to None.
image_size (str, optional):
Image size preset ("1K", "2K"). Defaults to DEFAULT_RESOLUTION.
aspect_ratio (str, optional):
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
output_path (Path | None, optional):
Output path for the generated image. Defaults to None.
**kwargs:
Additional arguments reserved for backend compatibility.
Returns:
dict:
Dictionary with keys: status, output (path), message.
"""
model = model or self.model or global_configs.SN_IMAGE_GEN_MODEL
if not model:
raise BadConfigurationError(
f"Model is not set. {global_configs.get_env_var_help('SN_IMAGE_GEN_MODEL')}"
)
image_size = image_size or DEFAULT_RESOLUTION
if aspect_ratio is None:
size = None
else:
rw, _, rh = aspect_ratio.partition(":")
try:
aspect_ratio_val: float = float(int(rw) / int(rh))
except (ValueError, ZeroDivisionError) as e:
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}") from e
size = self._resolve_size(
resolution=image_size,
aspect_ratio_val=aspect_ratio_val,
)
payload = self.build_payload(
prompt=prompt,
model=model,
size=size,
)
headers = self.headers
api_url = self.get_api_url(model)
if output_path is None:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
output_path = ensure_output_path(output_path)
client = await self._get_client()
try:
create_response = await client.post(
api_url,
json=payload,
headers=headers,
)
data = self.parse_response(create_response)
except U1HttpErrorBase as exc:
details = exc.detail or ""
field_name = None
if exc.code == 404:
field_name = "SN_IMAGE_GEN_BASE_URL"
elif exc.code == 401:
field_name = "SN_IMAGE_GEN_API_KEY"
if field_name is not None:
field_hint = global_configs.get_annotated_field(field_name)
if field_hint is not None:
env_names = list(field_hint.env_names) if field_hint.env_names else []
if env_names:
if len(env_names) == 1:
details += (
f"\nIs the environment variable `{env_names[0]}` set correctly?"
)
else:
env_names_str = ", ".join([f"`{n}`" for n in env_names])
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
return {
"status": "failed",
"error": f"HTTP {exc.code}: {exc.message}",
"message": details,
}
try:
images = data["images"]
if not images:
return {
"status": "failed",
"error": "No image generated from the model",
}
image_bytes, mime_type = images[-1]
suffix = mime_type_to_suffix(mime_type)
saved_path = output_path.with_suffix(suffix)
saved_path.write_bytes(image_bytes)
return {
"status": "ok",
"output": str(saved_path),
"message": "Image generated successfully",
}
except httpx.HTTPStatusError as exc:
return {
"status": "failed",
"error": f"HTTP {exc.response.status_code}",
"message": f"http error: {exc.response.status_code} {exc.response.text}",
}
except (httpx.HTTPError, OSError, ValueError) as exc:
return {
"status": "failed",
"error": type(exc).__name__,
"message": f"request error: {exc}",
}
@property
@override
def api_key(self) -> str:
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
if not api_key:
raise ValueError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
return api_key
@property
@override
def base_url(self) -> str:
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
if not base_url:
raise ValueError(
"Base URL is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
)
)
if not is_valid_base_url(base_url):
raise ValueError(
f"Base URL is not a valid base URL: {base_url}. "
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
)
return base_url
@override
def get_api_url(self, model: str | None = None) -> str:
model = model or self.model
path = self.DEFAULT_API_PATH.format(model=model).lstrip("/")
api_url = f"{self.base_url.rstrip('/')}/{path}"
return api_url
@override
def build_payload(
self,
prompt: str,
model: str,
*,
n: int = 1,
size: str | None = None,
**kwargs: Any,
) -> dict:
"""
Example:
{
"model": "dall-e-3",
"prompt": "一只戴着墨镜的猫在赛博朋克城市的街道上喝咖啡, 赛璐璐画风",
"n": 1,
"size": "1024x1024",
"response_format": "b64_json",
}
"""
size = size or "auto"
payload = {
"model": model,
"prompt": prompt,
"n": n,
"size": size,
"response_format": "b64_json",
**kwargs,
}
return payload
@property
@override
def headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
@override
def parse_response(self, response: httpx.Response) -> dict:
"""
Example:
{
"data": [{
"b64_json": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOYA3Q..."
}],
"created": 1776789055
"usage": {
"input_tokens":773,
"output_tokens":765,
"total_tokens":1538,
"input_tokens_details": {
"text_tokens":8,
"image_tokens":765
}
}
}
"""
raw_data = super().parse_response(response)
images: list[tuple[bytes, str]] = []
data_items: list[dict] = raw_data.get("data") or []
for item in data_items:
encoded: str = item.get("b64_json") or ""
if not encoded:
continue
if encoded.startswith("data:"):
match = B64_PARSE_PATTERN.match(encoded)
if match:
mime_type = match.group(1)
b64_data = match.group(2)
else:
raise ValueError(
f"Invalid base64 data in response: {encoded[:100]}... (truncated)"
)
else:
mime_type = "image/png" # fallback to png
b64_data = encoded
try:
decoded = base64.b64decode(b64_data)
except Exception as e:
raise ValueError(
f"Failed to decode base64 data in response: {e}. b64_json: {encoded[:100]}... (truncated)"
) from e
images.append((decoded, mime_type))
return {
"images": images,
}
@classmethod
def _resolve_size(
cls,
resolution: str,
aspect_ratio_val: float | None,
) -> str:
"""Convert (resolution, aspect_ratio) to a pixel size string."""
resolution = resolution.upper()
if resolution == "1K":
max_pixel = 1024**2
elif resolution == "2K":
max_pixel = 2048**2
else:
raise ValueError(f"Unsupported resolution: {resolution}")
aspect_ratio_val = aspect_ratio_val or 1
if aspect_ratio_val < 1 / 3 or aspect_ratio_val > 3:
raise ValueError(f"Aspect ratio value must be between [1/3, 3], got {aspect_ratio_val}")
width: int = round(math.sqrt(max_pixel * aspect_ratio_val))
height: int = round(math.sqrt(max_pixel / aspect_ratio_val))
return f"{width}x{height}"
def mime_type_to_suffix(mime_type: str) -> str:
"""Convert MIME type to file suffix.
Args:
mime_type: MIME type.
Returns:
str: File suffix.
"""
if mime_type == "image/jpeg":
return ".jpg"
elif mime_type == "image/png":
return ".png"
elif mime_type == "image/webp":
return ".webp"
else:
return ".png"

View File

@@ -0,0 +1,508 @@
from __future__ import annotations
import os
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
import httpx
from PIL import Image
from typing_extensions import override
from sn_image_base.configs import global_configs, is_valid_base_url
from sn_image_base.exceptions import InvalidBaseUrlError, MissingApiKeyError
from sn_image_base.generation.core import ensure_output_path
from sn_image_base.generation.core.client_base import (
DEFAULT_HTTP_REQUEST_TIMEOUT,
DEFAULT_MAX_CONNECTIONS,
T2IBaseClient,
)
from sn_image_base.utils.error_utils import U1HttpErrorBase
if TYPE_CHECKING:
from collections.abc import Sequence
DEFAULT_RESOLUTION: Literal["1K", "2K", "4K"] = "2K"
DEFAULT_ASPECT_RATIO = "16:9"
DEFAULT_POLL_INTERVAL = 5.0
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
IMAGE_GEN_ENDPOINT = "/images/generations"
class SensenovaText2ImageClient(T2IBaseClient):
"""Async client for Sensenova text-to-image API."""
def __init__(
self,
api_key: str,
base_url: str | None = None,
*,
model: str | None = None,
max_connections: int = DEFAULT_MAX_CONNECTIONS,
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
ssl_verify: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the SensenovaText2ImageClient.
Args:
api_key (str):
API key for authentication.
base_url (str | None, optional):
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
model (str | None, optional):
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
max_connections (int, optional):
Maximum number of connections. Defaults to 100.
timeout (float, optional):
Total timeout in seconds for HTTP requests.
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
ssl_verify (bool, optional):
If True, enable TLS verification. Defaults to True.
"""
api_key = api_key or global_configs.SN_IMAGE_GEN_API_KEY
if not api_key:
raise MissingApiKeyError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
base_url = base_url or global_configs.SN_IMAGE_GEN_BASE_URL
if not base_url:
raise InvalidBaseUrlError(
"Base URL is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
)
)
if not is_valid_base_url(base_url):
raise InvalidBaseUrlError(
f"Base URL is not a valid base URL: {base_url}. "
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
)
super().__init__(
api_key=api_key,
base_url=base_url,
model=model,
max_connections=max_connections,
timeout=timeout,
ssl_verify=ssl_verify,
**kwargs,
)
@override
async def generate(
self,
prompt: str,
negative_prompt: str = "",
*,
model: str | None = None,
image_size: Literal["1K", "2K", "4K"] = DEFAULT_RESOLUTION,
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
output_path: Path | None = None,
**kwargs: Any,
) -> dict:
"""Generate an image from text prompt.
Args:
prompt (str):
Text prompt for image generation.
negative_prompt (str, optional):
Negative prompt. Defaults to "".
model (str | None, optional):
Model name override. Defaults to None.
image_size (str, optional):
Image size preset ("1K", "2K", "4K"). Defaults to DEFAULT_RESOLUTION.
aspect_ratio (str, optional):
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
output_path (Path | None, optional):
Output path for the generated image. Defaults to None.
**kwargs:
Additional arguments reserved for backend compatibility.
Returns:
dict:
Dictionary with keys: status, output (path), message.
"""
model = model or self.model or global_configs.SN_IMAGE_GEN_MODEL
# Normalize image_size to uppercase for NanoBanana API
image_size = image_size.upper() # type: ignore[assignment]
output_format = "png"
size = self._resolve_size(image_size, aspect_ratio)
payload = self.build_payload(
prompt=prompt,
model=model,
size=size,
aspect_ratio=aspect_ratio,
output_format=output_format,
)
headers = self.headers
api_url = self.get_api_url(model)
if output_path is None:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
output_path = ensure_output_path(output_path)
client = await self._get_client()
try:
create_response = await client.post(
api_url,
json=payload,
headers=headers,
)
data = self.parse_response(create_response)
except U1HttpErrorBase as exc:
details = exc.detail or ""
field_name = None
if exc.code == 404:
field_name = "SN_IMAGE_GEN_BASE_URL"
elif exc.code == 401:
field_name = "SN_IMAGE_GEN_API_KEY"
# elif exc.code == 400:
# warnings.warn(f"Bad request: {exc.message}; body: {payload}", stacklevel=2)
if field_name is not None:
field_hint = global_configs.get_annotated_field(field_name)
if field_hint is not None:
env_names = list(field_hint.env_names) if field_hint.env_names else []
if env_names:
if len(env_names) == 1:
details += (
f"\nIs the environment variable `{env_names[0]}` set correctly?"
)
else:
env_names_str = ", ".join([f"`{n}`" for n in env_names])
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
return {
"status": "failed",
"error": f"HTTP {exc.code}: {exc.message}",
"message": details,
}
try:
images_urls: list[str] = data["images_urls"]
if not images_urls:
return {
"status": "failed",
"error": "No image generated from the model",
}
url = images_urls[-1]
suffix = f".{output_format}"
save_path = output_path.with_suffix(suffix)
saved_path = await download_image(url, save_path)
return {
"status": "ok",
"output": str(saved_path),
"message": "Image generated successfully",
}
except httpx.HTTPStatusError as exc:
return {
"status": "failed",
"error": f"HTTP {exc.response.status_code}",
"message": f"http error: {exc.response.status_code} {exc.response.text}",
}
except (httpx.HTTPError, OSError, ValueError) as exc:
return {
"status": "failed",
"error": type(exc).__name__,
"message": f"request error: {exc}",
}
@property
@override
def api_key(self) -> str:
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
if not api_key:
raise ValueError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
return api_key
@property
@override
def base_url(self) -> str:
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
if not base_url:
raise ValueError(
"Base URL is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
)
)
if not is_valid_base_url(base_url):
raise ValueError(
f"Base URL is not a valid base URL: {base_url}. "
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
)
return base_url
@override
def get_api_url(self, _model: str | None = None) -> str:
base_url = self.base_url.rstrip("/")
path = IMAGE_GEN_ENDPOINT.lstrip("/")
api_url = f"{base_url}/{path}"
return api_url
@override
def build_payload(
self,
prompt: str,
model: str,
*,
size: str | None = None,
modalities: Sequence[str] = ("text", "image"),
output_format: Literal["png"] = "png",
response_format: Literal["url"] = "url",
**kwargs: Any,
) -> dict[str, Any]:
"""Build the payload for the SenseNova image-generation endpoint.
Args:
prompt (str): The prompt to generate an image for.
model (str): The model to use for generation.
size (str | None): Pixel size string (for example, "1920x1920").
modalities (Sequence[str]): Reserved for compatibility; currently not sent.
output_format (Literal["png"]): The output format of the image. Defaults to "png".
response_format (Literal["url"]): The response format of the image. Defaults to "url".
**kwargs (Any, optional): Additional parameters to pass to the API.
Example:
{
"model": "sensenova-u1-fast",
"prompt": "A cat wearing a hat",
"size": "1024x1024",
"response_format": "url",
"output_format": "png",
}
"""
payload = {
"model": model,
"prompt": prompt,
# "modalities": modalities,
"size": size,
# "n": 1,
"response_format": response_format,
"output_format": output_format,
**kwargs,
}
return payload
@property
@override
def headers(self) -> dict[str, str]:
if not self.api_key:
raise MissingApiKeyError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
@classmethod
def _resolve_size(
cls,
resolution: Literal["1K", "2K"] | str | None = None,
aspect_ratio: ASPECT_RATIO_LITERALS | str | None = None,
) -> str | None:
"""Convert (resolution, aspect_ratio) to a pixel size string.
If aspect_ratio is None, returns the resolution as-is (e.g. "1K").
"""
if not resolution and not aspect_ratio:
return None
resolution = resolution or "2K"
aspect_ratio = aspect_ratio or "1:1"
if resolution == "1K":
buckets = BUCKETS_1K
elif resolution == "2K":
buckets = BUCKETS_2K
else:
raise ValueError(f"Unsupported resolution: {resolution!r}. Must be '1K' or '2K'.")
try:
ws, _, hs = aspect_ratio.strip().partition(":")
width = int(ws)
height = int(hs)
ratio = width / height
except Exception as e:
raise ValueError(f"Invalid aspect ratio: {aspect_ratio!r}") from e
if ratio > 16 / 9:
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too wide. Maximum is 16:9")
if ratio < 9 / 21:
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too high. Maximum is 9:21")
w, h = _find_nearest_aspect_ratio(ratio, buckets)
return f"{w}x{h}"
@override
def parse_response(self, response: httpx.Response) -> dict:
"""Parse the response from the SenseNova image-generation endpoint.
Example response data:
```json
{
"data": [{
"url": "https://cdn.sensenova.dev/gen/..."
}]
}
```
Args:
response: The HTTP response from the SenseNova image-generation endpoint.
Returns:
dict: Parsed data with key ``images_urls``.
"""
raw_data = super().parse_response(response)
images_urls: list[str] = []
for item in raw_data.get("data", []):
url = item.get("url")
if url:
images_urls.append(url)
return {"images_urls": images_urls}
async def download_image(
url: str,
save_path: Path,
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
) -> Path:
"""Download an image from a URL.
Args:
url: The URL of the image to download.
timeout: The timeout for the request.
Returns:
Path: The path to the downloaded image file.
"""
save_path.parent.mkdir(parents=True, exist_ok=True)
temp_path: Path | None = None
bytes_written = 0
expected_length: int | None = None
try:
with tempfile.NamedTemporaryFile(
dir=save_path.parent,
prefix=f".{save_path.name}.",
suffix=".tmp",
delete=False,
) as temp_file:
temp_path = Path(temp_file.name)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
response.raise_for_status()
content_length = response.headers.get("content-length")
if content_length is not None:
expected_length = int(content_length)
async for chunk in response.aiter_bytes():
bytes_written += len(chunk)
temp_file.write(chunk)
temp_file.flush()
os.fsync(temp_file.fileno())
if expected_length is not None and bytes_written != expected_length:
raise OSError(
f"Downloaded image is incomplete: got {bytes_written} bytes, "
f"expected {expected_length} bytes"
)
assert temp_path is not None
_validate_image_file(temp_path)
temp_path.replace(save_path)
return save_path
except Exception:
if temp_path is not None:
temp_path.unlink(missing_ok=True)
raise
def _validate_image_file(image_path: Path) -> None:
"""Verify that the downloaded image can be decoded completely."""
with Image.open(image_path) as image:
image.verify()
with Image.open(image_path) as image:
image.load()
def mime_type_to_suffix(mime_type: str) -> str:
"""Convert MIME type to file suffix.
Args:
mime_type: MIME type.
Returns:
str: File suffix.
"""
if mime_type == "image/jpeg":
return ".jpg"
elif mime_type == "image/png":
return ".png"
elif mime_type == "image/webp":
return ".webp"
else:
return ".png"
ASPECT_RATIO_LITERALS = Literal[
"2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "1:1", "16:9", "9:16", "9:21"
]
BUCKETS_1K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
"2:3": (1088, 1632),
"3:2": (1632, 1088),
"3:4": (1152, 1536),
"4:3": (1536, 1152),
"4:5": (1184, 1472),
"5:4": (1472, 1184),
"1:1": (1344, 1344),
"16:9": (1792, 992),
"9:16": (992, 1792),
"9:21": (864, 2048),
}
BUCKETS_2K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
"2:3": (1664, 2496),
"3:2": (2496, 1664),
"3:4": (1760, 2368),
"4:3": (2368, 1760),
"4:5": (1824, 2272),
"5:4": (2272, 1824),
"1:1": (2048, 2048),
"16:9": (2752, 1536),
"9:16": (1536, 2752),
"9:21": (1344, 3136),
}
def _find_nearest_aspect_ratio(
ratio: float,
buckets: dict[ASPECT_RATIO_LITERALS, tuple[int, int]],
) -> tuple[int, int]:
wh_pairs = sorted(
buckets.values(),
key=lambda wh: abs(wh[0] / wh[1] - ratio),
)
return wh_pairs[0]
if __name__ == "__main__":
import asyncio
async def main_async():
client = SensenovaText2ImageClient(
api_key=global_configs.SN_IMAGE_GEN_API_KEY,
base_url=global_configs.SN_IMAGE_GEN_BASE_URL,
)
result = await client.generate(
prompt="A cat wearing a hat",
image_size="1K",
aspect_ratio="16:9",
)
print(result)
asyncio.run(main_async())

View File

@@ -0,0 +1,5 @@
# llm module - Language Model (text only)
from .anthropic_adapter import AnthropicMessagesAdapter
from .chat_completions_adapter import OpenAIChatAdapter
__all__ = ["AnthropicMessagesAdapter", "OpenAIChatAdapter"]

View File

@@ -0,0 +1,161 @@
"""Anthropic Messages API adapter for text and vision."""
from __future__ import annotations
import logging
from typing import Any
import httpx
from sn_image_base.utils.error_utils import U1HttpResponseParseError
from sn_image_base.utils.httpx_client import httpx_response_raise_for_status_code
from sn_image_base.vlm.utils import image_to_base64
from sn_image_base.vlm.vlm_adapter import VlmAdapter
from .llm_adapter import LlmAdapter
logger = logging.getLogger(__name__)
DEFAULT_REQUEST_TIMEOUT = 150.0
DEFAULT_MAX_TOKENS = 4096
class AnthropicMessagesAdapter(LlmAdapter, VlmAdapter):
"""Anthropic Messages API adapter for text-only and vision calls."""
def __init__(
self,
endpoint_url: str,
api_key: str,
model: str,
*,
max_tokens: int = DEFAULT_MAX_TOKENS,
timeout: float = DEFAULT_REQUEST_TIMEOUT,
async_client: httpx.AsyncClient | None = None,
) -> None:
self._url = endpoint_url
self._api_key = api_key
self._default_model = model
self._max_tokens = max_tokens
self._timeout = timeout
self._external_client = async_client
self._client: httpx.AsyncClient | None = async_client
logger.info(
"AnthropicMessagesAdapter: endpoint=%s model=%s max_tokens=%s",
self._url,
self._default_model,
self._max_tokens,
)
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
@property
def _headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
"x-api-key": self._api_key,
}
@staticmethod
def _build_vision_content(
user_prompt: str,
images: list[str | bytes],
) -> list[dict[str, Any]]:
blocks: list[dict[str, Any]] = [{"type": "text", "text": user_prompt}]
for image in images:
mime, b64 = image_to_base64(image)
blocks.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime,
"data": b64,
},
}
)
return blocks
def _build_payload(
self,
user_prompt: str,
system_prompt: str,
model: str | None,
*,
images: list[str | bytes] | None = None,
) -> dict[str, Any]:
messages: list[dict[str, Any]] = []
if system_prompt:
messages.append({"role": "user", "content": system_prompt})
user_content: str | list[dict[str, Any]]
if images:
user_content = self._build_vision_content(user_prompt, images)
else:
user_content = user_prompt
messages.append({"role": "user", "content": user_content})
return {
"model": model or self._default_model,
"messages": messages,
"max_tokens": self._max_tokens,
}
@staticmethod
def _parse_response(data: dict[str, Any]) -> str:
content = data.get("content", [])
if content:
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
thinking = data.get("thinking")
if thinking:
return f"[Think] {thinking}"
raise RuntimeError("Anthropic Messages response has no extractable content.")
async def _post_payload(self, payload: dict[str, Any]) -> str:
resp = await self._get_client().post(self._url, json=payload, headers=self._headers)
httpx_response_raise_for_status_code(resp)
try:
data = resp.json()
except ValueError as exc:
raise U1HttpResponseParseError(
detail=f"Failed to parse HTTP response. {resp.request.url}. Response content: {resp.content}",
code=resp.status_code,
) from exc
return self._parse_response(data)
async def text_completion(
self,
user_prompt: str,
system_prompt: str = "",
model: str | None = None,
) -> str:
payload = self._build_payload(user_prompt, system_prompt, model)
return await self._post_payload(payload)
async def vision_completion(
self,
user_prompt: str,
images: list[str | bytes],
system_prompt: str = "",
model: str | None = None,
) -> str:
payload = self._build_payload(
user_prompt,
system_prompt,
model,
images=images,
)
return await self._post_payload(payload)
async def aclose(self) -> None:
if self._external_client is None and self._client is not None:
await self._client.aclose()
self._client = None

View File

@@ -0,0 +1,276 @@
"""OpenAI-compatible chat/completions adapter for text and vision."""
from __future__ import annotations
import json
import logging
import os
from typing import Any
import httpx
from sn_image_base.configs import is_valid_base_url
from sn_image_base.exceptions import InvalidBaseUrlError, MissingApiKeyError
from sn_image_base.utils.error_utils import (
U1HttpBadResponseError,
U1HttpNotFoundError,
U1HttpResponseParseError,
error_type_to_error_class,
finish_reason_to_error_class,
sanitize_base64_in_data,
)
from sn_image_base.utils.httpx_client import httpx_response_raise_for_status_code
from sn_image_base.vlm.utils import image_to_data_url
from sn_image_base.vlm.vlm_adapter import VlmAdapter
from .llm_adapter import LlmAdapter
logger = logging.getLogger(__name__)
DEFAULT_REQUEST_TIMEOUT = 600.0
DEFAULT_MAX_COMPLETION_TOKENS = 8192
class OpenAIChatAdapter(LlmAdapter, VlmAdapter):
"""OpenAI-compatible ``/chat/completions`` adapter for text and vision."""
def __init__(
self,
endpoint_url: str,
api_key: str,
model: str,
*,
timeout: float = DEFAULT_REQUEST_TIMEOUT,
async_client: httpx.AsyncClient | None = None,
reasoning_effort: str | None = None,
) -> None:
self._url = endpoint_url
self._api_key = api_key
self._default_model = model
self._timeout = timeout
self._reasoning_effort = reasoning_effort or None
self._external_client = async_client
self._client: httpx.AsyncClient | None = async_client
logger.info(
"OpenAIChatAdapter: endpoint=%s model=%s reasoning_effort=%s",
self._url,
self._default_model,
self._reasoning_effort,
)
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(timeout=self._timeout)
return self._client
@property
def _headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
@staticmethod
def _build_vision_content(
user_prompt: str,
images: list[str | bytes],
) -> list[dict[str, Any]]:
content: list[dict[str, Any]] = [{"type": "text", "text": user_prompt}]
content.extend(
{"type": "image_url", "image_url": {"url": image_to_data_url(img)}} for img in images
)
return content
def _build_payload(
self,
user_prompt: str,
system_prompt: str,
model: str,
*,
images: list[str | bytes] | None = None,
max_completion_tokens: int | None = DEFAULT_MAX_COMPLETION_TOKENS,
) -> dict[str, Any]:
messages: list[dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
user_content: str | list[dict[str, Any]]
if images:
user_content = self._build_vision_content(user_prompt, images)
else:
user_content = user_prompt
messages.append({"role": "user", "content": user_content})
payload: dict[str, Any] = {
"model": model,
"messages": messages,
}
if self._reasoning_effort:
payload["reasoning_effort"] = self._reasoning_effort
if max_completion_tokens:
payload["max_completion_tokens"] = max_completion_tokens
return payload
@staticmethod
def _parse_response(data: dict[str, Any]) -> str:
if "error" in data and (error := data["error"]):
error_message = error.get("message")
error_type = error.get("type")
error_code = error.get("code")
error_class, explanation = error_type_to_error_class(error_type)
raise error_class(
explanation,
detail=f"chat/completions response has error. Error: {error_message}",
code=error_code,
)
choices = data.get("choices") or []
if not choices:
sanitized_data = sanitize_base64_in_data(data)
dumped = json.dumps(sanitized_data, ensure_ascii=False)
raise U1HttpBadResponseError(
detail=f"chat/completions response has no choices. Response: {dumped}",
)
contents: list[str] = []
finish_reason: str | None = None
for choice in choices:
msg = choice.get("message", {})
finish_reason = choice.get("finish_reason") or finish_reason
content_val = msg.get("content")
if isinstance(content_val, str):
contents.append(content_val)
elif isinstance(content_val, list):
parts: list[str] = []
for block in content_val:
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text")
if isinstance(text, str):
parts.append(text)
contents.append("".join(parts))
final_content = "".join(contents)
if final_content:
return final_content
sanitized_data = sanitize_base64_in_data(data)
dumped = json.dumps(sanitized_data, ensure_ascii=False)
detail_msg = ""
if finish_reason:
detail_msg += f"\n^ Finish reason: {finish_reason}"
detail_msg += f"\n^ Response: {dumped}"
if finish_reason == "stop":
raise U1HttpBadResponseError(
"chat/completions response with empty content.",
detail=detail_msg,
)
if finish_reason:
error_class, explanation = finish_reason_to_error_class(finish_reason)
raise error_class(explanation, detail=detail_msg)
raise U1HttpBadResponseError(
"chat/completions response has no content. No finish reason provided.",
detail=detail_msg,
)
async def _post_payload(self, payload: dict[str, Any], model: str) -> str:
resp = await self._get_client().post(self._url, json=payload, headers=self._headers)
try:
httpx_response_raise_for_status_code(resp)
data = resp.json()
except U1HttpNotFoundError as exc:
raise U1HttpNotFoundError(
detail=f"{exc.detail} model={model!r}",
code=resp.status_code,
) from exc
except ValueError as exc:
raise U1HttpResponseParseError(
detail=f"Failed to parse HTTP response. {resp.request.url}. Response content: {resp.content}",
code=resp.status_code,
) from exc
return self._parse_response(data)
async def text_completion(
self,
user_prompt: str,
system_prompt: str = "",
model: str | None = None,
) -> str:
resolved_model = model or self._default_model
payload = self._build_payload(user_prompt, system_prompt, resolved_model)
return await self._post_payload(payload, resolved_model)
async def vision_completion(
self,
user_prompt: str,
images: list[str | bytes],
system_prompt: str = "",
model: str | None = None,
) -> str:
resolved_model = model or self._default_model
payload = self._build_payload(
user_prompt,
system_prompt,
resolved_model,
images=images,
)
return await self._post_payload(payload, resolved_model)
async def aclose(self) -> None:
if self._external_client is None and self._client is not None:
await self._client.aclose()
self._client = None
if __name__ == "__main__":
import argparse
import asyncio
from sn_image_base.configs import global_configs
parser = argparse.ArgumentParser(description="Async OpenAI-compatible chat adapter.")
parser.add_argument("--prompt", default=None, help="Prompt to use for the model")
parser.add_argument("--system-prompt", default=None, help="System prompt to use")
parser.add_argument("--image", default=os.environ.get("IMAGE_PATH"), help="Optional image path")
args = parser.parse_args()
async def main() -> None:
prompt = args.prompt or "Write a poem about the topic: 'Hello world'"
base_url = global_configs.SN_CHAT_BASE_URL
if not base_url:
raise InvalidBaseUrlError(
f"No base URL provided for chat runtime. {global_configs.get_env_var_help('SN_CHAT_BASE_URL')}"
)
if not is_valid_base_url(base_url):
raise InvalidBaseUrlError(
f"Invalid base URL for chat runtime: {base_url}. {global_configs.get_env_var_help('SN_CHAT_BASE_URL')}"
)
endpoint_url = f"{base_url.rstrip('/')}/chat/completions"
api_key = global_configs.SN_CHAT_API_KEY
if not api_key:
raise MissingApiKeyError(
f"No API key provided for chat runtime. {global_configs.get_env_var_help('SN_CHAT_API_KEY')}"
)
model = global_configs.SN_TEXT_MODEL
adapter = OpenAIChatAdapter(
endpoint_url=endpoint_url,
api_key=api_key,
model=model,
)
try:
if args.image:
result = await adapter.vision_completion(
user_prompt=prompt,
images=[args.image],
system_prompt=args.system_prompt or "",
)
else:
result = await adapter.text_completion(
user_prompt=prompt,
system_prompt=args.system_prompt or "",
)
print(result)
finally:
await adapter.aclose()
asyncio.run(main())

View File

@@ -0,0 +1,51 @@
"""Abstract base class for LLM (Language Model) adapters."""
from __future__ import annotations
from abc import ABC, abstractmethod
class LlmAdapter(ABC):
"""Uniform async interface for a single Language Model backend.
Each concrete adapter wraps one LLM endpoint + model combination and
exposes a single :meth:`text_completion` coroutine. Synchronous
calling is intentionally **not** supported; callers must run inside an
asyncio event loop.
**Client ownership contract** — when a shared
:class:`httpx.AsyncClient` is supplied at construction time the adapter
*reuses* it and must **not** close it; the caller retains full ownership
of the client's lifecycle. When no external client is provided the
adapter creates and owns an internal client and must close it in
:meth:`aclose`.
"""
@abstractmethod
async def text_completion(
self,
user_prompt: str,
system_prompt: str = "",
model: str | None = None,
) -> str:
"""Send a text-only prompt to the model and return the reply.
Args:
user_prompt: User-facing text instruction.
system_prompt: System-level instruction prepended to the
conversation. Defaults to ''.
model: Model name to use. If None, uses the default set at
initialization.
Returns:
str: Raw text response from the model.
"""
@abstractmethod
async def aclose(self) -> None:
"""Release async resources owned by this adapter.
Must be called when the adapter is no longer needed. Adapters that
were given an external shared client must implement this as a no-op;
adapters that created their own internal client must close it here.
"""

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
import base64
import contextlib
import json
from collections.abc import Iterable, Mapping
from typing import Any
class U1BaseError(Exception):
MESSAGE = "Base error"
def __init__(
self,
message: str | None = None,
detail: str | None = None,
code: int | None = None,
**kwargs: Any,
) -> None:
if message is None:
message = self.MESSAGE
super().__init__(message)
self.message = message
self.code = code
self.detail = detail
def __str__(self) -> str:
if self.code:
msg = f"{self.__class__.__name__}[{self.code}]"
else:
msg = f"{self.__class__.__name__}"
if self.message:
msg += f"(message={self.message!r})"
if self.detail:
msg += f" <detail>{self.detail}</detail>"
return msg
# ----------------------
# HTTP Errors
# ----------------------
class U1HttpErrorBase(U1BaseError):
MESSAGE = "Base HTTP Error"
class U1HttpAuthError(U1HttpErrorBase):
MESSAGE = "Authentication or Authorization Failed"
class U1HttpNotFoundError(U1HttpErrorBase):
MESSAGE = "Resource Not Found"
class U1HttpTooManyRequestsError(U1HttpErrorBase):
MESSAGE = "Too Many Requests"
class U1HttpServerError(U1HttpErrorBase):
MESSAGE = "Server Error"
class U1HttpBadRequestError(U1HttpErrorBase):
MESSAGE = "Bad Request"
class U1HttpPermissionError(U1HttpErrorBase):
MESSAGE = "Permission Error"
class U1HttpResponseParseError(U1HttpErrorBase):
MESSAGE = "Failed to parse HTTP response"
class U1HttpTimeoutError(U1HttpErrorBase):
MESSAGE = "Timeout Error"
class U1HttpNetworkError(U1HttpErrorBase):
MESSAGE = "Network Error"
class U1HttpUnknownError(U1HttpErrorBase):
MESSAGE = "Unknown Error"
class U1HttpForbiddenContentError(U1HttpErrorBase):
MESSAGE = "Forbidden Content Filtered"
class U1HttpTruncatedResponseError(U1HttpErrorBase):
MESSAGE = "Truncated Response"
class U1HttpBadResponseError(U1HttpErrorBase):
MESSAGE = "Bad Response"
def finish_reason_to_error_class(finish_reason: str) -> tuple[type[U1HttpErrorBase], str]:
if finish_reason == "length":
explanation = "Response was truncated due to length limit."
return U1HttpTruncatedResponseError, explanation
elif finish_reason == "content_filter":
explanation = "Response was filtered due to content policy."
return U1HttpForbiddenContentError, explanation
elif finish_reason in ("tool_calls", "function_call"):
explanation = "Response was halted due to tool calls or function calls."
return U1HttpBadRequestError, explanation
elif finish_reason == "stop":
explanation = "Response was completed normally."
return U1HttpBadResponseError, explanation
return U1HttpBadRequestError, f"Unknown finish reason: {finish_reason!r}."
def error_type_to_error_class(error_type: str) -> tuple[type[U1HttpErrorBase], str]:
if error_type == "invalid_request_error":
explanation = "Invalid request error."
return U1HttpBadRequestError, explanation
elif error_type == "rate_limit_error":
explanation = "Rate limit exceeded."
return U1HttpTooManyRequestsError, explanation
elif error_type == "authentication_error":
explanation = "Authentication error."
return U1HttpAuthError, explanation
elif error_type == "api_error":
explanation = "API service internal error."
return U1HttpServerError, explanation
elif error_type == "permission_error":
explanation = "You are not authorized to access this resource."
return U1HttpPermissionError, explanation
return U1HttpBadRequestError, f"Unknown error type: {error_type!r}."
def sanitize_base64_in_data(data: Any, *, truncate_length: int = 200) -> Any:
"""Recursively replace base64-encoded strings in data structure.
Args:
data: Data to sanitize (dict, list, str, or other)
truncate_length: Maximum length of base64-encoded string to truncate
Returns:
Sanitized data with base64 strings replaced by placeholders
Example:
>>> _sanitize_base64_in_data({"image": "iVBORw0KG..." * 100})
{"image": "<base64-data: 1200 bytes>"}
"""
# Handle binary data first (bytes, bytearray, memoryview)
if isinstance(data, (bytes, bytearray)):
# Try: bytes -> str
with contextlib.suppress(Exception):
data = data.decode("utf-8")
if isinstance(data, (bytes, bytearray, memoryview)):
return f'<binary-data len="{len(data)}bytes"/>'
if isinstance(data, str):
# Try: str -> dict | list
with contextlib.suppress(Exception):
data = json.loads(data)
seen_ids: set[int] = set() # Prevent circular references
def __recursive_sanitize_base64_in_data(
data: Mapping | Iterable | str | Any,
) -> dict | list | str | Any:
if isinstance(data, str):
if _is_base64_string(data) and len(data) > truncate_length:
# Truncate base64-encoded string, replace it with placeholder
len_str = f"{len(data):,d}bytes"
return f'<base64-data len="{len_str}">{data[:truncate_length]}...{TRUNCATED_MARKER}...{data[-truncate_length:]}</base64-data>'
return data
elif isinstance(data, Mapping):
obj_id = id(data)
if obj_id in seen_ids:
return "<circular-reference:mapping/>"
seen_ids.add(obj_id)
result = {
key: __recursive_sanitize_base64_in_data(value) for key, value in data.items()
}
seen_ids.remove(obj_id)
return result
elif isinstance(data, Iterable):
obj_id = id(data)
if obj_id in seen_ids:
return "<circular-reference:iterable/>"
seen_ids.add(obj_id)
result = [__recursive_sanitize_base64_in_data(item) for item in data]
seen_ids.remove(obj_id)
return result
return data
return __recursive_sanitize_base64_in_data(data)
TRUNCATED_MARKER = "<<<///TRUNCATED///>>>"
BASE64_DETECTION_MIN_LENGTH = 200 # Minimum length to consider as potential base64
BASE64_CHARS = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=")
def _is_base64_string(value: str) -> bool:
"""Check if a string looks like base64-encoded data.
Args:
value: String to check
Returns:
True if the string appears to be base64-encoded data
Heuristics:
- Length >= BASE64_DETECTION_MIN_LENGTH (200 chars)
- At least 80% of characters are valid base64 chars (A-Za-z0-9+/=)
- No whitespace or newlines (valid base64 is continuous)
"""
if not isinstance(value, str) or len(value) < BASE64_DETECTION_MIN_LENGTH:
return False
# Check if mostly base64 characters (allow some tolerance)
if value.startswith("data:"):
# Remove the prefix like "data:image/jpeg;base64,"
index = value.find(";base64,")
if index != -1:
value = value[index + len(";base64,") :]
valid_count = sum(1 for c in value if c in BASE64_CHARS)
ratio = valid_count / len(value)
if ratio >= 0.98:
with contextlib.suppress(Exception):
base64.b64decode(value)
return True
return False

View File

@@ -0,0 +1,180 @@
"""Shared httpx async client factory for vigeneval evaluators.
Centralizes connection pool limits, pool timeout, and optional file descriptor
limit check to avoid PoolTimeout and 'Too many open files' under high concurrency.
"""
import contextlib
import json
import resource
from typing import Any
import httpx
from .error_utils import (
U1HttpAuthError,
U1HttpBadRequestError,
U1HttpNotFoundError,
U1HttpServerError,
U1HttpTooManyRequestsError,
)
def check_file_descriptor_limit(max_connections: int, margin: int = 200) -> None:
"""Raise if process file descriptor limit is too low for max_connections.
Avoids 'Too many open files' mid-run when using a large httpx connection pool.
No-op on Windows or when resource module has no RLIMIT_NOFILE.
Args:
max_connections: Intended httpx pool max_connections.
margin: Extra FDs to reserve for app (logs, other files). Default 200.
Raises:
RuntimeError: If soft limit < max_connections + margin.
"""
try:
soft, _hard = resource.getrlimit(resource.RLIMIT_NOFILE)
except (ImportError, AttributeError, OSError):
return
required = max_connections + margin
if soft < required:
raise RuntimeError(
f"File descriptor limit too low for max_connections={max_connections}. "
f"Current soft limit: {soft}, need at least {required}. "
"Raise the limit before running, e.g.: ulimit -n 2048 # or higher, then re-run."
)
def create_async_httpx_client(
headers: dict[str, str],
*,
timeout: float = 600.0,
max_connections: int = 500,
pool_timeout: float = 60.0,
check_fd_limit: bool = False,
verify: bool = True,
**client_kwargs: Any,
) -> httpx.AsyncClient:
"""Create an httpx.AsyncClient with shared defaults for vigeneval evaluators.
Automatically uses proxy from environment variables (HTTPS_PROXY, HTTP_PROXY, etc.)
when trust_env=True (default). Supports proxy authentication via URL format:
http://username:password@proxy_host:port
Connection pool limits and pool timeout help avoid PoolTimeout under high concurrency.
Optionally checks process file descriptor limit before creating the client.
Args:
headers: Request headers (e.g. Content-Type, Authorization).
timeout: Request timeout in seconds. Default 600.
max_connections: Connection pool size. Default 500; use 1000 for batch
high parallelism (and check_fd_limit=True).
pool_timeout: Max seconds to wait for a connection from the pool. Default 60.
check_fd_limit: If True, call check_file_descriptor_limit(max_connections)
and raise before creating the client. Use for batch evaluators.
verify: If False, disable SSL certificate verification (avoids
CERTIFICATE_VERIFY_FAILED). Use only for dev/testing or trusted networks.
**client_kwargs: Passed through to httpx.AsyncClient (e.g. base_url).
Returns:
A new httpx.AsyncClient. Caller must aclose() when done.
Example:
# Set proxy with authentication in environment
export HTTP_PROXY="http://user:pass@proxy.example.com:3128"
export HTTPS_PROXY="http://user:pass@proxy.example.com:3128"
# Create client - proxy is automatically used
client = create_async_httpx_client(
headers={"Authorization": "Bearer token"},
max_connections=100,
)
"""
if check_fd_limit:
check_file_descriptor_limit(max_connections)
# Note: Proxy configuration is handled automatically by httpx when trust_env=True.
# We don't need to explicitly read or pass proxy URLs - httpx will read from
# environment variables (HTTPS_PROXY, HTTP_PROXY, etc.) and handle authentication.
limits = httpx.Limits(
max_connections=max_connections,
max_keepalive_connections=min(400, max_connections),
keepalive_expiry=30.0,
)
# Create transport without explicit proxy parameter when trust_env=True
# This allows httpx to properly handle proxy authentication from environment
transport = httpx.AsyncHTTPTransport(
verify=verify,
trust_env=True,
local_address="0.0.0.0",
limits=limits,
)
# Create client with trust_env=True to enable proxy from environment
return httpx.AsyncClient(
transport=transport,
headers=headers,
timeout=httpx.Timeout(timeout, pool=pool_timeout),
verify=verify,
trust_env=True, # Enable reading proxy from environment variables
**client_kwargs,
)
def httpx_response_raise_for_status_code(response: httpx.Response) -> None:
"""Check httpx response status code and raise appropriate exceptions.
Args:
response: The httpx response object.
verbose: Whether to log verbose information.
Raises:
AuthError: If response status is 401 or 403.
APIError: If response status is 429 or 5xx.
InvalidRequestError: If response status is 4xx (except 401, 403, 429).
"""
# Try best effort to parse response content & headers
response_headers = "[N/A]" # Not available
response_content = "[N/A]" # Not available
request_url = "[N/A]"
request_method = "[N/A]"
with contextlib.suppress(Exception):
response_headers = response.headers
response_headers = dict(response_headers)
with contextlib.suppress(Exception):
response_content = response.content
response_content = response_content.decode("utf-8")
response_content = json.loads(response_content)
with contextlib.suppress(Exception):
request_method = response.request.method
request_method = request_method.upper()
request_url = str(response.request.url)
if response.status_code == 404:
raise U1HttpNotFoundError(
detail=f"{request_method} {request_url!r} not found. Please check the URL and the model name.",
code=response.status_code,
)
if response.status_code in (401, 403):
raise U1HttpAuthError(
detail=f"Authentication or authorization failed. {request_method} {request_url!r}. Response content: {response_content}",
code=response.status_code,
)
elif response.status_code in (429, 503):
raise U1HttpTooManyRequestsError(
detail=f"Service temporarily unavailable. Please try again later. {request_method} {request_url!r}. Response content: {response_content}",
code=response.status_code,
)
elif 500 <= response.status_code <= 599:
raise U1HttpServerError(
detail=f"Request failed. {request_method} {request_url!r}. Response content: {response_content}",
code=response.status_code,
)
elif 400 <= response.status_code <= 499:
raise U1HttpBadRequestError(
detail=f"Bad request. {request_method} {request_url!r}. Response content: {response_content}",
code=response.status_code,
)

View File

@@ -0,0 +1,5 @@
# vlm module - Vision Language Model
from .vlm_adapter import VlmAdapter
__all__ = ["VlmAdapter"]

View File

@@ -0,0 +1,120 @@
"""Image encoding / decoding utilities for VLM."""
from __future__ import annotations
import base64
import io
from pathlib import Path
from PIL import Image
def read_image_bytes(image: str | bytes) -> bytes:
"""Read raw image bytes from a path or return bytes unchanged.
Args:
image: File path to an image, or raw image bytes.
Returns:
bytes: Raw image bytes.
Raises:
FileNotFoundError: If image is a path and the file does not exist.
"""
if isinstance(image, bytes):
return image
path = Path(image)
if not path.is_file():
raise FileNotFoundError(f"Image file not found: {image}")
return path.read_bytes()
def detect_mime(data: bytes) -> str:
"""Infer MIME type from image magic bytes.
Args:
data: Raw image bytes (at least 8 bytes for PNG check).
Returns:
str: 'image/png', 'image/jpeg', or 'image/png' as fallback.
"""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:3] == b"\xff\xd8\xff":
return "image/jpeg"
return "image/png"
def detect_suffix(data: bytes) -> str:
"""Infer file suffix from image magic bytes.
Args:
data: Raw image bytes.
Returns:
str: '.png', '.jpg', or '.bin' as fallback.
"""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return ".png"
if data[:3] == b"\xff\xd8\xff":
return ".jpg"
return ".bin"
def image_to_mime_and_bytes(image: str | bytes) -> tuple[str, bytes]:
"""Get MIME type and raw bytes; convert to PNG if format is not PNG/JPEG.
Args:
image: File path or raw image bytes.
Returns:
tuple[str, bytes]: (mime_type, raw_bytes). Unknown formats become PNG.
"""
raw = read_image_bytes(image)
mime = detect_mime(raw)
if mime in ("image/png", "image/jpeg"):
return mime, raw
img = Image.open(io.BytesIO(raw)).convert("RGBA")
buf = io.BytesIO()
img.save(buf, format="PNG")
return "image/png", buf.getvalue()
def image_to_base64(image: str | bytes) -> tuple[str, str]:
"""Encode image to MIME type and base64 string.
Args:
image: File path or raw image bytes.
Returns:
tuple[str, str]: (mime_type, base64_encoded_string).
"""
mime, raw = image_to_mime_and_bytes(image)
return mime, base64.b64encode(raw).decode("utf-8")
def image_to_data_url(image: str | bytes) -> str:
"""Build a data URL (data:mime;base64,...) for the image.
Args:
image: File path or raw image bytes.
Returns:
str: Data URL string.
"""
mime, b64 = image_to_base64(image)
return f"data:{mime};base64,{b64}"
def mask_secret(secret: str) -> str:
"""Mask a secret for logging (e.g. show first 6 and last 4 chars).
Args:
secret: Raw secret string.
Returns:
str: Masked string (e.g. 'abcdef...ghij' or all '*' if length <= 8).
"""
if len(secret) <= 8:
return "*" * len(secret)
return f"{secret[:6]}...{secret[-4:]}"

View File

@@ -0,0 +1,55 @@
"""Abstract base class for VLM (Vision Language Model) adapters."""
from __future__ import annotations
from abc import ABC, abstractmethod
class VlmAdapter(ABC):
"""Uniform async interface for a single Vision Language Model backend.
Each concrete adapter wraps one LLM endpoint + model combination and
exposes a single :meth:`vision_completion` coroutine. Synchronous
calling is intentionally **not** supported; callers must run inside an
asyncio event loop.
**Client ownership contract** — when a shared
:class:`httpx.AsyncClient` is supplied at construction time the adapter
*reuses* it and must **not** close it; the caller retains full ownership
of the client's lifecycle. When no external client is provided the
adapter creates and owns an internal client and must close it in
:meth:`aclose`.
"""
@abstractmethod
async def vision_completion(
self,
user_prompt: str,
images: list[str | bytes],
system_prompt: str = "",
model: str | None = None,
) -> str:
"""Send image(s) and a text prompt to the model; return the reply.
Args:
user_prompt: User-facing text instruction.
images: One or more images to pass to the model. Each element
is either a file-path string or raw image bytes.
system_prompt: System-level instruction prepended to the
conversation. Defaults to ''.
model: Model name to use. If None, uses the default set at
initialization.
Returns:
str: Raw text response from the model (may contain JSON or
markdown-wrapped JSON depending on the model and prompt).
"""
@abstractmethod
async def aclose(self) -> None:
"""Release async resources owned by this adapter.
Must be called when the adapter is no longer needed. Adapters that
were given an external shared client must implement this as a no-op;
adapters that created their own internal client must close it here.
"""