first commit
This commit is contained in:
1
sn-image-base/scripts/sn_image_base/__init__.py
Normal file
1
sn-image-base/scripts/sn_image_base/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# sn-image-base scripts
|
||||
313
sn-image-base/scripts/sn_image_base/configs.py
Normal file
313
sn-image-base/scripts/sn_image_base/configs.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal, get_args, get_origin, get_type_hints
|
||||
from urllib.parse import urlparse
|
||||
|
||||
SCRIPT_DIR = Path(__file__).absolute().parent
|
||||
# "skills" directory that contains "sn-*" skills (e.g. "sn-image-base", "sn-infographic", etc.)
|
||||
SKILLS_DIR = SCRIPT_DIR.parents[1]
|
||||
|
||||
|
||||
def prepare_env() -> None:
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
warnings.warn("python-dotenv is not installed, `.env` files will be ignored", stacklevel=2)
|
||||
return
|
||||
# Priorities:
|
||||
# 1. ".env" in the agent's config directory:
|
||||
# - openclaw: ~/.openclaw/.env
|
||||
# - hermes: ~/.openclaw/.env
|
||||
# 2. ".env" in current working directory. (depends on how the agent runs the skill)
|
||||
# 3. Environment variables
|
||||
# ------------------------------------------------------------
|
||||
# In reverse order of priority, the latter overrides the former:
|
||||
# 3 -- do nothing; overridden by other env files
|
||||
# 2 --
|
||||
load_dotenv(override=True)
|
||||
# 1 --
|
||||
if "OPENCLAW_SHELL" in os.environ:
|
||||
agent_config_dir = Path("~/.openclaw").expanduser()
|
||||
else:
|
||||
agent_config_dir = Path("~/.hermes").expanduser()
|
||||
if (dotenv_path := agent_config_dir / ".env").exists():
|
||||
load_dotenv(dotenv_path, override=True)
|
||||
|
||||
|
||||
prepare_env()
|
||||
|
||||
|
||||
class Field:
|
||||
"""Metadata marker that pairs a field with one or more env var names.
|
||||
|
||||
Env vars are tried in order; the first env var that is set is returned.
|
||||
"""
|
||||
|
||||
__slots__ = ("env_names", "required", "secret")
|
||||
|
||||
def __init__(self, *env_names: str, required: bool = False, secret: bool = False) -> None:
|
||||
self.env_names: tuple[str, ...] | None = tuple(env_names) if env_names else None
|
||||
self.required = required
|
||||
self.secret = secret
|
||||
|
||||
def resolve(self, target_type: type | None = None) -> str | int | float | None:
|
||||
"""Return the first env var value that is set, converted to target_type.
|
||||
|
||||
Args:
|
||||
target_type: The type to convert to (str, int, float, etc.) or None.
|
||||
If not int or float, returns the raw string.
|
||||
|
||||
Returns:
|
||||
The converted value, or None if none of the env vars exist.
|
||||
"""
|
||||
if not self.env_names:
|
||||
return None
|
||||
for n in self.env_names:
|
||||
if n in os.environ:
|
||||
raw = os.environ[n]
|
||||
if target_type is int:
|
||||
return int(raw)
|
||||
if target_type is float:
|
||||
return float(raw)
|
||||
# For other types (Literal, etc.), return raw string
|
||||
return raw
|
||||
return None
|
||||
|
||||
|
||||
class Configs:
|
||||
"""Central registry of env var names and built-in defaults.
|
||||
|
||||
Fields annotated with ``Annotated[str, EnvVar(...)]`` are resolved in
|
||||
``__init__``: env vars are tried in order; if none is set, the class-level
|
||||
default is kept.
|
||||
"""
|
||||
|
||||
# global defaults shared by all SN capabilities.
|
||||
SN_API_KEY: Annotated[str, Field("SN_API_KEY", secret=True)] = ""
|
||||
SN_BASE_URL: Annotated[str, Field("SN_BASE_URL")] = ""
|
||||
|
||||
# image-generate
|
||||
SN_IMAGE_GEN_API_KEY: Annotated[
|
||||
str, Field("SN_IMAGE_GEN_API_KEY", "SN_API_KEY", required=True, secret=True)
|
||||
] = ""
|
||||
SN_IMAGE_GEN_BASE_URL: Annotated[
|
||||
str, Field("SN_IMAGE_GEN_BASE_URL", "SN_BASE_URL", required=True)
|
||||
] = "https://token.sensenova.cn/v1"
|
||||
SN_IMAGE_GEN_MODEL_TYPE: Annotated[
|
||||
Literal["sensenova", "nano-banana", "openai-image"], Field("SN_IMAGE_GEN_MODEL_TYPE")
|
||||
] = "sensenova"
|
||||
SN_IMAGE_GEN_MODEL: Annotated[str, Field("SN_IMAGE_GEN_MODEL")] = "sensenova-u1-fast"
|
||||
|
||||
# chat runtime shared by text and vision commands; command-specific
|
||||
# SN_TEXT_* / SN_VISION_* values override these defaults.
|
||||
SN_CHAT_API_KEY: Annotated[str, Field("SN_CHAT_API_KEY", "SN_API_KEY", secret=True)] = ""
|
||||
SN_CHAT_BASE_URL: Annotated[str, Field("SN_CHAT_BASE_URL", "SN_BASE_URL")] = (
|
||||
"https://token.sensenova.cn/v1"
|
||||
)
|
||||
SN_CHAT_TYPE: Annotated[
|
||||
Literal["anthropic-messages", "openai-completions"], Field("SN_CHAT_TYPE")
|
||||
] = "openai-completions"
|
||||
SN_CHAT_MODEL: Annotated[str, Field("SN_CHAT_MODEL")] = "sensenova-6.7-flash-lite"
|
||||
SN_TEXT_API_KEY: Annotated[
|
||||
str, Field("SN_TEXT_API_KEY", "SN_CHAT_API_KEY", "SN_API_KEY", secret=True)
|
||||
] = ""
|
||||
SN_TEXT_BASE_URL: Annotated[
|
||||
str, Field("SN_TEXT_BASE_URL", "SN_CHAT_BASE_URL", "SN_BASE_URL")
|
||||
] = ""
|
||||
SN_TEXT_TYPE: Annotated[
|
||||
Literal["anthropic-messages", "openai-completions"],
|
||||
Field("SN_TEXT_TYPE", "SN_CHAT_TYPE"),
|
||||
] = ""
|
||||
SN_TEXT_MODEL: Annotated[str, Field("SN_TEXT_MODEL", "SN_CHAT_MODEL")] = (
|
||||
"sensenova-6.7-flash-lite"
|
||||
)
|
||||
SN_VISION_API_KEY: Annotated[
|
||||
str, Field("SN_VISION_API_KEY", "SN_CHAT_API_KEY", "SN_API_KEY", secret=True)
|
||||
] = ""
|
||||
SN_VISION_BASE_URL: Annotated[
|
||||
str, Field("SN_VISION_BASE_URL", "SN_CHAT_BASE_URL", "SN_BASE_URL")
|
||||
] = ""
|
||||
SN_VISION_TYPE: Annotated[
|
||||
Literal["anthropic-messages", "openai-completions"],
|
||||
Field("SN_VISION_TYPE", "SN_CHAT_TYPE"),
|
||||
] = ""
|
||||
SN_VISION_MODEL: Annotated[str, Field("SN_VISION_MODEL", "SN_CHAT_MODEL")] = (
|
||||
"sensenova-6.7-flash-lite"
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
for field, hint in get_type_hints(type(self), include_extras=True).items():
|
||||
env_var = next((a for a in get_args(hint) if isinstance(a, Field)), None)
|
||||
if env_var is None:
|
||||
continue
|
||||
# Extract the actual type (unwrap Annotated, handle Literal)
|
||||
origin = get_origin(hint)
|
||||
actual_type = get_args(hint)[0] if origin is Annotated else hint
|
||||
if (val := env_var.resolve(actual_type)) is not None:
|
||||
setattr(self, field, val)
|
||||
|
||||
def to_string(self, mask_secrets: bool = True) -> str:
|
||||
rows = []
|
||||
for field_name, hint in get_type_hints(type(self), include_extras=True).items():
|
||||
field = next((a for a in get_args(hint) if isinstance(a, Field)), None)
|
||||
value = getattr(self, field_name, None)
|
||||
v = str(value)
|
||||
if mask_secrets and v and field and field.secret:
|
||||
if len(v) > 10:
|
||||
v = f"{v[:6]}{'*' * (len(v) - 10)}{v[-4:]}"
|
||||
elif len(v) > 4:
|
||||
v = f"{v[:4]}{'*' * (len(v) - 4)}"
|
||||
else:
|
||||
v = "*" * len(v)
|
||||
rows.append(f"{field_name}: {v}")
|
||||
return "\n".join(rows)
|
||||
|
||||
def validate_configs(self) -> tuple[list[tuple[str, str]], list[tuple[str, str]]]:
|
||||
field_env_names: dict[str, tuple[str, ...] | str] = {}
|
||||
errors: list[tuple[str, str]] = []
|
||||
for field_name, hint in get_type_hints(type(self), include_extras=True).items():
|
||||
field = next((a for a in get_args(hint) if isinstance(a, Field)), None)
|
||||
if field is None:
|
||||
continue
|
||||
if env_names := field.env_names:
|
||||
if len(env_names) > 1:
|
||||
field_env_names[field_name] = env_names
|
||||
elif len(env_names) == 1:
|
||||
field_env_names[field_name] = env_names[0]
|
||||
value = getattr(self, field_name, None)
|
||||
if not value:
|
||||
if field.required:
|
||||
if field_name == "SN_IMAGE_GEN_API_KEY":
|
||||
msg = (
|
||||
"Image generation API key is not set; configure SN_API_KEY, "
|
||||
"or configure SN_IMAGE_GEN_API_KEY only for an image-generation-specific override"
|
||||
)
|
||||
else:
|
||||
msg = f"Field '{field_name}' is required but not set; try setting the environment variable(s) {field.env_names}"
|
||||
errors.append((field_name, msg))
|
||||
continue
|
||||
|
||||
# Check fields combination rules:
|
||||
if not self.SN_IMAGE_GEN_MODEL:
|
||||
errors.append((
|
||||
"SN_IMAGE_GEN_MODEL",
|
||||
f"SN_IMAGE_GEN_MODEL is required when SN_IMAGE_GEN_MODEL_TYPE is {self.SN_IMAGE_GEN_MODEL_TYPE!r}",
|
||||
))
|
||||
|
||||
warnings: list[tuple[str, str]] = []
|
||||
runtime_checks = {
|
||||
"text": {
|
||||
"api_key": ("SN_TEXT_API_KEY",),
|
||||
"base_url": ("SN_TEXT_BASE_URL", "SN_CHAT_BASE_URL"),
|
||||
"model": ("SN_TEXT_MODEL",),
|
||||
"type": ("SN_TEXT_TYPE", "SN_CHAT_TYPE"),
|
||||
},
|
||||
"vision": {
|
||||
"api_key": ("SN_VISION_API_KEY",),
|
||||
"base_url": ("SN_VISION_BASE_URL", "SN_CHAT_BASE_URL"),
|
||||
"model": ("SN_VISION_MODEL",),
|
||||
"type": ("SN_VISION_TYPE", "SN_CHAT_TYPE"),
|
||||
},
|
||||
}
|
||||
for runtime, checks in runtime_checks.items():
|
||||
for field_kind, keys in checks.items():
|
||||
if any(getattr(self, key) for key in keys):
|
||||
continue
|
||||
env_help = " / ".join(
|
||||
", ".join(field_env_names[key])
|
||||
if isinstance(field_env_names.get(key), tuple)
|
||||
else str(field_env_names.get(key, key))
|
||||
for key in keys
|
||||
)
|
||||
warnings.append((
|
||||
keys[0],
|
||||
f"{keys[0]} is not set; {runtime} {field_kind} may be unavailable. Try setting: {env_help}",
|
||||
))
|
||||
|
||||
# check urls
|
||||
errors.extend(
|
||||
(
|
||||
key,
|
||||
f"{key} is not a valid base URL: {getattr(self, key)}",
|
||||
)
|
||||
for key in ("SN_CHAT_BASE_URL", "SN_TEXT_BASE_URL", "SN_VISION_BASE_URL")
|
||||
if getattr(self, key) and not is_valid_base_url(getattr(self, key))
|
||||
)
|
||||
errors.extend(
|
||||
(
|
||||
key,
|
||||
f"{key} is not a valid base URL: {getattr(self, key)}",
|
||||
)
|
||||
for key in (
|
||||
"SN_BASE_URL",
|
||||
"SN_IMAGE_GEN_BASE_URL",
|
||||
)
|
||||
if getattr(self, key) and not is_valid_base_url(getattr(self, key))
|
||||
)
|
||||
return errors, warnings
|
||||
|
||||
def get_annotated_field(self, field_name: str) -> Field | None:
|
||||
hints = get_type_hints(type(self), include_extras=True)
|
||||
if field_name not in hints:
|
||||
return None
|
||||
hint = hints[field_name]
|
||||
field_inst = next((a for a in get_args(hint) if isinstance(a, Field)), None)
|
||||
return field_inst
|
||||
|
||||
def get_env_var_help(self, field_name: str) -> str:
|
||||
"""Return a help string describing which environment variables can be used
|
||||
to set the specified configuration field.
|
||||
|
||||
Args:
|
||||
field_name: The name of the configuration field (e.g., "SN_CHAT_API_KEY").
|
||||
|
||||
Returns:
|
||||
A string describing the environment variable(s) that control this field.
|
||||
Returns an error message if the field does not exist or has no EnvVar annotation.
|
||||
"""
|
||||
if not hasattr(type(self), field_name):
|
||||
return f"Field '{field_name}' does not exist in Configs."
|
||||
|
||||
field_inst = self.get_annotated_field(field_name)
|
||||
if field_inst is None:
|
||||
return f"Field '{field_name}' is not configurable via environment variables."
|
||||
|
||||
current_value = getattr(self, field_name)
|
||||
env_names = list(field_inst.env_names) if field_inst.env_names else []
|
||||
if len(env_names) == 1:
|
||||
return (
|
||||
f"To set '{field_name}', configure the environment variable: {env_names[0]}\n"
|
||||
f"Current value: {current_value!r}"
|
||||
)
|
||||
else:
|
||||
env_list = ", ".join(env_names)
|
||||
return (
|
||||
f"To set '{field_name}', configure one of these environment variables: {env_list}\n"
|
||||
f"They are tried in order; the first set value is used.\n"
|
||||
f"Current value: {current_value!r}"
|
||||
)
|
||||
|
||||
|
||||
def is_valid_base_url(url: str) -> bool:
|
||||
with contextlib.suppress(ValueError):
|
||||
parsed = urlparse(url)
|
||||
return bool(parsed.scheme and parsed.netloc)
|
||||
return False
|
||||
|
||||
|
||||
def reload_env() -> None:
|
||||
global global_configs
|
||||
|
||||
prepare_env()
|
||||
try:
|
||||
global_configs = Configs()
|
||||
print("✅ Reloaded global_configs")
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to reload global_configs: {e}", stacklevel=2)
|
||||
|
||||
|
||||
global_configs = Configs()
|
||||
39
sn-image-base/scripts/sn_image_base/exceptions.py
Normal file
39
sn-image-base/scripts/sn_image_base/exceptions.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Shared exceptions for sn-image-base."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class U1BaseError(Exception):
|
||||
"""Base exception for sn-image-base."""
|
||||
|
||||
DEFAULT_MESSAGE = "An error occurred in the sn-image-base skill."
|
||||
|
||||
def __init__(self, message: str | None = None) -> None:
|
||||
if message is None:
|
||||
message = self.DEFAULT_MESSAGE
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BadConfigurationError(U1BaseError):
|
||||
"""Raised when the configuration is invalid."""
|
||||
|
||||
DEFAULT_MESSAGE = "The configuration is invalid."
|
||||
|
||||
|
||||
class MissingApiKeyError(BadConfigurationError):
|
||||
"""Raised when API key is not provided via CLI argument or environment variable."""
|
||||
|
||||
DEFAULT_MESSAGE = (
|
||||
"API key is required but was not provided. "
|
||||
"Set SN_API_KEY, or set SN_IMAGE_GEN_API_KEY only for an image-generation-specific "
|
||||
"override, or pass --api-key explicitly."
|
||||
)
|
||||
|
||||
|
||||
class InvalidBaseUrlError(BadConfigurationError):
|
||||
"""Raised when base URL is not provided via CLI argument or environment variable."""
|
||||
|
||||
DEFAULT_MESSAGE = (
|
||||
"Base URL is required but was not provided. "
|
||||
"Set SN_IMAGE_GEN_BASE_URL or SN_BASE_URL, or pass --base-url explicitly."
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
from .nano_banana import NanoBananaText2ImageClient
|
||||
from .openai_image import OpenAIImageGenerationClient
|
||||
from .sensenova import SensenovaText2ImageClient
|
||||
|
||||
__all__ = [
|
||||
"NanoBananaText2ImageClient",
|
||||
"OpenAIImageGenerationClient",
|
||||
"SensenovaText2ImageClient",
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def ensure_output_path(path: Path) -> Path:
|
||||
"""Ensure the parent directory of the given path exists.
|
||||
|
||||
Args:
|
||||
path (Path):
|
||||
The file path whose parent directory should be created.
|
||||
|
||||
Returns:
|
||||
Path:
|
||||
The original path unchanged.
|
||||
"""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from sn_image_base.utils.error_utils import U1HttpResponseParseError
|
||||
from sn_image_base.utils.httpx_client import (
|
||||
create_async_httpx_client,
|
||||
httpx_response_raise_for_status_code,
|
||||
)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT = 300.0
|
||||
DEFAULT_MAX_CONNECTIONS = 100
|
||||
|
||||
|
||||
class T2IBaseClient(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self.model = model
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._max_connections = max_connections
|
||||
self._timeout = timeout
|
||||
self._ssl_verify = ssl_verify
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = create_async_httpx_client(
|
||||
self.headers,
|
||||
timeout=self._timeout,
|
||||
max_connections=self._max_connections,
|
||||
verify=self._ssl_verify,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self._api_key
|
||||
|
||||
@property
|
||||
def base_url(self) -> str | None:
|
||||
return self._base_url
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_api_url(self, *args: Any, **kwargs: Any) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def build_payload(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def headers(self) -> dict[str, str]: ...
|
||||
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
httpx_response_raise_for_status_code(response)
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except ValueError as exc:
|
||||
raise U1HttpResponseParseError(
|
||||
detail=f"Failed to parse HTTP response. {response.request.url}. Response content: {response.content}",
|
||||
code=response.status_code,
|
||||
) from exc
|
||||
306
sn-image-base/scripts/sn_image_base/generation/nano_banana.py
Normal file
306
sn-image-base/scripts/sn_image_base/generation/nano_banana.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from typing_extensions import override
|
||||
|
||||
from sn_image_base.configs import global_configs, is_valid_base_url
|
||||
from sn_image_base.utils.error_utils import U1HttpErrorBase
|
||||
|
||||
from .core import ensure_output_path
|
||||
from .core.client_base import (
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
DEFAULT_MAX_CONNECTIONS,
|
||||
T2IBaseClient,
|
||||
)
|
||||
|
||||
DEFAULT_MODEL_SIZE: Literal["1K", "2K", "4K"] = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "16:9"
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
|
||||
|
||||
|
||||
class NanoBananaText2ImageClient(T2IBaseClient):
|
||||
"""Async client for Google Nano Banana API."""
|
||||
|
||||
# requires `{model}` placeholder for format string
|
||||
DEFAULT_API_PATH = "/v1beta/models/{model}:generateContent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the NanoBananaText2ImageClient.
|
||||
|
||||
Args:
|
||||
api_key (str):
|
||||
API key for authentication.
|
||||
base_url (str | None, optional):
|
||||
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
|
||||
model (str | None, optional):
|
||||
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
|
||||
max_connections (int, optional):
|
||||
Maximum number of connections. Defaults to 100.
|
||||
timeout (float, optional):
|
||||
Total timeout in seconds for HTTP requests.
|
||||
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
|
||||
ssl_verify (bool, optional):
|
||||
If True, enable TLS verification. Defaults to True.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
ssl_verify=ssl_verify,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_size: Literal["1K", "2K", "4K"] = DEFAULT_MODEL_SIZE,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
output_path: Path | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Generate an image from text prompt.
|
||||
|
||||
Args:
|
||||
prompt (str):
|
||||
Text prompt for image generation.
|
||||
negative_prompt (str, optional):
|
||||
Negative prompt. Defaults to "".
|
||||
model (str | None, optional):
|
||||
Model name override. Defaults to None.
|
||||
image_size (str, optional):
|
||||
Image size preset ("1K", "2K", "4K"). Defaults to DEFAULT_MODEL_SIZE.
|
||||
aspect_ratio (str, optional):
|
||||
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
|
||||
output_path (Path | None, optional):
|
||||
Output path for the generated image. Defaults to None.
|
||||
**kwargs:
|
||||
Additional arguments reserved for backend compatibility.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
Dictionary with keys: status, output (path), message.
|
||||
"""
|
||||
model = model or self.model
|
||||
# Normalize image_size to uppercase for NanoBanana API
|
||||
image_size = image_size.upper() # type: ignore[assignment]
|
||||
payload = self.build_payload(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image_size=image_size,
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
headers = self.headers
|
||||
api_url = self.get_api_url(model)
|
||||
|
||||
if output_path is None:
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
|
||||
output_path = ensure_output_path(output_path)
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
create_response = await client.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
data = self.parse_response(create_response)
|
||||
except U1HttpErrorBase as exc:
|
||||
details = exc.detail or ""
|
||||
field_name = None
|
||||
if exc.code == 404:
|
||||
field_name = "SN_IMAGE_GEN_BASE_URL"
|
||||
elif exc.code == 401:
|
||||
field_name = "SN_IMAGE_GEN_API_KEY"
|
||||
if field_name is not None:
|
||||
field_hint = global_configs.get_annotated_field(field_name)
|
||||
if field_hint is not None:
|
||||
env_names = list(field_hint.env_names) if field_hint.env_names else []
|
||||
if env_names:
|
||||
if len(env_names) == 1:
|
||||
details += (
|
||||
f"\nIs the environment variable `{env_names[0]}` set correctly?"
|
||||
)
|
||||
else:
|
||||
env_names_str = ", ".join([f"`{n}`" for n in env_names])
|
||||
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.code}: {exc.message}",
|
||||
"message": details,
|
||||
}
|
||||
try:
|
||||
images = data["images"]
|
||||
if not images:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "No image generated from the model",
|
||||
}
|
||||
image, mime_type = images[-1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
suffix = mime_type_to_suffix(mime_type)
|
||||
saved_path = output_path.with_suffix(suffix)
|
||||
saved_path.write_bytes(image_bytes)
|
||||
return {
|
||||
"status": "ok",
|
||||
"output": str(saved_path),
|
||||
"message": "Image generated successfully",
|
||||
}
|
||||
except httpx.HTTPStatusError as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.response.status_code}",
|
||||
"message": f"http error: {exc.response.status_code} {exc.response.text}",
|
||||
}
|
||||
except (httpx.HTTPError, OSError, ValueError) as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": type(exc).__name__,
|
||||
"message": f"request error: {exc}",
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def api_key(self) -> str:
|
||||
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return api_key
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise ValueError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
@override
|
||||
def get_api_url(self, model: str | None = None) -> str:
|
||||
model = model or self.model
|
||||
path = self.DEFAULT_API_PATH.format(model=model).lstrip("/")
|
||||
api_url = f"{self.base_url.rstrip('/')}/{path}"
|
||||
return api_url
|
||||
|
||||
@override
|
||||
def build_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
*,
|
||||
image_size: Literal["1K", "2K", "4K"] = DEFAULT_MODEL_SIZE,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
max_output_tokens: int = 8192,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
parts: list[dict] = [{"text": prompt}]
|
||||
if (image_b64 := kwargs.get("image_b64")) and (
|
||||
image_mime_type := kwargs.get("image_mime_type")
|
||||
):
|
||||
if image_mime_type not in ["image/jpeg", "image/png"]:
|
||||
msg = (
|
||||
f"Unsupported image MIME type: {image_mime_type}. "
|
||||
"Supported types: image/jpeg, image/png"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
parts.append({"inline_data": {"mime_type": image_mime_type, "data": image_b64}})
|
||||
return {
|
||||
"contents": [{"role": "USER", "parts": parts}],
|
||||
"generationConfig": {
|
||||
"imageConfig": {"aspectRatio": aspect_ratio, "imageSize": image_size},
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
},
|
||||
"safetySettings": [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
],
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"x-goog-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@override
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
raw_data = super().parse_response(response)
|
||||
|
||||
images: list[tuple[str, str]] = []
|
||||
finish_reasons: list[str] = []
|
||||
candidates: list[dict] = raw_data.get("candidates") or []
|
||||
for c in candidates:
|
||||
content: dict[str, Any] = c.get("content") or {}
|
||||
parts: list[dict[str, Any]] = content.get("parts") or []
|
||||
if f_reason := content.get("finishReason"):
|
||||
finish_reasons.append(f_reason)
|
||||
for p in parts:
|
||||
inline_data: dict[str, Any] = p.get("inlineData", {})
|
||||
if inline_data:
|
||||
mime_type: str = inline_data.get("mimeType") # pyright: ignore[reportAssignmentType]
|
||||
data: str = inline_data.get("data") # pyright: ignore[reportAssignmentType]
|
||||
images.append((data, mime_type))
|
||||
return {
|
||||
"images": images,
|
||||
"finish_reasons": finish_reasons,
|
||||
}
|
||||
|
||||
|
||||
def mime_type_to_suffix(mime_type: str) -> str:
|
||||
"""Convert MIME type to file suffix.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type.
|
||||
|
||||
Returns:
|
||||
str: File suffix.
|
||||
"""
|
||||
if mime_type == "image/jpeg":
|
||||
return ".jpg"
|
||||
elif mime_type == "image/png":
|
||||
return ".png"
|
||||
elif mime_type == "image/webp":
|
||||
return ".webp"
|
||||
else:
|
||||
return ".png"
|
||||
366
sn-image-base/scripts/sn_image_base/generation/openai_image.py
Normal file
366
sn-image-base/scripts/sn_image_base/generation/openai_image.py
Normal file
@@ -0,0 +1,366 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from typing_extensions import override
|
||||
|
||||
from sn_image_base.configs import global_configs, is_valid_base_url
|
||||
from sn_image_base.exceptions import BadConfigurationError
|
||||
from sn_image_base.utils.error_utils import U1HttpErrorBase
|
||||
|
||||
from .core import ensure_output_path
|
||||
from .core.client_base import (
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
DEFAULT_MAX_CONNECTIONS,
|
||||
T2IBaseClient,
|
||||
)
|
||||
|
||||
DEFAULT_RESOLUTION: Literal["1K", "2K"] = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "16:9"
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
|
||||
|
||||
B64_PARSE_PATTERN = re.compile(r"^data:([a-zA-Z0-9/]+?);base64,([+-/_A-Za-z0-9]+=*)$")
|
||||
|
||||
|
||||
class OpenAIImageGenerationClient(T2IBaseClient):
|
||||
"""Async client for OpenAI Image Generation API."""
|
||||
|
||||
DEFAULT_API_PATH = "/images/generations"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAIImageGenerationClient.
|
||||
|
||||
Args:
|
||||
api_key (str):
|
||||
API key for authentication.
|
||||
base_url (str | None, optional):
|
||||
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
|
||||
model (str | None, optional):
|
||||
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
|
||||
max_connections (int, optional):
|
||||
Maximum number of connections. Defaults to 100.
|
||||
timeout (float, optional):
|
||||
Total timeout in seconds for HTTP requests.
|
||||
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
|
||||
ssl_verify (bool, optional):
|
||||
If True, enable TLS verification. Defaults to True.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
ssl_verify=ssl_verify,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_size: Literal["1K", "2K", "1k", "2k"] | None = None,
|
||||
aspect_ratio: str | None = DEFAULT_ASPECT_RATIO,
|
||||
output_path: Path | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Generate an image from text prompt.
|
||||
|
||||
Args:
|
||||
prompt (str):
|
||||
Text prompt for image generation.
|
||||
model (str | None, optional):
|
||||
Model name override. Defaults to None.
|
||||
image_size (str, optional):
|
||||
Image size preset ("1K", "2K"). Defaults to DEFAULT_RESOLUTION.
|
||||
aspect_ratio (str, optional):
|
||||
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
|
||||
output_path (Path | None, optional):
|
||||
Output path for the generated image. Defaults to None.
|
||||
**kwargs:
|
||||
Additional arguments reserved for backend compatibility.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
Dictionary with keys: status, output (path), message.
|
||||
"""
|
||||
model = model or self.model or global_configs.SN_IMAGE_GEN_MODEL
|
||||
if not model:
|
||||
raise BadConfigurationError(
|
||||
f"Model is not set. {global_configs.get_env_var_help('SN_IMAGE_GEN_MODEL')}"
|
||||
)
|
||||
image_size = image_size or DEFAULT_RESOLUTION
|
||||
if aspect_ratio is None:
|
||||
size = None
|
||||
else:
|
||||
rw, _, rh = aspect_ratio.partition(":")
|
||||
try:
|
||||
aspect_ratio_val: float = float(int(rw) / int(rh))
|
||||
except (ValueError, ZeroDivisionError) as e:
|
||||
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}") from e
|
||||
size = self._resolve_size(
|
||||
resolution=image_size,
|
||||
aspect_ratio_val=aspect_ratio_val,
|
||||
)
|
||||
payload = self.build_payload(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
)
|
||||
headers = self.headers
|
||||
api_url = self.get_api_url(model)
|
||||
|
||||
if output_path is None:
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
|
||||
output_path = ensure_output_path(output_path)
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
create_response = await client.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
data = self.parse_response(create_response)
|
||||
except U1HttpErrorBase as exc:
|
||||
details = exc.detail or ""
|
||||
field_name = None
|
||||
if exc.code == 404:
|
||||
field_name = "SN_IMAGE_GEN_BASE_URL"
|
||||
elif exc.code == 401:
|
||||
field_name = "SN_IMAGE_GEN_API_KEY"
|
||||
if field_name is not None:
|
||||
field_hint = global_configs.get_annotated_field(field_name)
|
||||
if field_hint is not None:
|
||||
env_names = list(field_hint.env_names) if field_hint.env_names else []
|
||||
if env_names:
|
||||
if len(env_names) == 1:
|
||||
details += (
|
||||
f"\nIs the environment variable `{env_names[0]}` set correctly?"
|
||||
)
|
||||
else:
|
||||
env_names_str = ", ".join([f"`{n}`" for n in env_names])
|
||||
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.code}: {exc.message}",
|
||||
"message": details,
|
||||
}
|
||||
try:
|
||||
images = data["images"]
|
||||
if not images:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "No image generated from the model",
|
||||
}
|
||||
image_bytes, mime_type = images[-1]
|
||||
suffix = mime_type_to_suffix(mime_type)
|
||||
saved_path = output_path.with_suffix(suffix)
|
||||
saved_path.write_bytes(image_bytes)
|
||||
return {
|
||||
"status": "ok",
|
||||
"output": str(saved_path),
|
||||
"message": "Image generated successfully",
|
||||
}
|
||||
except httpx.HTTPStatusError as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.response.status_code}",
|
||||
"message": f"http error: {exc.response.status_code} {exc.response.text}",
|
||||
}
|
||||
except (httpx.HTTPError, OSError, ValueError) as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": type(exc).__name__,
|
||||
"message": f"request error: {exc}",
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def api_key(self) -> str:
|
||||
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return api_key
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise ValueError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
@override
|
||||
def get_api_url(self, model: str | None = None) -> str:
|
||||
model = model or self.model
|
||||
path = self.DEFAULT_API_PATH.format(model=model).lstrip("/")
|
||||
api_url = f"{self.base_url.rstrip('/')}/{path}"
|
||||
return api_url
|
||||
|
||||
@override
|
||||
def build_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
*,
|
||||
n: int = 1,
|
||||
size: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Example:
|
||||
{
|
||||
"model": "dall-e-3",
|
||||
"prompt": "一只戴着墨镜的猫在赛博朋克城市的街道上喝咖啡, 赛璐璐画风",
|
||||
"n": 1,
|
||||
"size": "1024x1024",
|
||||
"response_format": "b64_json",
|
||||
}
|
||||
"""
|
||||
size = size or "auto"
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"n": n,
|
||||
"size": size,
|
||||
"response_format": "b64_json",
|
||||
**kwargs,
|
||||
}
|
||||
return payload
|
||||
|
||||
@property
|
||||
@override
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@override
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
"""
|
||||
Example:
|
||||
{
|
||||
"data": [{
|
||||
"b64_json": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOYA3Q..."
|
||||
}],
|
||||
"created": 1776789055
|
||||
"usage": {
|
||||
"input_tokens":773,
|
||||
"output_tokens":765,
|
||||
"total_tokens":1538,
|
||||
"input_tokens_details": {
|
||||
"text_tokens":8,
|
||||
"image_tokens":765
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
raw_data = super().parse_response(response)
|
||||
|
||||
images: list[tuple[bytes, str]] = []
|
||||
data_items: list[dict] = raw_data.get("data") or []
|
||||
for item in data_items:
|
||||
encoded: str = item.get("b64_json") or ""
|
||||
if not encoded:
|
||||
continue
|
||||
|
||||
if encoded.startswith("data:"):
|
||||
match = B64_PARSE_PATTERN.match(encoded)
|
||||
if match:
|
||||
mime_type = match.group(1)
|
||||
b64_data = match.group(2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid base64 data in response: {encoded[:100]}... (truncated)"
|
||||
)
|
||||
else:
|
||||
mime_type = "image/png" # fallback to png
|
||||
b64_data = encoded
|
||||
try:
|
||||
decoded = base64.b64decode(b64_data)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to decode base64 data in response: {e}. b64_json: {encoded[:100]}... (truncated)"
|
||||
) from e
|
||||
images.append((decoded, mime_type))
|
||||
return {
|
||||
"images": images,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _resolve_size(
|
||||
cls,
|
||||
resolution: str,
|
||||
aspect_ratio_val: float | None,
|
||||
) -> str:
|
||||
"""Convert (resolution, aspect_ratio) to a pixel size string."""
|
||||
resolution = resolution.upper()
|
||||
if resolution == "1K":
|
||||
max_pixel = 1024**2
|
||||
elif resolution == "2K":
|
||||
max_pixel = 2048**2
|
||||
else:
|
||||
raise ValueError(f"Unsupported resolution: {resolution}")
|
||||
aspect_ratio_val = aspect_ratio_val or 1
|
||||
if aspect_ratio_val < 1 / 3 or aspect_ratio_val > 3:
|
||||
raise ValueError(f"Aspect ratio value must be between [1/3, 3], got {aspect_ratio_val}")
|
||||
|
||||
width: int = round(math.sqrt(max_pixel * aspect_ratio_val))
|
||||
height: int = round(math.sqrt(max_pixel / aspect_ratio_val))
|
||||
return f"{width}x{height}"
|
||||
|
||||
|
||||
def mime_type_to_suffix(mime_type: str) -> str:
|
||||
"""Convert MIME type to file suffix.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type.
|
||||
|
||||
Returns:
|
||||
str: File suffix.
|
||||
"""
|
||||
if mime_type == "image/jpeg":
|
||||
return ".jpg"
|
||||
elif mime_type == "image/png":
|
||||
return ".png"
|
||||
elif mime_type == "image/webp":
|
||||
return ".webp"
|
||||
else:
|
||||
return ".png"
|
||||
508
sn-image-base/scripts/sn_image_base/generation/sensenova.py
Normal file
508
sn-image-base/scripts/sn_image_base/generation/sensenova.py
Normal file
@@ -0,0 +1,508 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
from sn_image_base.configs import global_configs, is_valid_base_url
|
||||
from sn_image_base.exceptions import InvalidBaseUrlError, MissingApiKeyError
|
||||
from sn_image_base.generation.core import ensure_output_path
|
||||
from sn_image_base.generation.core.client_base import (
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
DEFAULT_MAX_CONNECTIONS,
|
||||
T2IBaseClient,
|
||||
)
|
||||
from sn_image_base.utils.error_utils import U1HttpErrorBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
DEFAULT_RESOLUTION: Literal["1K", "2K", "4K"] = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "16:9"
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
|
||||
|
||||
|
||||
IMAGE_GEN_ENDPOINT = "/images/generations"
|
||||
|
||||
|
||||
class SensenovaText2ImageClient(T2IBaseClient):
|
||||
"""Async client for Sensenova text-to-image API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the SensenovaText2ImageClient.
|
||||
|
||||
Args:
|
||||
api_key (str):
|
||||
API key for authentication.
|
||||
base_url (str | None, optional):
|
||||
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
|
||||
model (str | None, optional):
|
||||
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
|
||||
max_connections (int, optional):
|
||||
Maximum number of connections. Defaults to 100.
|
||||
timeout (float, optional):
|
||||
Total timeout in seconds for HTTP requests.
|
||||
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
|
||||
ssl_verify (bool, optional):
|
||||
If True, enable TLS verification. Defaults to True.
|
||||
"""
|
||||
api_key = api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise MissingApiKeyError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
base_url = base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise InvalidBaseUrlError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise InvalidBaseUrlError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
ssl_verify=ssl_verify,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_size: Literal["1K", "2K", "4K"] = DEFAULT_RESOLUTION,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
output_path: Path | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Generate an image from text prompt.
|
||||
|
||||
Args:
|
||||
prompt (str):
|
||||
Text prompt for image generation.
|
||||
negative_prompt (str, optional):
|
||||
Negative prompt. Defaults to "".
|
||||
model (str | None, optional):
|
||||
Model name override. Defaults to None.
|
||||
image_size (str, optional):
|
||||
Image size preset ("1K", "2K", "4K"). Defaults to DEFAULT_RESOLUTION.
|
||||
aspect_ratio (str, optional):
|
||||
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
|
||||
output_path (Path | None, optional):
|
||||
Output path for the generated image. Defaults to None.
|
||||
**kwargs:
|
||||
Additional arguments reserved for backend compatibility.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
Dictionary with keys: status, output (path), message.
|
||||
"""
|
||||
model = model or self.model or global_configs.SN_IMAGE_GEN_MODEL
|
||||
# Normalize image_size to uppercase for NanoBanana API
|
||||
image_size = image_size.upper() # type: ignore[assignment]
|
||||
output_format = "png"
|
||||
size = self._resolve_size(image_size, aspect_ratio)
|
||||
payload = self.build_payload(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
aspect_ratio=aspect_ratio,
|
||||
output_format=output_format,
|
||||
)
|
||||
headers = self.headers
|
||||
api_url = self.get_api_url(model)
|
||||
if output_path is None:
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
|
||||
output_path = ensure_output_path(output_path)
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
create_response = await client.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
data = self.parse_response(create_response)
|
||||
except U1HttpErrorBase as exc:
|
||||
details = exc.detail or ""
|
||||
field_name = None
|
||||
if exc.code == 404:
|
||||
field_name = "SN_IMAGE_GEN_BASE_URL"
|
||||
elif exc.code == 401:
|
||||
field_name = "SN_IMAGE_GEN_API_KEY"
|
||||
# elif exc.code == 400:
|
||||
# warnings.warn(f"Bad request: {exc.message}; body: {payload}", stacklevel=2)
|
||||
if field_name is not None:
|
||||
field_hint = global_configs.get_annotated_field(field_name)
|
||||
if field_hint is not None:
|
||||
env_names = list(field_hint.env_names) if field_hint.env_names else []
|
||||
if env_names:
|
||||
if len(env_names) == 1:
|
||||
details += (
|
||||
f"\nIs the environment variable `{env_names[0]}` set correctly?"
|
||||
)
|
||||
else:
|
||||
env_names_str = ", ".join([f"`{n}`" for n in env_names])
|
||||
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.code}: {exc.message}",
|
||||
"message": details,
|
||||
}
|
||||
try:
|
||||
images_urls: list[str] = data["images_urls"]
|
||||
if not images_urls:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "No image generated from the model",
|
||||
}
|
||||
url = images_urls[-1]
|
||||
suffix = f".{output_format}"
|
||||
save_path = output_path.with_suffix(suffix)
|
||||
saved_path = await download_image(url, save_path)
|
||||
return {
|
||||
"status": "ok",
|
||||
"output": str(saved_path),
|
||||
"message": "Image generated successfully",
|
||||
}
|
||||
except httpx.HTTPStatusError as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.response.status_code}",
|
||||
"message": f"http error: {exc.response.status_code} {exc.response.text}",
|
||||
}
|
||||
except (httpx.HTTPError, OSError, ValueError) as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": type(exc).__name__,
|
||||
"message": f"request error: {exc}",
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def api_key(self) -> str:
|
||||
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return api_key
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise ValueError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
@override
|
||||
def get_api_url(self, _model: str | None = None) -> str:
|
||||
base_url = self.base_url.rstrip("/")
|
||||
path = IMAGE_GEN_ENDPOINT.lstrip("/")
|
||||
api_url = f"{base_url}/{path}"
|
||||
return api_url
|
||||
|
||||
@override
|
||||
def build_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
*,
|
||||
size: str | None = None,
|
||||
modalities: Sequence[str] = ("text", "image"),
|
||||
output_format: Literal["png"] = "png",
|
||||
response_format: Literal["url"] = "url",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the payload for the SenseNova image-generation endpoint.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate an image for.
|
||||
model (str): The model to use for generation.
|
||||
size (str | None): Pixel size string (for example, "1920x1920").
|
||||
modalities (Sequence[str]): Reserved for compatibility; currently not sent.
|
||||
output_format (Literal["png"]): The output format of the image. Defaults to "png".
|
||||
response_format (Literal["url"]): The response format of the image. Defaults to "url".
|
||||
**kwargs (Any, optional): Additional parameters to pass to the API.
|
||||
|
||||
Example:
|
||||
{
|
||||
"model": "sensenova-u1-fast",
|
||||
"prompt": "A cat wearing a hat",
|
||||
"size": "1024x1024",
|
||||
"response_format": "url",
|
||||
"output_format": "png",
|
||||
}
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
# "modalities": modalities,
|
||||
"size": size,
|
||||
# "n": 1,
|
||||
"response_format": response_format,
|
||||
"output_format": output_format,
|
||||
**kwargs,
|
||||
}
|
||||
return payload
|
||||
|
||||
@property
|
||||
@override
|
||||
def headers(self) -> dict[str, str]:
|
||||
if not self.api_key:
|
||||
raise MissingApiKeyError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _resolve_size(
|
||||
cls,
|
||||
resolution: Literal["1K", "2K"] | str | None = None,
|
||||
aspect_ratio: ASPECT_RATIO_LITERALS | str | None = None,
|
||||
) -> str | None:
|
||||
"""Convert (resolution, aspect_ratio) to a pixel size string.
|
||||
|
||||
If aspect_ratio is None, returns the resolution as-is (e.g. "1K").
|
||||
"""
|
||||
if not resolution and not aspect_ratio:
|
||||
return None
|
||||
resolution = resolution or "2K"
|
||||
aspect_ratio = aspect_ratio or "1:1"
|
||||
if resolution == "1K":
|
||||
buckets = BUCKETS_1K
|
||||
elif resolution == "2K":
|
||||
buckets = BUCKETS_2K
|
||||
else:
|
||||
raise ValueError(f"Unsupported resolution: {resolution!r}. Must be '1K' or '2K'.")
|
||||
try:
|
||||
ws, _, hs = aspect_ratio.strip().partition(":")
|
||||
width = int(ws)
|
||||
height = int(hs)
|
||||
ratio = width / height
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid aspect ratio: {aspect_ratio!r}") from e
|
||||
if ratio > 16 / 9:
|
||||
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too wide. Maximum is 16:9")
|
||||
if ratio < 9 / 21:
|
||||
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too high. Maximum is 9:21")
|
||||
w, h = _find_nearest_aspect_ratio(ratio, buckets)
|
||||
return f"{w}x{h}"
|
||||
|
||||
@override
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
"""Parse the response from the SenseNova image-generation endpoint.
|
||||
|
||||
Example response data:
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [{
|
||||
"url": "https://cdn.sensenova.dev/gen/..."
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
Args:
|
||||
response: The HTTP response from the SenseNova image-generation endpoint.
|
||||
|
||||
Returns:
|
||||
dict: Parsed data with key ``images_urls``.
|
||||
"""
|
||||
raw_data = super().parse_response(response)
|
||||
|
||||
images_urls: list[str] = []
|
||||
for item in raw_data.get("data", []):
|
||||
url = item.get("url")
|
||||
if url:
|
||||
images_urls.append(url)
|
||||
return {"images_urls": images_urls}
|
||||
|
||||
|
||||
async def download_image(
|
||||
url: str,
|
||||
save_path: Path,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
) -> Path:
|
||||
"""Download an image from a URL.
|
||||
|
||||
Args:
|
||||
url: The URL of the image to download.
|
||||
timeout: The timeout for the request.
|
||||
|
||||
Returns:
|
||||
Path: The path to the downloaded image file.
|
||||
"""
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path: Path | None = None
|
||||
bytes_written = 0
|
||||
expected_length: int | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=save_path.parent,
|
||||
prefix=f".{save_path.name}.",
|
||||
suffix=".tmp",
|
||||
delete=False,
|
||||
) as temp_file:
|
||||
temp_path = Path(temp_file.name)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length is not None:
|
||||
expected_length = int(content_length)
|
||||
async for chunk in response.aiter_bytes():
|
||||
bytes_written += len(chunk)
|
||||
temp_file.write(chunk)
|
||||
temp_file.flush()
|
||||
os.fsync(temp_file.fileno())
|
||||
|
||||
if expected_length is not None and bytes_written != expected_length:
|
||||
raise OSError(
|
||||
f"Downloaded image is incomplete: got {bytes_written} bytes, "
|
||||
f"expected {expected_length} bytes"
|
||||
)
|
||||
|
||||
assert temp_path is not None
|
||||
_validate_image_file(temp_path)
|
||||
temp_path.replace(save_path)
|
||||
return save_path
|
||||
except Exception:
|
||||
if temp_path is not None:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
def _validate_image_file(image_path: Path) -> None:
|
||||
"""Verify that the downloaded image can be decoded completely."""
|
||||
with Image.open(image_path) as image:
|
||||
image.verify()
|
||||
with Image.open(image_path) as image:
|
||||
image.load()
|
||||
|
||||
|
||||
def mime_type_to_suffix(mime_type: str) -> str:
|
||||
"""Convert MIME type to file suffix.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type.
|
||||
|
||||
Returns:
|
||||
str: File suffix.
|
||||
"""
|
||||
if mime_type == "image/jpeg":
|
||||
return ".jpg"
|
||||
elif mime_type == "image/png":
|
||||
return ".png"
|
||||
elif mime_type == "image/webp":
|
||||
return ".webp"
|
||||
else:
|
||||
return ".png"
|
||||
|
||||
|
||||
ASPECT_RATIO_LITERALS = Literal[
|
||||
"2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "1:1", "16:9", "9:16", "9:21"
|
||||
]
|
||||
BUCKETS_1K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
|
||||
"2:3": (1088, 1632),
|
||||
"3:2": (1632, 1088),
|
||||
"3:4": (1152, 1536),
|
||||
"4:3": (1536, 1152),
|
||||
"4:5": (1184, 1472),
|
||||
"5:4": (1472, 1184),
|
||||
"1:1": (1344, 1344),
|
||||
"16:9": (1792, 992),
|
||||
"9:16": (992, 1792),
|
||||
"9:21": (864, 2048),
|
||||
}
|
||||
BUCKETS_2K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
|
||||
"2:3": (1664, 2496),
|
||||
"3:2": (2496, 1664),
|
||||
"3:4": (1760, 2368),
|
||||
"4:3": (2368, 1760),
|
||||
"4:5": (1824, 2272),
|
||||
"5:4": (2272, 1824),
|
||||
"1:1": (2048, 2048),
|
||||
"16:9": (2752, 1536),
|
||||
"9:16": (1536, 2752),
|
||||
"9:21": (1344, 3136),
|
||||
}
|
||||
|
||||
|
||||
def _find_nearest_aspect_ratio(
|
||||
ratio: float,
|
||||
buckets: dict[ASPECT_RATIO_LITERALS, tuple[int, int]],
|
||||
) -> tuple[int, int]:
|
||||
wh_pairs = sorted(
|
||||
buckets.values(),
|
||||
key=lambda wh: abs(wh[0] / wh[1] - ratio),
|
||||
)
|
||||
return wh_pairs[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main_async():
|
||||
client = SensenovaText2ImageClient(
|
||||
api_key=global_configs.SN_IMAGE_GEN_API_KEY,
|
||||
base_url=global_configs.SN_IMAGE_GEN_BASE_URL,
|
||||
)
|
||||
|
||||
result = await client.generate(
|
||||
prompt="A cat wearing a hat",
|
||||
image_size="1K",
|
||||
aspect_ratio="16:9",
|
||||
)
|
||||
print(result)
|
||||
|
||||
asyncio.run(main_async())
|
||||
5
sn-image-base/scripts/sn_image_base/llm/__init__.py
Normal file
5
sn-image-base/scripts/sn_image_base/llm/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# llm module - Language Model (text only)
|
||||
from .anthropic_adapter import AnthropicMessagesAdapter
|
||||
from .chat_completions_adapter import OpenAIChatAdapter
|
||||
|
||||
__all__ = ["AnthropicMessagesAdapter", "OpenAIChatAdapter"]
|
||||
161
sn-image-base/scripts/sn_image_base/llm/anthropic_adapter.py
Normal file
161
sn-image-base/scripts/sn_image_base/llm/anthropic_adapter.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Anthropic Messages API adapter for text and vision."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from sn_image_base.utils.error_utils import U1HttpResponseParseError
|
||||
from sn_image_base.utils.httpx_client import httpx_response_raise_for_status_code
|
||||
from sn_image_base.vlm.utils import image_to_base64
|
||||
from sn_image_base.vlm.vlm_adapter import VlmAdapter
|
||||
|
||||
from .llm_adapter import LlmAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_REQUEST_TIMEOUT = 150.0
|
||||
DEFAULT_MAX_TOKENS = 4096
|
||||
|
||||
|
||||
class AnthropicMessagesAdapter(LlmAdapter, VlmAdapter):
|
||||
"""Anthropic Messages API adapter for text-only and vision calls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
*,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
timeout: float = DEFAULT_REQUEST_TIMEOUT,
|
||||
async_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._url = endpoint_url
|
||||
self._api_key = api_key
|
||||
self._default_model = model
|
||||
self._max_tokens = max_tokens
|
||||
self._timeout = timeout
|
||||
self._external_client = async_client
|
||||
self._client: httpx.AsyncClient | None = async_client
|
||||
logger.info(
|
||||
"AnthropicMessagesAdapter: endpoint=%s model=%s max_tokens=%s",
|
||||
self._url,
|
||||
self._default_model,
|
||||
self._max_tokens,
|
||||
)
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": self._api_key,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_vision_content(
|
||||
user_prompt: str,
|
||||
images: list[str | bytes],
|
||||
) -> list[dict[str, Any]]:
|
||||
blocks: list[dict[str, Any]] = [{"type": "text", "text": user_prompt}]
|
||||
for image in images:
|
||||
mime, b64 = image_to_base64(image)
|
||||
blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime,
|
||||
"data": b64,
|
||||
},
|
||||
}
|
||||
)
|
||||
return blocks
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
user_prompt: str,
|
||||
system_prompt: str,
|
||||
model: str | None,
|
||||
*,
|
||||
images: list[str | bytes] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "user", "content": system_prompt})
|
||||
|
||||
user_content: str | list[dict[str, Any]]
|
||||
if images:
|
||||
user_content = self._build_vision_content(user_prompt, images)
|
||||
else:
|
||||
user_content = user_prompt
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
return {
|
||||
"model": model or self._default_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self._max_tokens,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _parse_response(data: dict[str, Any]) -> str:
|
||||
content = data.get("content", [])
|
||||
if content:
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
|
||||
thinking = data.get("thinking")
|
||||
if thinking:
|
||||
return f"[Think] {thinking}"
|
||||
|
||||
raise RuntimeError("Anthropic Messages response has no extractable content.")
|
||||
|
||||
async def _post_payload(self, payload: dict[str, Any]) -> str:
|
||||
resp = await self._get_client().post(self._url, json=payload, headers=self._headers)
|
||||
httpx_response_raise_for_status_code(resp)
|
||||
try:
|
||||
data = resp.json()
|
||||
except ValueError as exc:
|
||||
raise U1HttpResponseParseError(
|
||||
detail=f"Failed to parse HTTP response. {resp.request.url}. Response content: {resp.content}",
|
||||
code=resp.status_code,
|
||||
) from exc
|
||||
return self._parse_response(data)
|
||||
|
||||
async def text_completion(
|
||||
self,
|
||||
user_prompt: str,
|
||||
system_prompt: str = "",
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
payload = self._build_payload(user_prompt, system_prompt, model)
|
||||
return await self._post_payload(payload)
|
||||
|
||||
async def vision_completion(
|
||||
self,
|
||||
user_prompt: str,
|
||||
images: list[str | bytes],
|
||||
system_prompt: str = "",
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
payload = self._build_payload(
|
||||
user_prompt,
|
||||
system_prompt,
|
||||
model,
|
||||
images=images,
|
||||
)
|
||||
return await self._post_payload(payload)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._external_client is None and self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
@@ -0,0 +1,276 @@
|
||||
"""OpenAI-compatible chat/completions adapter for text and vision."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from sn_image_base.configs import is_valid_base_url
|
||||
from sn_image_base.exceptions import InvalidBaseUrlError, MissingApiKeyError
|
||||
from sn_image_base.utils.error_utils import (
|
||||
U1HttpBadResponseError,
|
||||
U1HttpNotFoundError,
|
||||
U1HttpResponseParseError,
|
||||
error_type_to_error_class,
|
||||
finish_reason_to_error_class,
|
||||
sanitize_base64_in_data,
|
||||
)
|
||||
from sn_image_base.utils.httpx_client import httpx_response_raise_for_status_code
|
||||
from sn_image_base.vlm.utils import image_to_data_url
|
||||
from sn_image_base.vlm.vlm_adapter import VlmAdapter
|
||||
|
||||
from .llm_adapter import LlmAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_REQUEST_TIMEOUT = 600.0
|
||||
DEFAULT_MAX_COMPLETION_TOKENS = 8192
|
||||
|
||||
|
||||
class OpenAIChatAdapter(LlmAdapter, VlmAdapter):
|
||||
"""OpenAI-compatible ``/chat/completions`` adapter for text and vision."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
*,
|
||||
timeout: float = DEFAULT_REQUEST_TIMEOUT,
|
||||
async_client: httpx.AsyncClient | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> None:
|
||||
self._url = endpoint_url
|
||||
self._api_key = api_key
|
||||
self._default_model = model
|
||||
self._timeout = timeout
|
||||
self._reasoning_effort = reasoning_effort or None
|
||||
self._external_client = async_client
|
||||
self._client: httpx.AsyncClient | None = async_client
|
||||
logger.info(
|
||||
"OpenAIChatAdapter: endpoint=%s model=%s reasoning_effort=%s",
|
||||
self._url,
|
||||
self._default_model,
|
||||
self._reasoning_effort,
|
||||
)
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_vision_content(
|
||||
user_prompt: str,
|
||||
images: list[str | bytes],
|
||||
) -> list[dict[str, Any]]:
|
||||
content: list[dict[str, Any]] = [{"type": "text", "text": user_prompt}]
|
||||
content.extend(
|
||||
{"type": "image_url", "image_url": {"url": image_to_data_url(img)}} for img in images
|
||||
)
|
||||
return content
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
user_prompt: str,
|
||||
system_prompt: str,
|
||||
model: str,
|
||||
*,
|
||||
images: list[str | bytes] | None = None,
|
||||
max_completion_tokens: int | None = DEFAULT_MAX_COMPLETION_TOKENS,
|
||||
) -> dict[str, Any]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
user_content: str | list[dict[str, Any]]
|
||||
if images:
|
||||
user_content = self._build_vision_content(user_prompt, images)
|
||||
else:
|
||||
user_content = user_prompt
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
if self._reasoning_effort:
|
||||
payload["reasoning_effort"] = self._reasoning_effort
|
||||
if max_completion_tokens:
|
||||
payload["max_completion_tokens"] = max_completion_tokens
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _parse_response(data: dict[str, Any]) -> str:
|
||||
if "error" in data and (error := data["error"]):
|
||||
error_message = error.get("message")
|
||||
error_type = error.get("type")
|
||||
error_code = error.get("code")
|
||||
error_class, explanation = error_type_to_error_class(error_type)
|
||||
raise error_class(
|
||||
explanation,
|
||||
detail=f"chat/completions response has error. Error: {error_message}",
|
||||
code=error_code,
|
||||
)
|
||||
|
||||
choices = data.get("choices") or []
|
||||
if not choices:
|
||||
sanitized_data = sanitize_base64_in_data(data)
|
||||
dumped = json.dumps(sanitized_data, ensure_ascii=False)
|
||||
raise U1HttpBadResponseError(
|
||||
detail=f"chat/completions response has no choices. Response: {dumped}",
|
||||
)
|
||||
|
||||
contents: list[str] = []
|
||||
finish_reason: str | None = None
|
||||
for choice in choices:
|
||||
msg = choice.get("message", {})
|
||||
finish_reason = choice.get("finish_reason") or finish_reason
|
||||
content_val = msg.get("content")
|
||||
if isinstance(content_val, str):
|
||||
contents.append(content_val)
|
||||
elif isinstance(content_val, list):
|
||||
parts: list[str] = []
|
||||
for block in content_val:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
contents.append("".join(parts))
|
||||
|
||||
final_content = "".join(contents)
|
||||
if final_content:
|
||||
return final_content
|
||||
|
||||
sanitized_data = sanitize_base64_in_data(data)
|
||||
dumped = json.dumps(sanitized_data, ensure_ascii=False)
|
||||
detail_msg = ""
|
||||
if finish_reason:
|
||||
detail_msg += f"\n^ Finish reason: {finish_reason}"
|
||||
detail_msg += f"\n^ Response: {dumped}"
|
||||
if finish_reason == "stop":
|
||||
raise U1HttpBadResponseError(
|
||||
"chat/completions response with empty content.",
|
||||
detail=detail_msg,
|
||||
)
|
||||
if finish_reason:
|
||||
error_class, explanation = finish_reason_to_error_class(finish_reason)
|
||||
raise error_class(explanation, detail=detail_msg)
|
||||
raise U1HttpBadResponseError(
|
||||
"chat/completions response has no content. No finish reason provided.",
|
||||
detail=detail_msg,
|
||||
)
|
||||
|
||||
async def _post_payload(self, payload: dict[str, Any], model: str) -> str:
|
||||
resp = await self._get_client().post(self._url, json=payload, headers=self._headers)
|
||||
try:
|
||||
httpx_response_raise_for_status_code(resp)
|
||||
data = resp.json()
|
||||
except U1HttpNotFoundError as exc:
|
||||
raise U1HttpNotFoundError(
|
||||
detail=f"{exc.detail} model={model!r}",
|
||||
code=resp.status_code,
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise U1HttpResponseParseError(
|
||||
detail=f"Failed to parse HTTP response. {resp.request.url}. Response content: {resp.content}",
|
||||
code=resp.status_code,
|
||||
) from exc
|
||||
return self._parse_response(data)
|
||||
|
||||
async def text_completion(
|
||||
self,
|
||||
user_prompt: str,
|
||||
system_prompt: str = "",
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
resolved_model = model or self._default_model
|
||||
payload = self._build_payload(user_prompt, system_prompt, resolved_model)
|
||||
return await self._post_payload(payload, resolved_model)
|
||||
|
||||
async def vision_completion(
|
||||
self,
|
||||
user_prompt: str,
|
||||
images: list[str | bytes],
|
||||
system_prompt: str = "",
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
resolved_model = model or self._default_model
|
||||
payload = self._build_payload(
|
||||
user_prompt,
|
||||
system_prompt,
|
||||
resolved_model,
|
||||
images=images,
|
||||
)
|
||||
return await self._post_payload(payload, resolved_model)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._external_client is None and self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from sn_image_base.configs import global_configs
|
||||
|
||||
parser = argparse.ArgumentParser(description="Async OpenAI-compatible chat adapter.")
|
||||
parser.add_argument("--prompt", default=None, help="Prompt to use for the model")
|
||||
parser.add_argument("--system-prompt", default=None, help="System prompt to use")
|
||||
parser.add_argument("--image", default=os.environ.get("IMAGE_PATH"), help="Optional image path")
|
||||
args = parser.parse_args()
|
||||
|
||||
async def main() -> None:
|
||||
prompt = args.prompt or "Write a poem about the topic: 'Hello world'"
|
||||
base_url = global_configs.SN_CHAT_BASE_URL
|
||||
if not base_url:
|
||||
raise InvalidBaseUrlError(
|
||||
f"No base URL provided for chat runtime. {global_configs.get_env_var_help('SN_CHAT_BASE_URL')}"
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise InvalidBaseUrlError(
|
||||
f"Invalid base URL for chat runtime: {base_url}. {global_configs.get_env_var_help('SN_CHAT_BASE_URL')}"
|
||||
)
|
||||
endpoint_url = f"{base_url.rstrip('/')}/chat/completions"
|
||||
api_key = global_configs.SN_CHAT_API_KEY
|
||||
if not api_key:
|
||||
raise MissingApiKeyError(
|
||||
f"No API key provided for chat runtime. {global_configs.get_env_var_help('SN_CHAT_API_KEY')}"
|
||||
)
|
||||
model = global_configs.SN_TEXT_MODEL
|
||||
|
||||
adapter = OpenAIChatAdapter(
|
||||
endpoint_url=endpoint_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
try:
|
||||
if args.image:
|
||||
result = await adapter.vision_completion(
|
||||
user_prompt=prompt,
|
||||
images=[args.image],
|
||||
system_prompt=args.system_prompt or "",
|
||||
)
|
||||
else:
|
||||
result = await adapter.text_completion(
|
||||
user_prompt=prompt,
|
||||
system_prompt=args.system_prompt or "",
|
||||
)
|
||||
print(result)
|
||||
finally:
|
||||
await adapter.aclose()
|
||||
|
||||
asyncio.run(main())
|
||||
51
sn-image-base/scripts/sn_image_base/llm/llm_adapter.py
Normal file
51
sn-image-base/scripts/sn_image_base/llm/llm_adapter.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Abstract base class for LLM (Language Model) adapters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LlmAdapter(ABC):
|
||||
"""Uniform async interface for a single Language Model backend.
|
||||
|
||||
Each concrete adapter wraps one LLM endpoint + model combination and
|
||||
exposes a single :meth:`text_completion` coroutine. Synchronous
|
||||
calling is intentionally **not** supported; callers must run inside an
|
||||
asyncio event loop.
|
||||
|
||||
**Client ownership contract** — when a shared
|
||||
:class:`httpx.AsyncClient` is supplied at construction time the adapter
|
||||
*reuses* it and must **not** close it; the caller retains full ownership
|
||||
of the client's lifecycle. When no external client is provided the
|
||||
adapter creates and owns an internal client and must close it in
|
||||
:meth:`aclose`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def text_completion(
|
||||
self,
|
||||
user_prompt: str,
|
||||
system_prompt: str = "",
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Send a text-only prompt to the model and return the reply.
|
||||
|
||||
Args:
|
||||
user_prompt: User-facing text instruction.
|
||||
system_prompt: System-level instruction prepended to the
|
||||
conversation. Defaults to ''.
|
||||
model: Model name to use. If None, uses the default set at
|
||||
initialization.
|
||||
|
||||
Returns:
|
||||
str: Raw text response from the model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aclose(self) -> None:
|
||||
"""Release async resources owned by this adapter.
|
||||
|
||||
Must be called when the adapter is no longer needed. Adapters that
|
||||
were given an external shared client must implement this as a no-op;
|
||||
adapters that created their own internal client must close it here.
|
||||
"""
|
||||
231
sn-image-base/scripts/sn_image_base/utils/error_utils.py
Normal file
231
sn-image-base/scripts/sn_image_base/utils/error_utils.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any
|
||||
|
||||
|
||||
class U1BaseError(Exception):
|
||||
MESSAGE = "Base error"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
detail: str | None = None,
|
||||
code: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if message is None:
|
||||
message = self.MESSAGE
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.detail = detail
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.code:
|
||||
msg = f"{self.__class__.__name__}[{self.code}]"
|
||||
else:
|
||||
msg = f"{self.__class__.__name__}"
|
||||
if self.message:
|
||||
msg += f"(message={self.message!r})"
|
||||
if self.detail:
|
||||
msg += f" <detail>{self.detail}</detail>"
|
||||
return msg
|
||||
|
||||
|
||||
# ----------------------
|
||||
# HTTP Errors
|
||||
# ----------------------
|
||||
|
||||
|
||||
class U1HttpErrorBase(U1BaseError):
|
||||
MESSAGE = "Base HTTP Error"
|
||||
|
||||
|
||||
class U1HttpAuthError(U1HttpErrorBase):
|
||||
MESSAGE = "Authentication or Authorization Failed"
|
||||
|
||||
|
||||
class U1HttpNotFoundError(U1HttpErrorBase):
|
||||
MESSAGE = "Resource Not Found"
|
||||
|
||||
|
||||
class U1HttpTooManyRequestsError(U1HttpErrorBase):
|
||||
MESSAGE = "Too Many Requests"
|
||||
|
||||
|
||||
class U1HttpServerError(U1HttpErrorBase):
|
||||
MESSAGE = "Server Error"
|
||||
|
||||
|
||||
class U1HttpBadRequestError(U1HttpErrorBase):
|
||||
MESSAGE = "Bad Request"
|
||||
|
||||
|
||||
class U1HttpPermissionError(U1HttpErrorBase):
|
||||
MESSAGE = "Permission Error"
|
||||
|
||||
|
||||
class U1HttpResponseParseError(U1HttpErrorBase):
|
||||
MESSAGE = "Failed to parse HTTP response"
|
||||
|
||||
|
||||
class U1HttpTimeoutError(U1HttpErrorBase):
|
||||
MESSAGE = "Timeout Error"
|
||||
|
||||
|
||||
class U1HttpNetworkError(U1HttpErrorBase):
|
||||
MESSAGE = "Network Error"
|
||||
|
||||
|
||||
class U1HttpUnknownError(U1HttpErrorBase):
|
||||
MESSAGE = "Unknown Error"
|
||||
|
||||
|
||||
class U1HttpForbiddenContentError(U1HttpErrorBase):
|
||||
MESSAGE = "Forbidden Content Filtered"
|
||||
|
||||
|
||||
class U1HttpTruncatedResponseError(U1HttpErrorBase):
|
||||
MESSAGE = "Truncated Response"
|
||||
|
||||
|
||||
class U1HttpBadResponseError(U1HttpErrorBase):
|
||||
MESSAGE = "Bad Response"
|
||||
|
||||
|
||||
def finish_reason_to_error_class(finish_reason: str) -> tuple[type[U1HttpErrorBase], str]:
|
||||
if finish_reason == "length":
|
||||
explanation = "Response was truncated due to length limit."
|
||||
return U1HttpTruncatedResponseError, explanation
|
||||
elif finish_reason == "content_filter":
|
||||
explanation = "Response was filtered due to content policy."
|
||||
return U1HttpForbiddenContentError, explanation
|
||||
elif finish_reason in ("tool_calls", "function_call"):
|
||||
explanation = "Response was halted due to tool calls or function calls."
|
||||
return U1HttpBadRequestError, explanation
|
||||
elif finish_reason == "stop":
|
||||
explanation = "Response was completed normally."
|
||||
return U1HttpBadResponseError, explanation
|
||||
return U1HttpBadRequestError, f"Unknown finish reason: {finish_reason!r}."
|
||||
|
||||
|
||||
def error_type_to_error_class(error_type: str) -> tuple[type[U1HttpErrorBase], str]:
|
||||
if error_type == "invalid_request_error":
|
||||
explanation = "Invalid request error."
|
||||
return U1HttpBadRequestError, explanation
|
||||
elif error_type == "rate_limit_error":
|
||||
explanation = "Rate limit exceeded."
|
||||
return U1HttpTooManyRequestsError, explanation
|
||||
elif error_type == "authentication_error":
|
||||
explanation = "Authentication error."
|
||||
return U1HttpAuthError, explanation
|
||||
elif error_type == "api_error":
|
||||
explanation = "API service internal error."
|
||||
return U1HttpServerError, explanation
|
||||
elif error_type == "permission_error":
|
||||
explanation = "You are not authorized to access this resource."
|
||||
return U1HttpPermissionError, explanation
|
||||
return U1HttpBadRequestError, f"Unknown error type: {error_type!r}."
|
||||
|
||||
|
||||
def sanitize_base64_in_data(data: Any, *, truncate_length: int = 200) -> Any:
|
||||
"""Recursively replace base64-encoded strings in data structure.
|
||||
|
||||
Args:
|
||||
data: Data to sanitize (dict, list, str, or other)
|
||||
truncate_length: Maximum length of base64-encoded string to truncate
|
||||
|
||||
Returns:
|
||||
Sanitized data with base64 strings replaced by placeholders
|
||||
|
||||
Example:
|
||||
>>> _sanitize_base64_in_data({"image": "iVBORw0KG..." * 100})
|
||||
{"image": "<base64-data: 1200 bytes>"}
|
||||
"""
|
||||
# Handle binary data first (bytes, bytearray, memoryview)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
# Try: bytes -> str
|
||||
with contextlib.suppress(Exception):
|
||||
data = data.decode("utf-8")
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return f'<binary-data len="{len(data)}bytes"/>'
|
||||
if isinstance(data, str):
|
||||
# Try: str -> dict | list
|
||||
with contextlib.suppress(Exception):
|
||||
data = json.loads(data)
|
||||
|
||||
seen_ids: set[int] = set() # Prevent circular references
|
||||
|
||||
def __recursive_sanitize_base64_in_data(
|
||||
data: Mapping | Iterable | str | Any,
|
||||
) -> dict | list | str | Any:
|
||||
if isinstance(data, str):
|
||||
if _is_base64_string(data) and len(data) > truncate_length:
|
||||
# Truncate base64-encoded string, replace it with placeholder
|
||||
len_str = f"{len(data):,d}bytes"
|
||||
return f'<base64-data len="{len_str}">{data[:truncate_length]}...{TRUNCATED_MARKER}...{data[-truncate_length:]}</base64-data>'
|
||||
return data
|
||||
elif isinstance(data, Mapping):
|
||||
obj_id = id(data)
|
||||
if obj_id in seen_ids:
|
||||
return "<circular-reference:mapping/>"
|
||||
seen_ids.add(obj_id)
|
||||
result = {
|
||||
key: __recursive_sanitize_base64_in_data(value) for key, value in data.items()
|
||||
}
|
||||
seen_ids.remove(obj_id)
|
||||
return result
|
||||
elif isinstance(data, Iterable):
|
||||
obj_id = id(data)
|
||||
if obj_id in seen_ids:
|
||||
return "<circular-reference:iterable/>"
|
||||
seen_ids.add(obj_id)
|
||||
result = [__recursive_sanitize_base64_in_data(item) for item in data]
|
||||
seen_ids.remove(obj_id)
|
||||
return result
|
||||
return data
|
||||
|
||||
return __recursive_sanitize_base64_in_data(data)
|
||||
|
||||
|
||||
TRUNCATED_MARKER = "<<<///TRUNCATED///>>>"
|
||||
BASE64_DETECTION_MIN_LENGTH = 200 # Minimum length to consider as potential base64
|
||||
BASE64_CHARS = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=")
|
||||
|
||||
|
||||
def _is_base64_string(value: str) -> bool:
|
||||
"""Check if a string looks like base64-encoded data.
|
||||
|
||||
Args:
|
||||
value: String to check
|
||||
|
||||
Returns:
|
||||
True if the string appears to be base64-encoded data
|
||||
|
||||
Heuristics:
|
||||
- Length >= BASE64_DETECTION_MIN_LENGTH (200 chars)
|
||||
- At least 80% of characters are valid base64 chars (A-Za-z0-9+/=)
|
||||
- No whitespace or newlines (valid base64 is continuous)
|
||||
"""
|
||||
if not isinstance(value, str) or len(value) < BASE64_DETECTION_MIN_LENGTH:
|
||||
return False
|
||||
|
||||
# Check if mostly base64 characters (allow some tolerance)
|
||||
if value.startswith("data:"):
|
||||
# Remove the prefix like "data:image/jpeg;base64,"
|
||||
index = value.find(";base64,")
|
||||
if index != -1:
|
||||
value = value[index + len(";base64,") :]
|
||||
valid_count = sum(1 for c in value if c in BASE64_CHARS)
|
||||
ratio = valid_count / len(value)
|
||||
|
||||
if ratio >= 0.98:
|
||||
with contextlib.suppress(Exception):
|
||||
base64.b64decode(value)
|
||||
return True
|
||||
|
||||
return False
|
||||
180
sn-image-base/scripts/sn_image_base/utils/httpx_client.py
Normal file
180
sn-image-base/scripts/sn_image_base/utils/httpx_client.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Shared httpx async client factory for vigeneval evaluators.
|
||||
|
||||
Centralizes connection pool limits, pool timeout, and optional file descriptor
|
||||
limit check to avoid PoolTimeout and 'Too many open files' under high concurrency.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import resource
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .error_utils import (
|
||||
U1HttpAuthError,
|
||||
U1HttpBadRequestError,
|
||||
U1HttpNotFoundError,
|
||||
U1HttpServerError,
|
||||
U1HttpTooManyRequestsError,
|
||||
)
|
||||
|
||||
|
||||
def check_file_descriptor_limit(max_connections: int, margin: int = 200) -> None:
|
||||
"""Raise if process file descriptor limit is too low for max_connections.
|
||||
|
||||
Avoids 'Too many open files' mid-run when using a large httpx connection pool.
|
||||
No-op on Windows or when resource module has no RLIMIT_NOFILE.
|
||||
|
||||
Args:
|
||||
max_connections: Intended httpx pool max_connections.
|
||||
margin: Extra FDs to reserve for app (logs, other files). Default 200.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If soft limit < max_connections + margin.
|
||||
"""
|
||||
try:
|
||||
soft, _hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
except (ImportError, AttributeError, OSError):
|
||||
return
|
||||
required = max_connections + margin
|
||||
if soft < required:
|
||||
raise RuntimeError(
|
||||
f"File descriptor limit too low for max_connections={max_connections}. "
|
||||
f"Current soft limit: {soft}, need at least {required}. "
|
||||
"Raise the limit before running, e.g.: ulimit -n 2048 # or higher, then re-run."
|
||||
)
|
||||
|
||||
|
||||
def create_async_httpx_client(
|
||||
headers: dict[str, str],
|
||||
*,
|
||||
timeout: float = 600.0,
|
||||
max_connections: int = 500,
|
||||
pool_timeout: float = 60.0,
|
||||
check_fd_limit: bool = False,
|
||||
verify: bool = True,
|
||||
**client_kwargs: Any,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an httpx.AsyncClient with shared defaults for vigeneval evaluators.
|
||||
|
||||
Automatically uses proxy from environment variables (HTTPS_PROXY, HTTP_PROXY, etc.)
|
||||
when trust_env=True (default). Supports proxy authentication via URL format:
|
||||
http://username:password@proxy_host:port
|
||||
|
||||
Connection pool limits and pool timeout help avoid PoolTimeout under high concurrency.
|
||||
Optionally checks process file descriptor limit before creating the client.
|
||||
|
||||
Args:
|
||||
headers: Request headers (e.g. Content-Type, Authorization).
|
||||
timeout: Request timeout in seconds. Default 600.
|
||||
max_connections: Connection pool size. Default 500; use 1000 for batch
|
||||
high parallelism (and check_fd_limit=True).
|
||||
pool_timeout: Max seconds to wait for a connection from the pool. Default 60.
|
||||
check_fd_limit: If True, call check_file_descriptor_limit(max_connections)
|
||||
and raise before creating the client. Use for batch evaluators.
|
||||
verify: If False, disable SSL certificate verification (avoids
|
||||
CERTIFICATE_VERIFY_FAILED). Use only for dev/testing or trusted networks.
|
||||
**client_kwargs: Passed through to httpx.AsyncClient (e.g. base_url).
|
||||
|
||||
Returns:
|
||||
A new httpx.AsyncClient. Caller must aclose() when done.
|
||||
|
||||
Example:
|
||||
# Set proxy with authentication in environment
|
||||
export HTTP_PROXY="http://user:pass@proxy.example.com:3128"
|
||||
export HTTPS_PROXY="http://user:pass@proxy.example.com:3128"
|
||||
|
||||
# Create client - proxy is automatically used
|
||||
client = create_async_httpx_client(
|
||||
headers={"Authorization": "Bearer token"},
|
||||
max_connections=100,
|
||||
)
|
||||
"""
|
||||
if check_fd_limit:
|
||||
check_file_descriptor_limit(max_connections)
|
||||
|
||||
# Note: Proxy configuration is handled automatically by httpx when trust_env=True.
|
||||
# We don't need to explicitly read or pass proxy URLs - httpx will read from
|
||||
# environment variables (HTTPS_PROXY, HTTP_PROXY, etc.) and handle authentication.
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=min(400, max_connections),
|
||||
keepalive_expiry=30.0,
|
||||
)
|
||||
|
||||
# Create transport without explicit proxy parameter when trust_env=True
|
||||
# This allows httpx to properly handle proxy authentication from environment
|
||||
transport = httpx.AsyncHTTPTransport(
|
||||
verify=verify,
|
||||
trust_env=True,
|
||||
local_address="0.0.0.0",
|
||||
limits=limits,
|
||||
)
|
||||
|
||||
# Create client with trust_env=True to enable proxy from environment
|
||||
return httpx.AsyncClient(
|
||||
transport=transport,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(timeout, pool=pool_timeout),
|
||||
verify=verify,
|
||||
trust_env=True, # Enable reading proxy from environment variables
|
||||
**client_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def httpx_response_raise_for_status_code(response: httpx.Response) -> None:
|
||||
"""Check httpx response status code and raise appropriate exceptions.
|
||||
|
||||
Args:
|
||||
response: The httpx response object.
|
||||
verbose: Whether to log verbose information.
|
||||
|
||||
Raises:
|
||||
AuthError: If response status is 401 or 403.
|
||||
APIError: If response status is 429 or 5xx.
|
||||
InvalidRequestError: If response status is 4xx (except 401, 403, 429).
|
||||
"""
|
||||
# Try best effort to parse response content & headers
|
||||
response_headers = "[N/A]" # Not available
|
||||
response_content = "[N/A]" # Not available
|
||||
request_url = "[N/A]"
|
||||
request_method = "[N/A]"
|
||||
with contextlib.suppress(Exception):
|
||||
response_headers = response.headers
|
||||
response_headers = dict(response_headers)
|
||||
with contextlib.suppress(Exception):
|
||||
response_content = response.content
|
||||
response_content = response_content.decode("utf-8")
|
||||
response_content = json.loads(response_content)
|
||||
with contextlib.suppress(Exception):
|
||||
request_method = response.request.method
|
||||
request_method = request_method.upper()
|
||||
request_url = str(response.request.url)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise U1HttpNotFoundError(
|
||||
detail=f"{request_method} {request_url!r} not found. Please check the URL and the model name.",
|
||||
code=response.status_code,
|
||||
)
|
||||
if response.status_code in (401, 403):
|
||||
raise U1HttpAuthError(
|
||||
detail=f"Authentication or authorization failed. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
elif response.status_code in (429, 503):
|
||||
raise U1HttpTooManyRequestsError(
|
||||
detail=f"Service temporarily unavailable. Please try again later. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
elif 500 <= response.status_code <= 599:
|
||||
raise U1HttpServerError(
|
||||
detail=f"Request failed. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
elif 400 <= response.status_code <= 499:
|
||||
raise U1HttpBadRequestError(
|
||||
detail=f"Bad request. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
5
sn-image-base/scripts/sn_image_base/vlm/__init__.py
Normal file
5
sn-image-base/scripts/sn_image_base/vlm/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# vlm module - Vision Language Model
|
||||
|
||||
from .vlm_adapter import VlmAdapter
|
||||
|
||||
__all__ = ["VlmAdapter"]
|
||||
120
sn-image-base/scripts/sn_image_base/vlm/utils.py
Normal file
120
sn-image-base/scripts/sn_image_base/vlm/utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Image encoding / decoding utilities for VLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_image_bytes(image: str | bytes) -> bytes:
|
||||
"""Read raw image bytes from a path or return bytes unchanged.
|
||||
|
||||
Args:
|
||||
image: File path to an image, or raw image bytes.
|
||||
|
||||
Returns:
|
||||
bytes: Raw image bytes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If image is a path and the file does not exist.
|
||||
"""
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
path = Path(image)
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
return path.read_bytes()
|
||||
|
||||
|
||||
def detect_mime(data: bytes) -> str:
|
||||
"""Infer MIME type from image magic bytes.
|
||||
|
||||
Args:
|
||||
data: Raw image bytes (at least 8 bytes for PNG check).
|
||||
|
||||
Returns:
|
||||
str: 'image/png', 'image/jpeg', or 'image/png' as fallback.
|
||||
"""
|
||||
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
return "image/png"
|
||||
if data[:3] == b"\xff\xd8\xff":
|
||||
return "image/jpeg"
|
||||
return "image/png"
|
||||
|
||||
|
||||
def detect_suffix(data: bytes) -> str:
|
||||
"""Infer file suffix from image magic bytes.
|
||||
|
||||
Args:
|
||||
data: Raw image bytes.
|
||||
|
||||
Returns:
|
||||
str: '.png', '.jpg', or '.bin' as fallback.
|
||||
"""
|
||||
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
return ".png"
|
||||
if data[:3] == b"\xff\xd8\xff":
|
||||
return ".jpg"
|
||||
return ".bin"
|
||||
|
||||
|
||||
def image_to_mime_and_bytes(image: str | bytes) -> tuple[str, bytes]:
|
||||
"""Get MIME type and raw bytes; convert to PNG if format is not PNG/JPEG.
|
||||
|
||||
Args:
|
||||
image: File path or raw image bytes.
|
||||
|
||||
Returns:
|
||||
tuple[str, bytes]: (mime_type, raw_bytes). Unknown formats become PNG.
|
||||
"""
|
||||
raw = read_image_bytes(image)
|
||||
mime = detect_mime(raw)
|
||||
if mime in ("image/png", "image/jpeg"):
|
||||
return mime, raw
|
||||
img = Image.open(io.BytesIO(raw)).convert("RGBA")
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
return "image/png", buf.getvalue()
|
||||
|
||||
|
||||
def image_to_base64(image: str | bytes) -> tuple[str, str]:
|
||||
"""Encode image to MIME type and base64 string.
|
||||
|
||||
Args:
|
||||
image: File path or raw image bytes.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (mime_type, base64_encoded_string).
|
||||
"""
|
||||
mime, raw = image_to_mime_and_bytes(image)
|
||||
return mime, base64.b64encode(raw).decode("utf-8")
|
||||
|
||||
|
||||
def image_to_data_url(image: str | bytes) -> str:
|
||||
"""Build a data URL (data:mime;base64,...) for the image.
|
||||
|
||||
Args:
|
||||
image: File path or raw image bytes.
|
||||
|
||||
Returns:
|
||||
str: Data URL string.
|
||||
"""
|
||||
mime, b64 = image_to_base64(image)
|
||||
return f"data:{mime};base64,{b64}"
|
||||
|
||||
|
||||
def mask_secret(secret: str) -> str:
|
||||
"""Mask a secret for logging (e.g. show first 6 and last 4 chars).
|
||||
|
||||
Args:
|
||||
secret: Raw secret string.
|
||||
|
||||
Returns:
|
||||
str: Masked string (e.g. 'abcdef...ghij' or all '*' if length <= 8).
|
||||
"""
|
||||
if len(secret) <= 8:
|
||||
return "*" * len(secret)
|
||||
return f"{secret[:6]}...{secret[-4:]}"
|
||||
55
sn-image-base/scripts/sn_image_base/vlm/vlm_adapter.py
Normal file
55
sn-image-base/scripts/sn_image_base/vlm/vlm_adapter.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Abstract base class for VLM (Vision Language Model) adapters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class VlmAdapter(ABC):
|
||||
"""Uniform async interface for a single Vision Language Model backend.
|
||||
|
||||
Each concrete adapter wraps one LLM endpoint + model combination and
|
||||
exposes a single :meth:`vision_completion` coroutine. Synchronous
|
||||
calling is intentionally **not** supported; callers must run inside an
|
||||
asyncio event loop.
|
||||
|
||||
**Client ownership contract** — when a shared
|
||||
:class:`httpx.AsyncClient` is supplied at construction time the adapter
|
||||
*reuses* it and must **not** close it; the caller retains full ownership
|
||||
of the client's lifecycle. When no external client is provided the
|
||||
adapter creates and owns an internal client and must close it in
|
||||
:meth:`aclose`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def vision_completion(
|
||||
self,
|
||||
user_prompt: str,
|
||||
images: list[str | bytes],
|
||||
system_prompt: str = "",
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Send image(s) and a text prompt to the model; return the reply.
|
||||
|
||||
Args:
|
||||
user_prompt: User-facing text instruction.
|
||||
images: One or more images to pass to the model. Each element
|
||||
is either a file-path string or raw image bytes.
|
||||
system_prompt: System-level instruction prepended to the
|
||||
conversation. Defaults to ''.
|
||||
model: Model name to use. If None, uses the default set at
|
||||
initialization.
|
||||
|
||||
Returns:
|
||||
str: Raw text response from the model (may contain JSON or
|
||||
markdown-wrapped JSON depending on the model and prompt).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aclose(self) -> None:
|
||||
"""Release async resources owned by this adapter.
|
||||
|
||||
Must be called when the adapter is no longer needed. Adapters that
|
||||
were given an external shared client must implement this as a no-op;
|
||||
adapters that created their own internal client must close it here.
|
||||
"""
|
||||
Reference in New Issue
Block a user