241 lines
7.0 KiB
Python
241 lines
7.0 KiB
Python
"""
|
||
沙箱执行器
|
||
在受限环境中执行生成的 Python 代码
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import subprocess
|
||
import uuid
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
from dataclasses import dataclass
|
||
|
||
|
||
@dataclass
|
||
class ExecutionResult:
|
||
"""执行结果"""
|
||
success: bool
|
||
task_id: str
|
||
stdout: str
|
||
stderr: str
|
||
return_code: int
|
||
log_path: str
|
||
duration_ms: int
|
||
|
||
|
||
class SandboxRunner:
|
||
"""
|
||
沙箱执行器
|
||
|
||
特性:
|
||
1. 使用 subprocess 启动独立 Python 进程
|
||
2. 工作目录限定为 workspace
|
||
3. 捕获所有输出
|
||
4. 写入日志文件
|
||
"""
|
||
|
||
def __init__(self, workspace_path: Optional[str] = None):
|
||
if workspace_path:
|
||
self.workspace = Path(workspace_path)
|
||
else:
|
||
# 默认使用项目根目录下的 workspace
|
||
self.workspace = Path(__file__).parent.parent / "workspace"
|
||
|
||
self.input_dir = self.workspace / "input"
|
||
self.output_dir = self.workspace / "output"
|
||
self.logs_dir = self.workspace / "logs"
|
||
|
||
# 确保目录存在
|
||
self.input_dir.mkdir(parents=True, exist_ok=True)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
self.logs_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
def save_task_code(self, code: str, task_id: Optional[str] = None) -> tuple[str, Path]:
|
||
"""
|
||
保存任务代码到文件
|
||
|
||
Args:
|
||
code: Python 代码
|
||
task_id: 任务 ID(可选,自动生成)
|
||
|
||
Returns:
|
||
(task_id, code_path)
|
||
"""
|
||
if not task_id:
|
||
task_id = self._generate_task_id()
|
||
|
||
code_path = self.workspace / f"task_{task_id}.py"
|
||
code_path.write_text(code, encoding='utf-8')
|
||
|
||
return task_id, code_path
|
||
|
||
def execute(self, code: str, task_id: Optional[str] = None, timeout: int = 60) -> ExecutionResult:
|
||
"""
|
||
执行代码
|
||
|
||
Args:
|
||
code: Python 代码
|
||
task_id: 任务 ID
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
ExecutionResult: 执行结果
|
||
"""
|
||
# 保存代码
|
||
task_id, code_path = self.save_task_code(code, task_id)
|
||
|
||
# 准备日志
|
||
log_path = self.logs_dir / f"task_{task_id}.log"
|
||
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
# 使用 subprocess 执行
|
||
result = subprocess.run(
|
||
[sys.executable, str(code_path)],
|
||
cwd=str(self.workspace),
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=timeout,
|
||
# 不继承父进程的环境变量中的网络代理等
|
||
env=self._get_safe_env()
|
||
)
|
||
|
||
end_time = datetime.now()
|
||
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
||
|
||
# 写入日志
|
||
self._write_log(
|
||
log_path=log_path,
|
||
task_id=task_id,
|
||
code_path=code_path,
|
||
stdout=result.stdout,
|
||
stderr=result.stderr,
|
||
return_code=result.returncode,
|
||
duration_ms=duration_ms
|
||
)
|
||
|
||
return ExecutionResult(
|
||
success=result.returncode == 0,
|
||
task_id=task_id,
|
||
stdout=result.stdout,
|
||
stderr=result.stderr,
|
||
return_code=result.returncode,
|
||
log_path=str(log_path),
|
||
duration_ms=duration_ms
|
||
)
|
||
|
||
except subprocess.TimeoutExpired:
|
||
end_time = datetime.now()
|
||
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
||
|
||
error_msg = f"执行超时(超过 {timeout} 秒)"
|
||
|
||
self._write_log(
|
||
log_path=log_path,
|
||
task_id=task_id,
|
||
code_path=code_path,
|
||
stdout="",
|
||
stderr=error_msg,
|
||
return_code=-1,
|
||
duration_ms=duration_ms
|
||
)
|
||
|
||
return ExecutionResult(
|
||
success=False,
|
||
task_id=task_id,
|
||
stdout="",
|
||
stderr=error_msg,
|
||
return_code=-1,
|
||
log_path=str(log_path),
|
||
duration_ms=duration_ms
|
||
)
|
||
|
||
except Exception as e:
|
||
end_time = datetime.now()
|
||
duration_ms = int((end_time - start_time).total_seconds() * 1000)
|
||
|
||
error_msg = f"执行异常: {str(e)}"
|
||
|
||
self._write_log(
|
||
log_path=log_path,
|
||
task_id=task_id,
|
||
code_path=code_path,
|
||
stdout="",
|
||
stderr=error_msg,
|
||
return_code=-1,
|
||
duration_ms=duration_ms
|
||
)
|
||
|
||
return ExecutionResult(
|
||
success=False,
|
||
task_id=task_id,
|
||
stdout="",
|
||
stderr=error_msg,
|
||
return_code=-1,
|
||
log_path=str(log_path),
|
||
duration_ms=duration_ms
|
||
)
|
||
|
||
def _generate_task_id(self) -> str:
|
||
"""生成任务 ID"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
short_uuid = uuid.uuid4().hex[:6]
|
||
return f"{timestamp}_{short_uuid}"
|
||
|
||
def _get_safe_env(self) -> dict:
|
||
"""获取安全的环境变量(移除网络代理等)"""
|
||
safe_env = os.environ.copy()
|
||
|
||
# 移除可能的网络代理设置
|
||
proxy_vars = [
|
||
'HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy',
|
||
'ALL_PROXY', 'all_proxy', 'NO_PROXY', 'no_proxy'
|
||
]
|
||
for var in proxy_vars:
|
||
safe_env.pop(var, None)
|
||
|
||
return safe_env
|
||
|
||
def _write_log(
|
||
self,
|
||
log_path: Path,
|
||
task_id: str,
|
||
code_path: Path,
|
||
stdout: str,
|
||
stderr: str,
|
||
return_code: int,
|
||
duration_ms: int
|
||
):
|
||
"""写入执行日志"""
|
||
log_content = f"""========================================
|
||
任务执行日志
|
||
========================================
|
||
任务 ID: {task_id}
|
||
代码文件: {code_path}
|
||
执行时间: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||
耗时: {duration_ms} ms
|
||
返回码: {return_code}
|
||
状态: {"成功" if return_code == 0 else "失败"}
|
||
|
||
========================================
|
||
标准输出 (stdout)
|
||
========================================
|
||
{stdout if stdout else "(无输出)"}
|
||
|
||
========================================
|
||
标准错误 (stderr)
|
||
========================================
|
||
{stderr if stderr else "(无错误)"}
|
||
"""
|
||
log_path.write_text(log_content, encoding='utf-8')
|
||
|
||
|
||
def run_task(code: str, task_id: Optional[str] = None) -> ExecutionResult:
|
||
"""便捷函数:执行任务"""
|
||
runner = SandboxRunner()
|
||
return runner.execute(code, task_id)
|
||
|