first commit
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from .nano_banana import NanoBananaText2ImageClient
|
||||
from .openai_image import OpenAIImageGenerationClient
|
||||
from .sensenova import SensenovaText2ImageClient
|
||||
|
||||
__all__ = [
|
||||
"NanoBananaText2ImageClient",
|
||||
"OpenAIImageGenerationClient",
|
||||
"SensenovaText2ImageClient",
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def ensure_output_path(path: Path) -> Path:
|
||||
"""Ensure the parent directory of the given path exists.
|
||||
|
||||
Args:
|
||||
path (Path):
|
||||
The file path whose parent directory should be created.
|
||||
|
||||
Returns:
|
||||
Path:
|
||||
The original path unchanged.
|
||||
"""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from sn_image_base.utils.error_utils import U1HttpResponseParseError
|
||||
from sn_image_base.utils.httpx_client import (
|
||||
create_async_httpx_client,
|
||||
httpx_response_raise_for_status_code,
|
||||
)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT = 300.0
|
||||
DEFAULT_MAX_CONNECTIONS = 100
|
||||
|
||||
|
||||
class T2IBaseClient(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self.model = model
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._max_connections = max_connections
|
||||
self._timeout = timeout
|
||||
self._ssl_verify = ssl_verify
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = create_async_httpx_client(
|
||||
self.headers,
|
||||
timeout=self._timeout,
|
||||
max_connections=self._max_connections,
|
||||
verify=self._ssl_verify,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self._api_key
|
||||
|
||||
@property
|
||||
def base_url(self) -> str | None:
|
||||
return self._base_url
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_api_url(self, *args: Any, **kwargs: Any) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def build_payload(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def headers(self) -> dict[str, str]: ...
|
||||
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
httpx_response_raise_for_status_code(response)
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except ValueError as exc:
|
||||
raise U1HttpResponseParseError(
|
||||
detail=f"Failed to parse HTTP response. {response.request.url}. Response content: {response.content}",
|
||||
code=response.status_code,
|
||||
) from exc
|
||||
306
sn-image-base/scripts/sn_image_base/generation/nano_banana.py
Normal file
306
sn-image-base/scripts/sn_image_base/generation/nano_banana.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from typing_extensions import override
|
||||
|
||||
from sn_image_base.configs import global_configs, is_valid_base_url
|
||||
from sn_image_base.utils.error_utils import U1HttpErrorBase
|
||||
|
||||
from .core import ensure_output_path
|
||||
from .core.client_base import (
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
DEFAULT_MAX_CONNECTIONS,
|
||||
T2IBaseClient,
|
||||
)
|
||||
|
||||
DEFAULT_MODEL_SIZE: Literal["1K", "2K", "4K"] = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "16:9"
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
|
||||
|
||||
|
||||
class NanoBananaText2ImageClient(T2IBaseClient):
|
||||
"""Async client for Google Nano Banana API."""
|
||||
|
||||
# requires `{model}` placeholder for format string
|
||||
DEFAULT_API_PATH = "/v1beta/models/{model}:generateContent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the NanoBananaText2ImageClient.
|
||||
|
||||
Args:
|
||||
api_key (str):
|
||||
API key for authentication.
|
||||
base_url (str | None, optional):
|
||||
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
|
||||
model (str | None, optional):
|
||||
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
|
||||
max_connections (int, optional):
|
||||
Maximum number of connections. Defaults to 100.
|
||||
timeout (float, optional):
|
||||
Total timeout in seconds for HTTP requests.
|
||||
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
|
||||
ssl_verify (bool, optional):
|
||||
If True, enable TLS verification. Defaults to True.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
ssl_verify=ssl_verify,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_size: Literal["1K", "2K", "4K"] = DEFAULT_MODEL_SIZE,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
output_path: Path | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Generate an image from text prompt.
|
||||
|
||||
Args:
|
||||
prompt (str):
|
||||
Text prompt for image generation.
|
||||
negative_prompt (str, optional):
|
||||
Negative prompt. Defaults to "".
|
||||
model (str | None, optional):
|
||||
Model name override. Defaults to None.
|
||||
image_size (str, optional):
|
||||
Image size preset ("1K", "2K", "4K"). Defaults to DEFAULT_MODEL_SIZE.
|
||||
aspect_ratio (str, optional):
|
||||
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
|
||||
output_path (Path | None, optional):
|
||||
Output path for the generated image. Defaults to None.
|
||||
**kwargs:
|
||||
Additional arguments reserved for backend compatibility.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
Dictionary with keys: status, output (path), message.
|
||||
"""
|
||||
model = model or self.model
|
||||
# Normalize image_size to uppercase for NanoBanana API
|
||||
image_size = image_size.upper() # type: ignore[assignment]
|
||||
payload = self.build_payload(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image_size=image_size,
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
headers = self.headers
|
||||
api_url = self.get_api_url(model)
|
||||
|
||||
if output_path is None:
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
|
||||
output_path = ensure_output_path(output_path)
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
create_response = await client.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
data = self.parse_response(create_response)
|
||||
except U1HttpErrorBase as exc:
|
||||
details = exc.detail or ""
|
||||
field_name = None
|
||||
if exc.code == 404:
|
||||
field_name = "SN_IMAGE_GEN_BASE_URL"
|
||||
elif exc.code == 401:
|
||||
field_name = "SN_IMAGE_GEN_API_KEY"
|
||||
if field_name is not None:
|
||||
field_hint = global_configs.get_annotated_field(field_name)
|
||||
if field_hint is not None:
|
||||
env_names = list(field_hint.env_names) if field_hint.env_names else []
|
||||
if env_names:
|
||||
if len(env_names) == 1:
|
||||
details += (
|
||||
f"\nIs the environment variable `{env_names[0]}` set correctly?"
|
||||
)
|
||||
else:
|
||||
env_names_str = ", ".join([f"`{n}`" for n in env_names])
|
||||
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.code}: {exc.message}",
|
||||
"message": details,
|
||||
}
|
||||
try:
|
||||
images = data["images"]
|
||||
if not images:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "No image generated from the model",
|
||||
}
|
||||
image, mime_type = images[-1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
suffix = mime_type_to_suffix(mime_type)
|
||||
saved_path = output_path.with_suffix(suffix)
|
||||
saved_path.write_bytes(image_bytes)
|
||||
return {
|
||||
"status": "ok",
|
||||
"output": str(saved_path),
|
||||
"message": "Image generated successfully",
|
||||
}
|
||||
except httpx.HTTPStatusError as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.response.status_code}",
|
||||
"message": f"http error: {exc.response.status_code} {exc.response.text}",
|
||||
}
|
||||
except (httpx.HTTPError, OSError, ValueError) as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": type(exc).__name__,
|
||||
"message": f"request error: {exc}",
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def api_key(self) -> str:
|
||||
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return api_key
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise ValueError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
@override
|
||||
def get_api_url(self, model: str | None = None) -> str:
|
||||
model = model or self.model
|
||||
path = self.DEFAULT_API_PATH.format(model=model).lstrip("/")
|
||||
api_url = f"{self.base_url.rstrip('/')}/{path}"
|
||||
return api_url
|
||||
|
||||
@override
|
||||
def build_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
*,
|
||||
image_size: Literal["1K", "2K", "4K"] = DEFAULT_MODEL_SIZE,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
max_output_tokens: int = 8192,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
parts: list[dict] = [{"text": prompt}]
|
||||
if (image_b64 := kwargs.get("image_b64")) and (
|
||||
image_mime_type := kwargs.get("image_mime_type")
|
||||
):
|
||||
if image_mime_type not in ["image/jpeg", "image/png"]:
|
||||
msg = (
|
||||
f"Unsupported image MIME type: {image_mime_type}. "
|
||||
"Supported types: image/jpeg, image/png"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
parts.append({"inline_data": {"mime_type": image_mime_type, "data": image_b64}})
|
||||
return {
|
||||
"contents": [{"role": "USER", "parts": parts}],
|
||||
"generationConfig": {
|
||||
"imageConfig": {"aspectRatio": aspect_ratio, "imageSize": image_size},
|
||||
"maxOutputTokens": max_output_tokens,
|
||||
},
|
||||
"safetySettings": [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
||||
],
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"x-goog-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@override
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
raw_data = super().parse_response(response)
|
||||
|
||||
images: list[tuple[str, str]] = []
|
||||
finish_reasons: list[str] = []
|
||||
candidates: list[dict] = raw_data.get("candidates") or []
|
||||
for c in candidates:
|
||||
content: dict[str, Any] = c.get("content") or {}
|
||||
parts: list[dict[str, Any]] = content.get("parts") or []
|
||||
if f_reason := content.get("finishReason"):
|
||||
finish_reasons.append(f_reason)
|
||||
for p in parts:
|
||||
inline_data: dict[str, Any] = p.get("inlineData", {})
|
||||
if inline_data:
|
||||
mime_type: str = inline_data.get("mimeType") # pyright: ignore[reportAssignmentType]
|
||||
data: str = inline_data.get("data") # pyright: ignore[reportAssignmentType]
|
||||
images.append((data, mime_type))
|
||||
return {
|
||||
"images": images,
|
||||
"finish_reasons": finish_reasons,
|
||||
}
|
||||
|
||||
|
||||
def mime_type_to_suffix(mime_type: str) -> str:
|
||||
"""Convert MIME type to file suffix.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type.
|
||||
|
||||
Returns:
|
||||
str: File suffix.
|
||||
"""
|
||||
if mime_type == "image/jpeg":
|
||||
return ".jpg"
|
||||
elif mime_type == "image/png":
|
||||
return ".png"
|
||||
elif mime_type == "image/webp":
|
||||
return ".webp"
|
||||
else:
|
||||
return ".png"
|
||||
366
sn-image-base/scripts/sn_image_base/generation/openai_image.py
Normal file
366
sn-image-base/scripts/sn_image_base/generation/openai_image.py
Normal file
@@ -0,0 +1,366 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from typing_extensions import override
|
||||
|
||||
from sn_image_base.configs import global_configs, is_valid_base_url
|
||||
from sn_image_base.exceptions import BadConfigurationError
|
||||
from sn_image_base.utils.error_utils import U1HttpErrorBase
|
||||
|
||||
from .core import ensure_output_path
|
||||
from .core.client_base import (
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
DEFAULT_MAX_CONNECTIONS,
|
||||
T2IBaseClient,
|
||||
)
|
||||
|
||||
DEFAULT_RESOLUTION: Literal["1K", "2K"] = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "16:9"
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
|
||||
|
||||
B64_PARSE_PATTERN = re.compile(r"^data:([a-zA-Z0-9/]+?);base64,([+-/_A-Za-z0-9]+=*)$")
|
||||
|
||||
|
||||
class OpenAIImageGenerationClient(T2IBaseClient):
|
||||
"""Async client for OpenAI Image Generation API."""
|
||||
|
||||
DEFAULT_API_PATH = "/images/generations"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAIImageGenerationClient.
|
||||
|
||||
Args:
|
||||
api_key (str):
|
||||
API key for authentication.
|
||||
base_url (str | None, optional):
|
||||
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
|
||||
model (str | None, optional):
|
||||
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
|
||||
max_connections (int, optional):
|
||||
Maximum number of connections. Defaults to 100.
|
||||
timeout (float, optional):
|
||||
Total timeout in seconds for HTTP requests.
|
||||
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
|
||||
ssl_verify (bool, optional):
|
||||
If True, enable TLS verification. Defaults to True.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
ssl_verify=ssl_verify,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_size: Literal["1K", "2K", "1k", "2k"] | None = None,
|
||||
aspect_ratio: str | None = DEFAULT_ASPECT_RATIO,
|
||||
output_path: Path | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Generate an image from text prompt.
|
||||
|
||||
Args:
|
||||
prompt (str):
|
||||
Text prompt for image generation.
|
||||
model (str | None, optional):
|
||||
Model name override. Defaults to None.
|
||||
image_size (str, optional):
|
||||
Image size preset ("1K", "2K"). Defaults to DEFAULT_RESOLUTION.
|
||||
aspect_ratio (str, optional):
|
||||
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
|
||||
output_path (Path | None, optional):
|
||||
Output path for the generated image. Defaults to None.
|
||||
**kwargs:
|
||||
Additional arguments reserved for backend compatibility.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
Dictionary with keys: status, output (path), message.
|
||||
"""
|
||||
model = model or self.model or global_configs.SN_IMAGE_GEN_MODEL
|
||||
if not model:
|
||||
raise BadConfigurationError(
|
||||
f"Model is not set. {global_configs.get_env_var_help('SN_IMAGE_GEN_MODEL')}"
|
||||
)
|
||||
image_size = image_size or DEFAULT_RESOLUTION
|
||||
if aspect_ratio is None:
|
||||
size = None
|
||||
else:
|
||||
rw, _, rh = aspect_ratio.partition(":")
|
||||
try:
|
||||
aspect_ratio_val: float = float(int(rw) / int(rh))
|
||||
except (ValueError, ZeroDivisionError) as e:
|
||||
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}") from e
|
||||
size = self._resolve_size(
|
||||
resolution=image_size,
|
||||
aspect_ratio_val=aspect_ratio_val,
|
||||
)
|
||||
payload = self.build_payload(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
)
|
||||
headers = self.headers
|
||||
api_url = self.get_api_url(model)
|
||||
|
||||
if output_path is None:
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
|
||||
output_path = ensure_output_path(output_path)
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
create_response = await client.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
data = self.parse_response(create_response)
|
||||
except U1HttpErrorBase as exc:
|
||||
details = exc.detail or ""
|
||||
field_name = None
|
||||
if exc.code == 404:
|
||||
field_name = "SN_IMAGE_GEN_BASE_URL"
|
||||
elif exc.code == 401:
|
||||
field_name = "SN_IMAGE_GEN_API_KEY"
|
||||
if field_name is not None:
|
||||
field_hint = global_configs.get_annotated_field(field_name)
|
||||
if field_hint is not None:
|
||||
env_names = list(field_hint.env_names) if field_hint.env_names else []
|
||||
if env_names:
|
||||
if len(env_names) == 1:
|
||||
details += (
|
||||
f"\nIs the environment variable `{env_names[0]}` set correctly?"
|
||||
)
|
||||
else:
|
||||
env_names_str = ", ".join([f"`{n}`" for n in env_names])
|
||||
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.code}: {exc.message}",
|
||||
"message": details,
|
||||
}
|
||||
try:
|
||||
images = data["images"]
|
||||
if not images:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "No image generated from the model",
|
||||
}
|
||||
image_bytes, mime_type = images[-1]
|
||||
suffix = mime_type_to_suffix(mime_type)
|
||||
saved_path = output_path.with_suffix(suffix)
|
||||
saved_path.write_bytes(image_bytes)
|
||||
return {
|
||||
"status": "ok",
|
||||
"output": str(saved_path),
|
||||
"message": "Image generated successfully",
|
||||
}
|
||||
except httpx.HTTPStatusError as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.response.status_code}",
|
||||
"message": f"http error: {exc.response.status_code} {exc.response.text}",
|
||||
}
|
||||
except (httpx.HTTPError, OSError, ValueError) as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": type(exc).__name__,
|
||||
"message": f"request error: {exc}",
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def api_key(self) -> str:
|
||||
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return api_key
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise ValueError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
@override
|
||||
def get_api_url(self, model: str | None = None) -> str:
|
||||
model = model or self.model
|
||||
path = self.DEFAULT_API_PATH.format(model=model).lstrip("/")
|
||||
api_url = f"{self.base_url.rstrip('/')}/{path}"
|
||||
return api_url
|
||||
|
||||
@override
|
||||
def build_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
*,
|
||||
n: int = 1,
|
||||
size: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Example:
|
||||
{
|
||||
"model": "dall-e-3",
|
||||
"prompt": "一只戴着墨镜的猫在赛博朋克城市的街道上喝咖啡, 赛璐璐画风",
|
||||
"n": 1,
|
||||
"size": "1024x1024",
|
||||
"response_format": "b64_json",
|
||||
}
|
||||
"""
|
||||
size = size or "auto"
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"n": n,
|
||||
"size": size,
|
||||
"response_format": "b64_json",
|
||||
**kwargs,
|
||||
}
|
||||
return payload
|
||||
|
||||
@property
|
||||
@override
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@override
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
"""
|
||||
Example:
|
||||
{
|
||||
"data": [{
|
||||
"b64_json": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOYA3Q..."
|
||||
}],
|
||||
"created": 1776789055
|
||||
"usage": {
|
||||
"input_tokens":773,
|
||||
"output_tokens":765,
|
||||
"total_tokens":1538,
|
||||
"input_tokens_details": {
|
||||
"text_tokens":8,
|
||||
"image_tokens":765
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
raw_data = super().parse_response(response)
|
||||
|
||||
images: list[tuple[bytes, str]] = []
|
||||
data_items: list[dict] = raw_data.get("data") or []
|
||||
for item in data_items:
|
||||
encoded: str = item.get("b64_json") or ""
|
||||
if not encoded:
|
||||
continue
|
||||
|
||||
if encoded.startswith("data:"):
|
||||
match = B64_PARSE_PATTERN.match(encoded)
|
||||
if match:
|
||||
mime_type = match.group(1)
|
||||
b64_data = match.group(2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid base64 data in response: {encoded[:100]}... (truncated)"
|
||||
)
|
||||
else:
|
||||
mime_type = "image/png" # fallback to png
|
||||
b64_data = encoded
|
||||
try:
|
||||
decoded = base64.b64decode(b64_data)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to decode base64 data in response: {e}. b64_json: {encoded[:100]}... (truncated)"
|
||||
) from e
|
||||
images.append((decoded, mime_type))
|
||||
return {
|
||||
"images": images,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _resolve_size(
|
||||
cls,
|
||||
resolution: str,
|
||||
aspect_ratio_val: float | None,
|
||||
) -> str:
|
||||
"""Convert (resolution, aspect_ratio) to a pixel size string."""
|
||||
resolution = resolution.upper()
|
||||
if resolution == "1K":
|
||||
max_pixel = 1024**2
|
||||
elif resolution == "2K":
|
||||
max_pixel = 2048**2
|
||||
else:
|
||||
raise ValueError(f"Unsupported resolution: {resolution}")
|
||||
aspect_ratio_val = aspect_ratio_val or 1
|
||||
if aspect_ratio_val < 1 / 3 or aspect_ratio_val > 3:
|
||||
raise ValueError(f"Aspect ratio value must be between [1/3, 3], got {aspect_ratio_val}")
|
||||
|
||||
width: int = round(math.sqrt(max_pixel * aspect_ratio_val))
|
||||
height: int = round(math.sqrt(max_pixel / aspect_ratio_val))
|
||||
return f"{width}x{height}"
|
||||
|
||||
|
||||
def mime_type_to_suffix(mime_type: str) -> str:
|
||||
"""Convert MIME type to file suffix.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type.
|
||||
|
||||
Returns:
|
||||
str: File suffix.
|
||||
"""
|
||||
if mime_type == "image/jpeg":
|
||||
return ".jpg"
|
||||
elif mime_type == "image/png":
|
||||
return ".png"
|
||||
elif mime_type == "image/webp":
|
||||
return ".webp"
|
||||
else:
|
||||
return ".png"
|
||||
508
sn-image-base/scripts/sn_image_base/generation/sensenova.py
Normal file
508
sn-image-base/scripts/sn_image_base/generation/sensenova.py
Normal file
@@ -0,0 +1,508 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
from sn_image_base.configs import global_configs, is_valid_base_url
|
||||
from sn_image_base.exceptions import InvalidBaseUrlError, MissingApiKeyError
|
||||
from sn_image_base.generation.core import ensure_output_path
|
||||
from sn_image_base.generation.core.client_base import (
|
||||
DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
DEFAULT_MAX_CONNECTIONS,
|
||||
T2IBaseClient,
|
||||
)
|
||||
from sn_image_base.utils.error_utils import U1HttpErrorBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
DEFAULT_RESOLUTION: Literal["1K", "2K", "4K"] = "2K"
|
||||
DEFAULT_ASPECT_RATIO = "16:9"
|
||||
DEFAULT_POLL_INTERVAL = 5.0
|
||||
OUTPUT_DIR = Path("/tmp/openclaw-sn-image")
|
||||
|
||||
|
||||
IMAGE_GEN_ENDPOINT = "/images/generations"
|
||||
|
||||
|
||||
class SensenovaText2ImageClient(T2IBaseClient):
|
||||
"""Async client for Sensenova text-to-image API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str | None = None,
|
||||
*,
|
||||
model: str | None = None,
|
||||
max_connections: int = DEFAULT_MAX_CONNECTIONS,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
ssl_verify: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the SensenovaText2ImageClient.
|
||||
|
||||
Args:
|
||||
api_key (str):
|
||||
API key for authentication.
|
||||
base_url (str | None, optional):
|
||||
API base URL. If None, reads from SN_IMAGE_GEN_BASE_URL env var.
|
||||
model (str | None, optional):
|
||||
Model name. If None, reads from SN_IMAGE_GEN_MODEL env var.
|
||||
max_connections (int, optional):
|
||||
Maximum number of connections. Defaults to 100.
|
||||
timeout (float, optional):
|
||||
Total timeout in seconds for HTTP requests.
|
||||
Defaults to DEFAULT_HTTP_REQUEST_TIMEOUT.
|
||||
ssl_verify (bool, optional):
|
||||
If True, enable TLS verification. Defaults to True.
|
||||
"""
|
||||
api_key = api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise MissingApiKeyError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
base_url = base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise InvalidBaseUrlError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise InvalidBaseUrlError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
max_connections=max_connections,
|
||||
timeout=timeout,
|
||||
ssl_verify=ssl_verify,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
*,
|
||||
model: str | None = None,
|
||||
image_size: Literal["1K", "2K", "4K"] = DEFAULT_RESOLUTION,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
output_path: Path | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Generate an image from text prompt.
|
||||
|
||||
Args:
|
||||
prompt (str):
|
||||
Text prompt for image generation.
|
||||
negative_prompt (str, optional):
|
||||
Negative prompt. Defaults to "".
|
||||
model (str | None, optional):
|
||||
Model name override. Defaults to None.
|
||||
image_size (str, optional):
|
||||
Image size preset ("1K", "2K", "4K"). Defaults to DEFAULT_RESOLUTION.
|
||||
aspect_ratio (str, optional):
|
||||
Aspect ratio (e.g. "16:9", "1:1"). Defaults to DEFAULT_ASPECT_RATIO.
|
||||
output_path (Path | None, optional):
|
||||
Output path for the generated image. Defaults to None.
|
||||
**kwargs:
|
||||
Additional arguments reserved for backend compatibility.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
Dictionary with keys: status, output (path), message.
|
||||
"""
|
||||
model = model or self.model or global_configs.SN_IMAGE_GEN_MODEL
|
||||
# Normalize image_size to uppercase for NanoBanana API
|
||||
image_size = image_size.upper() # type: ignore[assignment]
|
||||
output_format = "png"
|
||||
size = self._resolve_size(image_size, aspect_ratio)
|
||||
payload = self.build_payload(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
aspect_ratio=aspect_ratio,
|
||||
output_format=output_format,
|
||||
)
|
||||
headers = self.headers
|
||||
api_url = self.get_api_url(model)
|
||||
if output_path is None:
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
output_path = OUTPUT_DIR / f"t2i_{timestamp}.png"
|
||||
output_path = ensure_output_path(output_path)
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
create_response = await client.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
data = self.parse_response(create_response)
|
||||
except U1HttpErrorBase as exc:
|
||||
details = exc.detail or ""
|
||||
field_name = None
|
||||
if exc.code == 404:
|
||||
field_name = "SN_IMAGE_GEN_BASE_URL"
|
||||
elif exc.code == 401:
|
||||
field_name = "SN_IMAGE_GEN_API_KEY"
|
||||
# elif exc.code == 400:
|
||||
# warnings.warn(f"Bad request: {exc.message}; body: {payload}", stacklevel=2)
|
||||
if field_name is not None:
|
||||
field_hint = global_configs.get_annotated_field(field_name)
|
||||
if field_hint is not None:
|
||||
env_names = list(field_hint.env_names) if field_hint.env_names else []
|
||||
if env_names:
|
||||
if len(env_names) == 1:
|
||||
details += (
|
||||
f"\nIs the environment variable `{env_names[0]}` set correctly?"
|
||||
)
|
||||
else:
|
||||
env_names_str = ", ".join([f"`{n}`" for n in env_names])
|
||||
details += f"\nIs any of the following environment variable(s) set correctly: {env_names_str}?"
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.code}: {exc.message}",
|
||||
"message": details,
|
||||
}
|
||||
try:
|
||||
images_urls: list[str] = data["images_urls"]
|
||||
if not images_urls:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "No image generated from the model",
|
||||
}
|
||||
url = images_urls[-1]
|
||||
suffix = f".{output_format}"
|
||||
save_path = output_path.with_suffix(suffix)
|
||||
saved_path = await download_image(url, save_path)
|
||||
return {
|
||||
"status": "ok",
|
||||
"output": str(saved_path),
|
||||
"message": "Image generated successfully",
|
||||
}
|
||||
except httpx.HTTPStatusError as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": f"HTTP {exc.response.status_code}",
|
||||
"message": f"http error: {exc.response.status_code} {exc.response.text}",
|
||||
}
|
||||
except (httpx.HTTPError, OSError, ValueError) as exc:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": type(exc).__name__,
|
||||
"message": f"request error: {exc}",
|
||||
}
|
||||
|
||||
@property
|
||||
@override
|
||||
def api_key(self) -> str:
|
||||
api_key = self._api_key or global_configs.SN_IMAGE_GEN_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return api_key
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
base_url = self._base_url or global_configs.SN_IMAGE_GEN_BASE_URL
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"Base URL is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_BASE_URL")
|
||||
)
|
||||
)
|
||||
if not is_valid_base_url(base_url):
|
||||
raise ValueError(
|
||||
f"Base URL is not a valid base URL: {base_url}. "
|
||||
f"Try setting environment variable(s): {global_configs.get_env_var_help('SN_IMAGE_GEN_BASE_URL')}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
@override
|
||||
def get_api_url(self, _model: str | None = None) -> str:
|
||||
base_url = self.base_url.rstrip("/")
|
||||
path = IMAGE_GEN_ENDPOINT.lstrip("/")
|
||||
api_url = f"{base_url}/{path}"
|
||||
return api_url
|
||||
|
||||
@override
|
||||
def build_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
*,
|
||||
size: str | None = None,
|
||||
modalities: Sequence[str] = ("text", "image"),
|
||||
output_format: Literal["png"] = "png",
|
||||
response_format: Literal["url"] = "url",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the payload for the SenseNova image-generation endpoint.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to generate an image for.
|
||||
model (str): The model to use for generation.
|
||||
size (str | None): Pixel size string (for example, "1920x1920").
|
||||
modalities (Sequence[str]): Reserved for compatibility; currently not sent.
|
||||
output_format (Literal["png"]): The output format of the image. Defaults to "png".
|
||||
response_format (Literal["url"]): The response format of the image. Defaults to "url".
|
||||
**kwargs (Any, optional): Additional parameters to pass to the API.
|
||||
|
||||
Example:
|
||||
{
|
||||
"model": "sensenova-u1-fast",
|
||||
"prompt": "A cat wearing a hat",
|
||||
"size": "1024x1024",
|
||||
"response_format": "url",
|
||||
"output_format": "png",
|
||||
}
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
# "modalities": modalities,
|
||||
"size": size,
|
||||
# "n": 1,
|
||||
"response_format": response_format,
|
||||
"output_format": output_format,
|
||||
**kwargs,
|
||||
}
|
||||
return payload
|
||||
|
||||
@property
|
||||
@override
|
||||
def headers(self) -> dict[str, str]:
|
||||
if not self.api_key:
|
||||
raise MissingApiKeyError(
|
||||
"API key is missing: {}".format(
|
||||
global_configs.get_env_var_help("SN_IMAGE_GEN_API_KEY")
|
||||
)
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _resolve_size(
|
||||
cls,
|
||||
resolution: Literal["1K", "2K"] | str | None = None,
|
||||
aspect_ratio: ASPECT_RATIO_LITERALS | str | None = None,
|
||||
) -> str | None:
|
||||
"""Convert (resolution, aspect_ratio) to a pixel size string.
|
||||
|
||||
If aspect_ratio is None, returns the resolution as-is (e.g. "1K").
|
||||
"""
|
||||
if not resolution and not aspect_ratio:
|
||||
return None
|
||||
resolution = resolution or "2K"
|
||||
aspect_ratio = aspect_ratio or "1:1"
|
||||
if resolution == "1K":
|
||||
buckets = BUCKETS_1K
|
||||
elif resolution == "2K":
|
||||
buckets = BUCKETS_2K
|
||||
else:
|
||||
raise ValueError(f"Unsupported resolution: {resolution!r}. Must be '1K' or '2K'.")
|
||||
try:
|
||||
ws, _, hs = aspect_ratio.strip().partition(":")
|
||||
width = int(ws)
|
||||
height = int(hs)
|
||||
ratio = width / height
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid aspect ratio: {aspect_ratio!r}") from e
|
||||
if ratio > 16 / 9:
|
||||
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too wide. Maximum is 16:9")
|
||||
if ratio < 9 / 21:
|
||||
raise ValueError(f"Aspect ratio {aspect_ratio!r} is too high. Maximum is 9:21")
|
||||
w, h = _find_nearest_aspect_ratio(ratio, buckets)
|
||||
return f"{w}x{h}"
|
||||
|
||||
@override
|
||||
def parse_response(self, response: httpx.Response) -> dict:
|
||||
"""Parse the response from the SenseNova image-generation endpoint.
|
||||
|
||||
Example response data:
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [{
|
||||
"url": "https://cdn.sensenova.dev/gen/..."
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
Args:
|
||||
response: The HTTP response from the SenseNova image-generation endpoint.
|
||||
|
||||
Returns:
|
||||
dict: Parsed data with key ``images_urls``.
|
||||
"""
|
||||
raw_data = super().parse_response(response)
|
||||
|
||||
images_urls: list[str] = []
|
||||
for item in raw_data.get("data", []):
|
||||
url = item.get("url")
|
||||
if url:
|
||||
images_urls.append(url)
|
||||
return {"images_urls": images_urls}
|
||||
|
||||
|
||||
async def download_image(
|
||||
url: str,
|
||||
save_path: Path,
|
||||
timeout: float = DEFAULT_HTTP_REQUEST_TIMEOUT,
|
||||
) -> Path:
|
||||
"""Download an image from a URL.
|
||||
|
||||
Args:
|
||||
url: The URL of the image to download.
|
||||
timeout: The timeout for the request.
|
||||
|
||||
Returns:
|
||||
Path: The path to the downloaded image file.
|
||||
"""
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
temp_path: Path | None = None
|
||||
bytes_written = 0
|
||||
expected_length: int | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=save_path.parent,
|
||||
prefix=f".{save_path.name}.",
|
||||
suffix=".tmp",
|
||||
delete=False,
|
||||
) as temp_file:
|
||||
temp_path = Path(temp_file.name)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length is not None:
|
||||
expected_length = int(content_length)
|
||||
async for chunk in response.aiter_bytes():
|
||||
bytes_written += len(chunk)
|
||||
temp_file.write(chunk)
|
||||
temp_file.flush()
|
||||
os.fsync(temp_file.fileno())
|
||||
|
||||
if expected_length is not None and bytes_written != expected_length:
|
||||
raise OSError(
|
||||
f"Downloaded image is incomplete: got {bytes_written} bytes, "
|
||||
f"expected {expected_length} bytes"
|
||||
)
|
||||
|
||||
assert temp_path is not None
|
||||
_validate_image_file(temp_path)
|
||||
temp_path.replace(save_path)
|
||||
return save_path
|
||||
except Exception:
|
||||
if temp_path is not None:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
def _validate_image_file(image_path: Path) -> None:
|
||||
"""Verify that the downloaded image can be decoded completely."""
|
||||
with Image.open(image_path) as image:
|
||||
image.verify()
|
||||
with Image.open(image_path) as image:
|
||||
image.load()
|
||||
|
||||
|
||||
def mime_type_to_suffix(mime_type: str) -> str:
|
||||
"""Convert MIME type to file suffix.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type.
|
||||
|
||||
Returns:
|
||||
str: File suffix.
|
||||
"""
|
||||
if mime_type == "image/jpeg":
|
||||
return ".jpg"
|
||||
elif mime_type == "image/png":
|
||||
return ".png"
|
||||
elif mime_type == "image/webp":
|
||||
return ".webp"
|
||||
else:
|
||||
return ".png"
|
||||
|
||||
|
||||
ASPECT_RATIO_LITERALS = Literal[
|
||||
"2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "1:1", "16:9", "9:16", "9:21"
|
||||
]
|
||||
BUCKETS_1K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
|
||||
"2:3": (1088, 1632),
|
||||
"3:2": (1632, 1088),
|
||||
"3:4": (1152, 1536),
|
||||
"4:3": (1536, 1152),
|
||||
"4:5": (1184, 1472),
|
||||
"5:4": (1472, 1184),
|
||||
"1:1": (1344, 1344),
|
||||
"16:9": (1792, 992),
|
||||
"9:16": (992, 1792),
|
||||
"9:21": (864, 2048),
|
||||
}
|
||||
BUCKETS_2K: dict[ASPECT_RATIO_LITERALS, tuple[int, int]] = {
|
||||
"2:3": (1664, 2496),
|
||||
"3:2": (2496, 1664),
|
||||
"3:4": (1760, 2368),
|
||||
"4:3": (2368, 1760),
|
||||
"4:5": (1824, 2272),
|
||||
"5:4": (2272, 1824),
|
||||
"1:1": (2048, 2048),
|
||||
"16:9": (2752, 1536),
|
||||
"9:16": (1536, 2752),
|
||||
"9:21": (1344, 3136),
|
||||
}
|
||||
|
||||
|
||||
def _find_nearest_aspect_ratio(
|
||||
ratio: float,
|
||||
buckets: dict[ASPECT_RATIO_LITERALS, tuple[int, int]],
|
||||
) -> tuple[int, int]:
|
||||
wh_pairs = sorted(
|
||||
buckets.values(),
|
||||
key=lambda wh: abs(wh[0] / wh[1] - ratio),
|
||||
)
|
||||
return wh_pairs[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main_async():
|
||||
client = SensenovaText2ImageClient(
|
||||
api_key=global_configs.SN_IMAGE_GEN_API_KEY,
|
||||
base_url=global_configs.SN_IMAGE_GEN_BASE_URL,
|
||||
)
|
||||
|
||||
result = await client.generate(
|
||||
prompt="A cat wearing a hat",
|
||||
image_size="1K",
|
||||
aspect_ratio="16:9",
|
||||
)
|
||||
print(result)
|
||||
|
||||
asyncio.run(main_async())
|
||||
Reference in New Issue
Block a user