first commit
This commit is contained in:
231
sn-image-base/scripts/sn_image_base/utils/error_utils.py
Normal file
231
sn-image-base/scripts/sn_image_base/utils/error_utils.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any
|
||||
|
||||
|
||||
class U1BaseError(Exception):
|
||||
MESSAGE = "Base error"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
detail: str | None = None,
|
||||
code: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if message is None:
|
||||
message = self.MESSAGE
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.detail = detail
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.code:
|
||||
msg = f"{self.__class__.__name__}[{self.code}]"
|
||||
else:
|
||||
msg = f"{self.__class__.__name__}"
|
||||
if self.message:
|
||||
msg += f"(message={self.message!r})"
|
||||
if self.detail:
|
||||
msg += f" <detail>{self.detail}</detail>"
|
||||
return msg
|
||||
|
||||
|
||||
# ----------------------
|
||||
# HTTP Errors
|
||||
# ----------------------
|
||||
|
||||
|
||||
class U1HttpErrorBase(U1BaseError):
|
||||
MESSAGE = "Base HTTP Error"
|
||||
|
||||
|
||||
class U1HttpAuthError(U1HttpErrorBase):
|
||||
MESSAGE = "Authentication or Authorization Failed"
|
||||
|
||||
|
||||
class U1HttpNotFoundError(U1HttpErrorBase):
|
||||
MESSAGE = "Resource Not Found"
|
||||
|
||||
|
||||
class U1HttpTooManyRequestsError(U1HttpErrorBase):
|
||||
MESSAGE = "Too Many Requests"
|
||||
|
||||
|
||||
class U1HttpServerError(U1HttpErrorBase):
|
||||
MESSAGE = "Server Error"
|
||||
|
||||
|
||||
class U1HttpBadRequestError(U1HttpErrorBase):
|
||||
MESSAGE = "Bad Request"
|
||||
|
||||
|
||||
class U1HttpPermissionError(U1HttpErrorBase):
|
||||
MESSAGE = "Permission Error"
|
||||
|
||||
|
||||
class U1HttpResponseParseError(U1HttpErrorBase):
|
||||
MESSAGE = "Failed to parse HTTP response"
|
||||
|
||||
|
||||
class U1HttpTimeoutError(U1HttpErrorBase):
|
||||
MESSAGE = "Timeout Error"
|
||||
|
||||
|
||||
class U1HttpNetworkError(U1HttpErrorBase):
|
||||
MESSAGE = "Network Error"
|
||||
|
||||
|
||||
class U1HttpUnknownError(U1HttpErrorBase):
|
||||
MESSAGE = "Unknown Error"
|
||||
|
||||
|
||||
class U1HttpForbiddenContentError(U1HttpErrorBase):
|
||||
MESSAGE = "Forbidden Content Filtered"
|
||||
|
||||
|
||||
class U1HttpTruncatedResponseError(U1HttpErrorBase):
|
||||
MESSAGE = "Truncated Response"
|
||||
|
||||
|
||||
class U1HttpBadResponseError(U1HttpErrorBase):
|
||||
MESSAGE = "Bad Response"
|
||||
|
||||
|
||||
def finish_reason_to_error_class(finish_reason: str) -> tuple[type[U1HttpErrorBase], str]:
|
||||
if finish_reason == "length":
|
||||
explanation = "Response was truncated due to length limit."
|
||||
return U1HttpTruncatedResponseError, explanation
|
||||
elif finish_reason == "content_filter":
|
||||
explanation = "Response was filtered due to content policy."
|
||||
return U1HttpForbiddenContentError, explanation
|
||||
elif finish_reason in ("tool_calls", "function_call"):
|
||||
explanation = "Response was halted due to tool calls or function calls."
|
||||
return U1HttpBadRequestError, explanation
|
||||
elif finish_reason == "stop":
|
||||
explanation = "Response was completed normally."
|
||||
return U1HttpBadResponseError, explanation
|
||||
return U1HttpBadRequestError, f"Unknown finish reason: {finish_reason!r}."
|
||||
|
||||
|
||||
def error_type_to_error_class(error_type: str) -> tuple[type[U1HttpErrorBase], str]:
|
||||
if error_type == "invalid_request_error":
|
||||
explanation = "Invalid request error."
|
||||
return U1HttpBadRequestError, explanation
|
||||
elif error_type == "rate_limit_error":
|
||||
explanation = "Rate limit exceeded."
|
||||
return U1HttpTooManyRequestsError, explanation
|
||||
elif error_type == "authentication_error":
|
||||
explanation = "Authentication error."
|
||||
return U1HttpAuthError, explanation
|
||||
elif error_type == "api_error":
|
||||
explanation = "API service internal error."
|
||||
return U1HttpServerError, explanation
|
||||
elif error_type == "permission_error":
|
||||
explanation = "You are not authorized to access this resource."
|
||||
return U1HttpPermissionError, explanation
|
||||
return U1HttpBadRequestError, f"Unknown error type: {error_type!r}."
|
||||
|
||||
|
||||
def sanitize_base64_in_data(data: Any, *, truncate_length: int = 200) -> Any:
|
||||
"""Recursively replace base64-encoded strings in data structure.
|
||||
|
||||
Args:
|
||||
data: Data to sanitize (dict, list, str, or other)
|
||||
truncate_length: Maximum length of base64-encoded string to truncate
|
||||
|
||||
Returns:
|
||||
Sanitized data with base64 strings replaced by placeholders
|
||||
|
||||
Example:
|
||||
>>> _sanitize_base64_in_data({"image": "iVBORw0KG..." * 100})
|
||||
{"image": "<base64-data: 1200 bytes>"}
|
||||
"""
|
||||
# Handle binary data first (bytes, bytearray, memoryview)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
# Try: bytes -> str
|
||||
with contextlib.suppress(Exception):
|
||||
data = data.decode("utf-8")
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return f'<binary-data len="{len(data)}bytes"/>'
|
||||
if isinstance(data, str):
|
||||
# Try: str -> dict | list
|
||||
with contextlib.suppress(Exception):
|
||||
data = json.loads(data)
|
||||
|
||||
seen_ids: set[int] = set() # Prevent circular references
|
||||
|
||||
def __recursive_sanitize_base64_in_data(
|
||||
data: Mapping | Iterable | str | Any,
|
||||
) -> dict | list | str | Any:
|
||||
if isinstance(data, str):
|
||||
if _is_base64_string(data) and len(data) > truncate_length:
|
||||
# Truncate base64-encoded string, replace it with placeholder
|
||||
len_str = f"{len(data):,d}bytes"
|
||||
return f'<base64-data len="{len_str}">{data[:truncate_length]}...{TRUNCATED_MARKER}...{data[-truncate_length:]}</base64-data>'
|
||||
return data
|
||||
elif isinstance(data, Mapping):
|
||||
obj_id = id(data)
|
||||
if obj_id in seen_ids:
|
||||
return "<circular-reference:mapping/>"
|
||||
seen_ids.add(obj_id)
|
||||
result = {
|
||||
key: __recursive_sanitize_base64_in_data(value) for key, value in data.items()
|
||||
}
|
||||
seen_ids.remove(obj_id)
|
||||
return result
|
||||
elif isinstance(data, Iterable):
|
||||
obj_id = id(data)
|
||||
if obj_id in seen_ids:
|
||||
return "<circular-reference:iterable/>"
|
||||
seen_ids.add(obj_id)
|
||||
result = [__recursive_sanitize_base64_in_data(item) for item in data]
|
||||
seen_ids.remove(obj_id)
|
||||
return result
|
||||
return data
|
||||
|
||||
return __recursive_sanitize_base64_in_data(data)
|
||||
|
||||
|
||||
TRUNCATED_MARKER = "<<<///TRUNCATED///>>>"
|
||||
BASE64_DETECTION_MIN_LENGTH = 200 # Minimum length to consider as potential base64
|
||||
BASE64_CHARS = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=")
|
||||
|
||||
|
||||
def _is_base64_string(value: str) -> bool:
|
||||
"""Check if a string looks like base64-encoded data.
|
||||
|
||||
Args:
|
||||
value: String to check
|
||||
|
||||
Returns:
|
||||
True if the string appears to be base64-encoded data
|
||||
|
||||
Heuristics:
|
||||
- Length >= BASE64_DETECTION_MIN_LENGTH (200 chars)
|
||||
- At least 80% of characters are valid base64 chars (A-Za-z0-9+/=)
|
||||
- No whitespace or newlines (valid base64 is continuous)
|
||||
"""
|
||||
if not isinstance(value, str) or len(value) < BASE64_DETECTION_MIN_LENGTH:
|
||||
return False
|
||||
|
||||
# Check if mostly base64 characters (allow some tolerance)
|
||||
if value.startswith("data:"):
|
||||
# Remove the prefix like "data:image/jpeg;base64,"
|
||||
index = value.find(";base64,")
|
||||
if index != -1:
|
||||
value = value[index + len(";base64,") :]
|
||||
valid_count = sum(1 for c in value if c in BASE64_CHARS)
|
||||
ratio = valid_count / len(value)
|
||||
|
||||
if ratio >= 0.98:
|
||||
with contextlib.suppress(Exception):
|
||||
base64.b64decode(value)
|
||||
return True
|
||||
|
||||
return False
|
||||
180
sn-image-base/scripts/sn_image_base/utils/httpx_client.py
Normal file
180
sn-image-base/scripts/sn_image_base/utils/httpx_client.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Shared httpx async client factory for vigeneval evaluators.
|
||||
|
||||
Centralizes connection pool limits, pool timeout, and optional file descriptor
|
||||
limit check to avoid PoolTimeout and 'Too many open files' under high concurrency.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import resource
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .error_utils import (
|
||||
U1HttpAuthError,
|
||||
U1HttpBadRequestError,
|
||||
U1HttpNotFoundError,
|
||||
U1HttpServerError,
|
||||
U1HttpTooManyRequestsError,
|
||||
)
|
||||
|
||||
|
||||
def check_file_descriptor_limit(max_connections: int, margin: int = 200) -> None:
|
||||
"""Raise if process file descriptor limit is too low for max_connections.
|
||||
|
||||
Avoids 'Too many open files' mid-run when using a large httpx connection pool.
|
||||
No-op on Windows or when resource module has no RLIMIT_NOFILE.
|
||||
|
||||
Args:
|
||||
max_connections: Intended httpx pool max_connections.
|
||||
margin: Extra FDs to reserve for app (logs, other files). Default 200.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If soft limit < max_connections + margin.
|
||||
"""
|
||||
try:
|
||||
soft, _hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
except (ImportError, AttributeError, OSError):
|
||||
return
|
||||
required = max_connections + margin
|
||||
if soft < required:
|
||||
raise RuntimeError(
|
||||
f"File descriptor limit too low for max_connections={max_connections}. "
|
||||
f"Current soft limit: {soft}, need at least {required}. "
|
||||
"Raise the limit before running, e.g.: ulimit -n 2048 # or higher, then re-run."
|
||||
)
|
||||
|
||||
|
||||
def create_async_httpx_client(
|
||||
headers: dict[str, str],
|
||||
*,
|
||||
timeout: float = 600.0,
|
||||
max_connections: int = 500,
|
||||
pool_timeout: float = 60.0,
|
||||
check_fd_limit: bool = False,
|
||||
verify: bool = True,
|
||||
**client_kwargs: Any,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an httpx.AsyncClient with shared defaults for vigeneval evaluators.
|
||||
|
||||
Automatically uses proxy from environment variables (HTTPS_PROXY, HTTP_PROXY, etc.)
|
||||
when trust_env=True (default). Supports proxy authentication via URL format:
|
||||
http://username:password@proxy_host:port
|
||||
|
||||
Connection pool limits and pool timeout help avoid PoolTimeout under high concurrency.
|
||||
Optionally checks process file descriptor limit before creating the client.
|
||||
|
||||
Args:
|
||||
headers: Request headers (e.g. Content-Type, Authorization).
|
||||
timeout: Request timeout in seconds. Default 600.
|
||||
max_connections: Connection pool size. Default 500; use 1000 for batch
|
||||
high parallelism (and check_fd_limit=True).
|
||||
pool_timeout: Max seconds to wait for a connection from the pool. Default 60.
|
||||
check_fd_limit: If True, call check_file_descriptor_limit(max_connections)
|
||||
and raise before creating the client. Use for batch evaluators.
|
||||
verify: If False, disable SSL certificate verification (avoids
|
||||
CERTIFICATE_VERIFY_FAILED). Use only for dev/testing or trusted networks.
|
||||
**client_kwargs: Passed through to httpx.AsyncClient (e.g. base_url).
|
||||
|
||||
Returns:
|
||||
A new httpx.AsyncClient. Caller must aclose() when done.
|
||||
|
||||
Example:
|
||||
# Set proxy with authentication in environment
|
||||
export HTTP_PROXY="http://user:pass@proxy.example.com:3128"
|
||||
export HTTPS_PROXY="http://user:pass@proxy.example.com:3128"
|
||||
|
||||
# Create client - proxy is automatically used
|
||||
client = create_async_httpx_client(
|
||||
headers={"Authorization": "Bearer token"},
|
||||
max_connections=100,
|
||||
)
|
||||
"""
|
||||
if check_fd_limit:
|
||||
check_file_descriptor_limit(max_connections)
|
||||
|
||||
# Note: Proxy configuration is handled automatically by httpx when trust_env=True.
|
||||
# We don't need to explicitly read or pass proxy URLs - httpx will read from
|
||||
# environment variables (HTTPS_PROXY, HTTP_PROXY, etc.) and handle authentication.
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=min(400, max_connections),
|
||||
keepalive_expiry=30.0,
|
||||
)
|
||||
|
||||
# Create transport without explicit proxy parameter when trust_env=True
|
||||
# This allows httpx to properly handle proxy authentication from environment
|
||||
transport = httpx.AsyncHTTPTransport(
|
||||
verify=verify,
|
||||
trust_env=True,
|
||||
local_address="0.0.0.0",
|
||||
limits=limits,
|
||||
)
|
||||
|
||||
# Create client with trust_env=True to enable proxy from environment
|
||||
return httpx.AsyncClient(
|
||||
transport=transport,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(timeout, pool=pool_timeout),
|
||||
verify=verify,
|
||||
trust_env=True, # Enable reading proxy from environment variables
|
||||
**client_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def httpx_response_raise_for_status_code(response: httpx.Response) -> None:
|
||||
"""Check httpx response status code and raise appropriate exceptions.
|
||||
|
||||
Args:
|
||||
response: The httpx response object.
|
||||
verbose: Whether to log verbose information.
|
||||
|
||||
Raises:
|
||||
AuthError: If response status is 401 or 403.
|
||||
APIError: If response status is 429 or 5xx.
|
||||
InvalidRequestError: If response status is 4xx (except 401, 403, 429).
|
||||
"""
|
||||
# Try best effort to parse response content & headers
|
||||
response_headers = "[N/A]" # Not available
|
||||
response_content = "[N/A]" # Not available
|
||||
request_url = "[N/A]"
|
||||
request_method = "[N/A]"
|
||||
with contextlib.suppress(Exception):
|
||||
response_headers = response.headers
|
||||
response_headers = dict(response_headers)
|
||||
with contextlib.suppress(Exception):
|
||||
response_content = response.content
|
||||
response_content = response_content.decode("utf-8")
|
||||
response_content = json.loads(response_content)
|
||||
with contextlib.suppress(Exception):
|
||||
request_method = response.request.method
|
||||
request_method = request_method.upper()
|
||||
request_url = str(response.request.url)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise U1HttpNotFoundError(
|
||||
detail=f"{request_method} {request_url!r} not found. Please check the URL and the model name.",
|
||||
code=response.status_code,
|
||||
)
|
||||
if response.status_code in (401, 403):
|
||||
raise U1HttpAuthError(
|
||||
detail=f"Authentication or authorization failed. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
elif response.status_code in (429, 503):
|
||||
raise U1HttpTooManyRequestsError(
|
||||
detail=f"Service temporarily unavailable. Please try again later. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
elif 500 <= response.status_code <= 599:
|
||||
raise U1HttpServerError(
|
||||
detail=f"Request failed. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
elif 400 <= response.status_code <= 499:
|
||||
raise U1HttpBadRequestError(
|
||||
detail=f"Bad request. {request_method} {request_url!r}. Response content: {response_content}",
|
||||
code=response.status_code,
|
||||
)
|
||||
Reference in New Issue
Block a user