Files
agent-skills/sn-image-base/scripts/sn_image_base/generation/sensenova.py
Hermes Agent ccc63d1e70 first commit
2026-05-10 13:52:46 +08:00

509 lines
17 KiB
Python

from __future__ import annotations
import os
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
import httpx
from PIL import Image
from typing_extensions import override
from sn_image_base.configs import global_configs, is_valid_base_url
from sn_image_base.exceptions import InvalidBaseUrlError, MissingApiKeyError
from sn_image_base.generation.core import ensure_output_path
from sn_image_base.generation.core.client_base import (
DEFAULT_HTTP_REQUEST_TIMEOUT,
DEFAULT_MAX_CONNECTIONS,
T2IBaseClient,
)
from sn_image_base.utils.error_utils import U1HttpErrorBase
if TYPE_CHECKING:
from collections.abc import Sequence
DEFAULT_RESOLUTION: Literal["1K", "2K", "4K"] = "2K"
DEFAULT_ASPECT_RATIO = "16:9"
DEFAULT_POLL_INTERVAL = 5.0
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
IMAGE_GEN_ENDPOINT = "/images/generations"
class SensenovaText2ImageClient(T2IBaseClient):
"""Async client for Sensenova text-to-image API."""
def __init__(
self,
api_key: str,
base_url: str | None = None,
*,
model: str | None = None,
max_connections: int = DEFAULT_MAX_CONNECTIONS,
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
ssl_verify: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the SensenovaText2ImageClient.
Args:
api_key (str):
API key for authentication.
base_url (str | None, optional):
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
model (str | None, optional):
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
max_connections (int, optional):
Maximum number of connections. Defaults to 100.
timeout (float, optional):
Total timeout in seconds for HTTP requests.
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
ssl_verify (bool, optional):
If True, enable TLS verification. Defaults to True.
"""
api_key = api_key or global_configs.SN_IMAGE_GEN_API_KEY
if not api_key:
raise MissingApiKeyError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
base_url = base_url or global_configs.SN_IMAGE_GEN_BASE_URL
if not base_url:
raise InvalidBaseUrlError(
"Base URL is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
)
)
if not is_valid_base_url(base_url):
raise InvalidBaseUrlError(
f"Base URL is not a valid base URL: {base_url}. "
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
)
super().__init__(
api_key=api_key,
base_url=base_url,
model=model,
max_connections=max_connections,
timeout=timeout,
ssl_verify=ssl_verify,
**kwargs,
)
@override
async def generate(
self,
prompt: str,
negative_prompt: str = "",
*,
model: str | None = None,
image_size: Literal["1K", "2K", "4K"] = DEFAULT_RESOLUTION,
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
output_path: Path | None = None,
**kwargs: Any,
) -> dict:
"""Generate an image from text prompt.
Args:
prompt (str):
Text prompt for image generation.
negative_prompt (str, optional):
Negative prompt. Defaults to "".
model (str | None, optional):
Model name override. Defaults to None.
image_size (str, optional):
Image size preset ("1K", "2K", "4K"). Defaults to DEFAULT_RESOLUTION.
aspect_ratio (str, optional):
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
output_path (Path | None, optional):
Output path for the generated image. Defaults to None.
**kwargs:
Additional arguments reserved for backend compatibility.
Returns:
dict:
Dictionary with keys: status, output (path), message.
"""
model = model or self.model or global_configs.SN_IMAGE_GEN_MODEL
# Normalize image_size to uppercase for NanoBanana API
image_size = image_size.upper() # type: ignore[assignment]
output_format = "png"
size = self._resolve_size(image_size, aspect_ratio)
payload = self.build_payload(
prompt=prompt,
model=model,
size=size,
aspect_ratio=aspect_ratio,
output_format=output_format,
)
headers = self.headers
api_url = self.get_api_url(model)
if output_path is None:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
output_path = ensure_output_path(output_path)
client = await self._get_client()
try:
create_response = await client.post(
api_url,
json=payload,
headers=headers,
)
data = self.parse_response(create_response)
except U1HttpErrorBase as exc:
details = exc.detail or ""
field_name = None
if exc.code == 404:
field_name = "SN_IMAGE_GEN_BASE_URL"
elif exc.code == 401:
field_name = "SN_IMAGE_GEN_API_KEY"
# elif exc.code == 400:
# warnings.warn(f"Bad request: {exc.message}; body: {payload}", stacklevel=2)
if field_name is not None:
field_hint = global_configs.get_annotated_field(field_name)
if field_hint is not None:
env_names = list(field_hint.env_names) if field_hint.env_names else []
if env_names:
if len(env_names) == 1:
details += (
f"\nIs the environment variable `{env_names[0]}` set correctly?"
)
else:
env_names_str = ", ".join([f"`{n}`" for n in env_names])
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
return {
"status": "failed",
"error": f"HTTP {exc.code}: {exc.message}",
"message": details,
}
try:
images_urls: list[str] = data["images_urls"]
if not images_urls:
return {
"status": "failed",
"error": "No image generated from the model",
}
url = images_urls[-1]
suffix = f".{output_format}"
save_path = output_path.with_suffix(suffix)
saved_path = await download_image(url, save_path)
return {
"status": "ok",
"output": str(saved_path),
"message": "Image generated successfully",
}
except httpx.HTTPStatusError as exc:
return {
"status": "failed",
"error": f"HTTP {exc.response.status_code}",
"message": f"http error: {exc.response.status_code} {exc.response.text}",
}
except (httpx.HTTPError, OSError, ValueError) as exc:
return {
"status": "failed",
"error": type(exc).__name__,
"message": f"request error: {exc}",
}
@property
@override
def api_key(self) -> str:
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
if not api_key:
raise ValueError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
return api_key
@property
@override
def base_url(self) -> str:
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
if not base_url:
raise ValueError(
"Base URL is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
)
)
if not is_valid_base_url(base_url):
raise ValueError(
f"Base URL is not a valid base URL: {base_url}. "
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
)
return base_url
@override
def get_api_url(self, _model: str | None = None) -> str:
base_url = self.base_url.rstrip("/")
path = IMAGE_GEN_ENDPOINT.lstrip("/")
api_url = f"{base_url}/{path}"
return api_url
@override
def build_payload(
self,
prompt: str,
model: str,
*,
size: str | None = None,
modalities: Sequence[str] = ("text", "image"),
output_format: Literal["png"] = "png",
response_format: Literal["url"] = "url",
**kwargs: Any,
) -> dict[str, Any]:
"""Build the payload for the SenseNova image-generation endpoint.
Args:
prompt (str): The prompt to generate an image for.
model (str): The model to use for generation.
size (str | None): Pixel size string (for example, "1920x1920").
modalities (Sequence[str]): Reserved for compatibility; currently not sent.
output_format (Literal["png"]): The output format of the image. Defaults to "png".
response_format (Literal["url"]): The response format of the image. Defaults to "url".
**kwargs (Any, optional): Additional parameters to pass to the API.
Example:
{
"model": "sensenova-u1-fast",
"prompt": "A cat wearing a hat",
"size": "1024x1024",
"response_format": "url",
"output_format": "png",
}
"""
payload = {
"model": model,
"prompt": prompt,
# "modalities": modalities,
"size": size,
# "n": 1,
"response_format": response_format,
"output_format": output_format,
**kwargs,
}
return payload
@property
@override
def headers(self) -> dict[str, str]:
if not self.api_key:
raise MissingApiKeyError(
"API key is missing: {}".format(
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
)
)
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
@classmethod
def _resolve_size(
cls,
resolution: Literal["1K", "2K"] | str | None = None,
aspect_ratio: ASPECT_RATIO_LITERALS | str | None = None,
) -> str | None:
"""Convert (resolution, aspect_ratio) to a pixel size string.
If aspect_ratio is None, returns the resolution as-is (e.g. "1K").
"""
if not resolution and not aspect_ratio:
return None
resolution = resolution or "2K"
aspect_ratio = aspect_ratio or "1:1"
if resolution == "1K":
buckets = BUCKETS_1K
elif resolution == "2K":
buckets = BUCKETS_2K
else:
raise ValueError(f"Unsupported resolution: {resolution!r}. Must be '1K' or '2K'.")
try:
ws, _, hs = aspect_ratio.strip().partition(":")
width = int(ws)
height = int(hs)
ratio = width / height
except Exception as e:
raise ValueError(f"Invalid aspect ratio: {aspect_ratio!r}") from e
if ratio > 16 / 9:
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too wide. Maximum is 16:9")
if ratio < 9 / 21:
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too high. Maximum is 9:21")
w, h = _find_nearest_aspect_ratio(ratio, buckets)
return f"{w}x{h}"
@override
def parse_response(self, response: httpx.Response) -> dict:
"""Parse the response from the SenseNova image-generation endpoint.
Example response data:
```json
{
"data": [{
"url": "https://cdn.sensenova.dev/gen/..."
}]
}
```
Args:
response: The HTTP response from the SenseNova image-generation endpoint.
Returns:
dict: Parsed data with key ``images_urls``.
"""
raw_data = super().parse_response(response)
images_urls: list[str] = []
for item in raw_data.get("data", []):
url = item.get("url")
if url:
images_urls.append(url)
return {"images_urls": images_urls}
async def download_image(
url: str,
save_path: Path,
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
) -> Path:
"""Download an image from a URL.
Args:
url: The URL of the image to download.
timeout: The timeout for the request.
Returns:
Path: The path to the downloaded image file.
"""
save_path.parent.mkdir(parents=True, exist_ok=True)
temp_path: Path | None = None
bytes_written = 0
expected_length: int | None = None
try:
with tempfile.NamedTemporaryFile(
dir=save_path.parent,
prefix=f".{save_path.name}.",
suffix=".tmp",
delete=False,
) as temp_file:
temp_path = Path(temp_file.name)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
response.raise_for_status()
content_length = response.headers.get("content-length")
if content_length is not None:
expected_length = int(content_length)
async for chunk in response.aiter_bytes():
bytes_written += len(chunk)
temp_file.write(chunk)
temp_file.flush()
os.fsync(temp_file.fileno())
if expected_length is not None and bytes_written != expected_length:
raise OSError(
f"Downloaded image is incomplete: got {bytes_written} bytes, "
f"expected {expected_length} bytes"
)
assert temp_path is not None
_validate_image_file(temp_path)
temp_path.replace(save_path)
return save_path
except Exception:
if temp_path is not None:
temp_path.unlink(missing_ok=True)
raise
def _validate_image_file(image_path: Path) -> None:
"""Verify that the downloaded image can be decoded completely."""
with Image.open(image_path) as image:
image.verify()
with Image.open(image_path) as image:
image.load()
def mime_type_to_suffix(mime_type: str) -> str:
"""Convert MIME type to file suffix.
Args:
mime_type: MIME type.
Returns:
str: File suffix.
"""
if mime_type == "image/jpeg":
return ".jpg"
elif mime_type == "image/png":
return ".png"
elif mime_type == "image/webp":
return ".webp"
else:
return ".png"
ASPECT_RATIO_LITERALS = Literal[
"2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "1:1", "16:9", "9:16", "9:21"
]
BUCKETS_1K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
"2:3": (1088, 1632),
"3:2": (1632, 1088),
"3:4": (1152, 1536),
"4:3": (1536, 1152),
"4:5": (1184, 1472),
"5:4": (1472, 1184),
"1:1": (1344, 1344),
"16:9": (1792, 992),
"9:16": (992, 1792),
"9:21": (864, 2048),
}
BUCKETS_2K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
"2:3": (1664, 2496),
"3:2": (2496, 1664),
"3:4": (1760, 2368),
"4:3": (2368, 1760),
"4:5": (1824, 2272),
"5:4": (2272, 1824),
"1:1": (2048, 2048),
"16:9": (2752, 1536),
"9:16": (1536, 2752),
"9:21": (1344, 3136),
}
def _find_nearest_aspect_ratio(
ratio: float,
buckets: dict[ASPECT_RATIO_LITERALS, tuple[int, int]],
) -> tuple[int, int]:
wh_pairs = sorted(
buckets.values(),
key=lambda wh: abs(wh[0] / wh[1] - ratio),
)
return wh_pairs[0]
if __name__ == "__main__":
import asyncio
async def main_async():
client = SensenovaText2ImageClient(
api_key=global_configs.SN_IMAGE_GEN_API_KEY,
base_url=global_configs.SN_IMAGE_GEN_BASE_URL,
)
result = await client.generate(
prompt="A cat wearing a hat",
image_size="1K",
aspect_ratio="16:9",
)
print(result)
asyncio.run(main_async())