ai_service/scripts/run_pipeline.py
2025-07-05 04:24:43 +08:00

415 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
* @file run_pipeline.py
* @brief AI海报生成系统主服务入口和API服务器
* 集成多个AI模型提供统一的海报生成接口
*
* @author 王秀强 (2310460@mail.nankai.edu.cn)
* @date 2025.6.9
* @version v2.0.0
*
* @details
* 本文件主要实现:
* - FastAPI服务器和RESTful API接口
* - 用户输入分析和海报生成流程编排
* - Vue组件代码生成和PSD文件合成
* - 会话管理和文件下载服务
* - 集成DeepSeek、Kimi、ComfyUI等AI服务
*
* @note
* - 依赖外部ComfyUI服务(101.201.50.90:8188)进行图片生成
* - 需要配置DEEPSEEK_API_KEY和MOONSHOT_API_KEY环境变量
* - PSD生成优先使用手动创建的模板文件
* - 支持CORS跨域访问生产环境需调整安全配置
*
* @usage
* # API服务器模式
* python run_pipeline.py
* # 选择: 2 (API服务器模式)
*
* # 本地测试模式
* python run_pipeline.py
* # 选择: 1 (本地测试模式)
*
* @copyright
* (c) 2025 砚生项目组
*/
"""
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from dotenv import load_dotenv
import json
import shutil
import uuid
from datetime import datetime
from typing import Optional
# 导入工具函数
from utils import (
print_step, print_result, get_session_folder,
llm_user_analysis, save_json_file, save_vue_file,
create_temp_config, CONFIG_PATHS
)
# 导入核心模块
from generate_layout import generate_vue_code_enhanced, save_code
from generate_text import load_config_from_file, get_poster_content_suggestions
from flux_con import comfyui_img_info
from export_psd_from_json import create_psd_from_images as create_psd_impl
# FastAPI相关
from fastapi import FastAPI
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from colorama import init, Fore, Style
# 初始化colorama
init(autoreset=True)
# 加载环境变量
load_dotenv()
# FastAPI应用
app = FastAPI(title="AI海报生成系统API", version="2.0.0")
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class PosterRequest(BaseModel):
user_input: str
session_id: str = None
class ApiResponse(BaseModel):
status: str
message: str
data: dict = None
session_id: str = None
# 全局变量存储会话数据
sessions = {}
def create_psd_from_fixed_images(output_path: str) -> Optional[str]:
"""
使用固定的图片文件创建PSD文件
简化的PSD创建函数直接使用预定义的四个图片
"""
print(f"{Fore.CYAN}🎨 开始合成PSD文件...{Style.RESET_ALL}")
try:
# 使用固定的图片文件列表
fixed_image_files = ["background.png", "lotus.jpg", "nku.png", "stamp.jpg"]
image_paths = []
for img_file in fixed_image_files:
img_path = os.path.join(CONFIG_PATHS["output_folder"], img_file)
if os.path.exists(img_path):
image_paths.append(img_path)
print(f"{Fore.GREEN}✓ 找到图片: {img_file}{Style.RESET_ALL}")
else:
print(f"{Fore.YELLOW}⚠️ 图片不存在: {img_file}{Style.RESET_ALL}")
if not image_paths:
print(f"{Fore.RED}❌ 没有找到任何指定的图片文件{Style.RESET_ALL}")
return None
print(f"{Fore.CYAN}📋 将合并以下图片 (共{len(image_paths)}张):")
for i, path in enumerate(image_paths):
print(f" {i + 1}. {os.path.basename(path)}")
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 调用PSD创建函数
create_psd_impl(
image_paths=image_paths,
output_path=output_path,
canvas_size=(1080, 1920),
mode='RGB'
)
print(f"{Fore.GREEN}✅ PSD文件创建成功: {output_path}{Style.RESET_ALL}")
# 验证PSD文件
if os.path.exists(output_path):
file_size = os.path.getsize(output_path) / (1024 * 1024)
print(f"{Fore.CYAN}📁 PSD文件大小: {file_size:.2f} MB{Style.RESET_ALL}")
return output_path
except Exception as e:
print(f"{Fore.RED}❌ PSD文件创建失败: {str(e)}{Style.RESET_ALL}")
import traceback
traceback.print_exc()
return None
def run_pipeline(user_input: str = None) -> Optional[str]:
try:
print(f"{Fore.MAGENTA}{'=' * 50}")
print(f"{Fore.MAGENTA}🎨 海报生成流程启动 🎨")
print(f"{'=' * 50}{Style.RESET_ALL}")
print_step(1, "生成临时配置")
temp_config_path = create_temp_config(user_input)
load_config_from_file(temp_config_path)
print_step(1, "生成临时配置", "完成")
print_step(2, "分析用户输入")
user_input_analysis_result = llm_user_analysis(user_input)
print_result("分析结果", user_input_analysis_result.get('main_theme', '未知'))
print_step(2, "分析用户输入", "完成")
print_step(3, "生成图片信息")
system_prompt = user_input_analysis_result["analyzed_prompt"]
parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt)
print_result("生成图片数量", len(parse_imglist))
print_step(3, "生成图片信息", "完成")
print_step(4, "生成文案建议")
suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"])
# 保存文案到文件
suggestions_path = os.path.join(CONFIG_PATHS["output_folder"], "poster_content.json")
save_json_file(suggestions, suggestions_path)
print_step(4, "生成文案建议", "完成")
print_step(5, "生成Vue组件")
# 使用增强的Vue代码生成已移至generate_layout模块
vue_code = generate_vue_code_enhanced(
user_input_analysis_result,
parse_imglist,
suggestions
)
vue_path = os.path.join(CONFIG_PATHS["output_folder"], "generated_code.vue")
save_vue_file(vue_code, vue_path)
print_step(5, "生成Vue组件", "完成")
print_step(6, "合成PSD文件")
psd_path = os.path.join(CONFIG_PATHS["output_folder"], "final_poster.psd")
result_path = create_psd_from_fixed_images(psd_path)
if result_path:
print_step(6, "合成PSD文件", "完成")
else:
print_step(6, "合成PSD文件", "错误")
print(f"\n{Fore.GREEN}{'=' * 50}")
print(f"✅ 流程执行完成!")
print(f"{'=' * 50}{Style.RESET_ALL}")
return result_path
except Exception as e:
print_step("X", f"Pipeline执行", "错误")
print(f"{Fore.RED}错误详情: {str(e)}{Style.RESET_ALL}")
import traceback
traceback.print_exc()
return None
def run_local_pipeline(user_input: str = None):
"""
本地运行整个管道流程
"""
print(f"{Fore.CYAN}🎬 启动本地流程,输入: {Style.BRIGHT}{user_input}{Style.RESET_ALL}")
output_path = run_pipeline(user_input)
if output_path:
print(f"{Fore.GREEN}🎊 流程执行成功!")
print(f"{Fore.YELLOW}📁 生成文件:")
print(f" - Vue组件: {os.path.join(CONFIG_PATHS['output_folder'], 'generated_code.vue')}")
print(f" - PSD文件: {os.path.join(CONFIG_PATHS['output_folder'], 'final_poster.psd')}")
print(f" - 文案JSON: {os.path.join(CONFIG_PATHS['output_folder'], 'poster_content.json')}")
print(f"{Fore.CYAN}💡 查看 outputs/ 目录获取生成的文件{Style.RESET_ALL}")
else:
print(f"{Fore.RED}❌ 流程执行失败{Style.RESET_ALL}")
# === API路由 ===
@app.get("/")
def read_root():
return {
"message": "AI海报生成系统API v2.0",
"version": "2.0.0",
"features": [
"预定义Vue模板支持",
"简化的代码结构",
"优化的流程管理"
],
"endpoints": {
"generate_poster": "/api/generate-poster",
"download": "/api/download/{file_type}",
"health": "/health",
"status": "/api/status/{session_id}"
}
}
@app.get("/health")
def health_check():
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
@app.post("/api/generate-poster", response_model=ApiResponse)
async def generate_poster_api(request: PosterRequest):
try:
session_folder, session_id = get_session_folder(request.session_id)
print(f"{Fore.BLUE}🎨 开始生成海报...{Style.RESET_ALL}")
print(f"{Fore.CYAN}用户输入: {request.user_input}{Style.RESET_ALL}")
# === 步骤1: 生成配置文件 ===
temp_config_path = create_temp_config(request.user_input)
load_config_from_file(temp_config_path)
# === 步骤2: 分析用户输入 ===
user_input_analysis_result = llm_user_analysis(request.user_input)
# === 步骤3: 生成图片信息 ===
system_prompt = user_input_analysis_result["analyzed_prompt"]
parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt)
# === 步骤4: 生成文案建议 ===
suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"])
# 保存文案到会话文件夹
suggestions_path = os.path.join(session_folder, "poster_content.json")
save_json_file(suggestions, suggestions_path)
# === 步骤5: 生成Vue组件 ===
vue_code = generate_vue_code_enhanced(
user_input_analysis_result,
parse_imglist,
suggestions
)
vue_path = os.path.join(session_folder, "generated_code.vue")
save_vue_file(vue_code, vue_path)
# === 步骤6: 合成PSD文件 ===
psd_path = os.path.join(session_folder, "final_poster.psd")
psd_created = create_psd_from_fixed_images(psd_path)
# 返回API响应
return ApiResponse(
status="success",
message="海报生成完成",
data={
"vue_file": vue_path if os.path.exists(vue_path) else None,
"psd_file": psd_path if psd_created else None,
"content_file": suggestions_path,
"analysis_result": user_input_analysis_result,
"images_info": parse_imglist,
"suggestions": suggestions,
"vue_code": vue_code,
"file_size_mb": round(os.path.getsize(psd_path) / (1024 * 1024), 2) if psd_created else 0,
"generated_images": len(parse_imglist)
},
session_id=session_id
)
except Exception as e:
print(f"{Fore.RED}❌ API错误: {str(e)}{Style.RESET_ALL}")
import traceback
traceback.print_exc()
return ApiResponse(
status="error",
message=f"生成失败: {str(e)}",
data=None,
session_id=session_id if 'session_id' in locals() else None
)
@app.get("/api/download/{file_type}")
async def download_file(file_type: str, session_id: str = None):
"""下载生成的文件"""
try:
if session_id:
session_folder = os.path.join(CONFIG_PATHS["output_folder"], f"session_{session_id}")
else:
session_folder = CONFIG_PATHS["output_folder"]
file_mapping = {
"vue": ("generated_code.vue", "text/plain"),
"psd": ("final_poster.psd", "application/octet-stream"),
"json": ("poster_content.json", "application/json")
}
if file_type not in file_mapping:
return JSONResponse(status_code=400, content={"error": "不支持的文件类型"})
filename, media_type = file_mapping[file_type]
file_path = os.path.join(session_folder, filename)
if os.path.exists(file_path):
return FileResponse(
path=file_path,
media_type=media_type,
filename=filename
)
else:
return JSONResponse(status_code=404, content={"error": "文件不存在"})
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"下载失败: {str(e)}"})
@app.get("/api/status/{session_id}")
async def get_session_status(session_id: str):
"""获取会话状态"""
try:
session_folder = os.path.join(CONFIG_PATHS["output_folder"], f"session_{session_id}")
if not os.path.exists(session_folder):
return JSONResponse(status_code=404, content={"error": "会话不存在"})
files_status = {
"vue_file": os.path.exists(os.path.join(session_folder, "generated_code.vue")),
"psd_file": os.path.exists(os.path.join(session_folder, "final_poster.psd")),
"content_file": os.path.exists(os.path.join(session_folder, "poster_content.json"))
}
return {
"session_id": session_id,
"files": files_status,
"folder": session_folder
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"状态查询失败: {str(e)}"})
if __name__ == "__main__":
import uvicorn
print(f"{Fore.BLUE}🔧 运行模式选择:{Style.RESET_ALL}")
print(f"{Fore.YELLOW}1. 本地测试模式")
print(f"2. API服务器模式{Style.RESET_ALL}")
choice = input("请选择运行模式 (1/2): ").strip()
if choice == "1":
# 本地测试
test_input = input("请输入海报需求 (留空使用默认): ").strip()
if not test_input:
test_input = "端午节海报,传统风格"
run_local_pipeline(test_input)
else:
# 启动API服务器
print(f"{Fore.GREEN}🚀 启动API服务器 v2.0...{Style.RESET_ALL}")
uvicorn.run(app, host="0.0.0.0", port=8000)