415 lines
14 KiB
Python
415 lines
14 KiB
Python
"""
|
||
* @file run_pipeline.py
|
||
* @brief AI海报生成系统主服务入口和API服务器
|
||
* 集成多个AI模型提供统一的海报生成接口
|
||
*
|
||
* @author 王秀强 (2310460@mail.nankai.edu.cn)
|
||
* @date 2025.6.9
|
||
* @version v2.0.0
|
||
*
|
||
* @details
|
||
* 本文件主要实现:
|
||
* - FastAPI服务器和RESTful API接口
|
||
* - 用户输入分析和海报生成流程编排
|
||
* - Vue组件代码生成和PSD文件合成
|
||
* - 会话管理和文件下载服务
|
||
* - 集成DeepSeek、Kimi、ComfyUI等AI服务
|
||
*
|
||
* @note
|
||
* - 依赖外部ComfyUI服务(101.201.50.90:8188)进行图片生成
|
||
* - 需要配置DEEPSEEK_API_KEY和MOONSHOT_API_KEY环境变量
|
||
* - PSD生成优先使用手动创建的模板文件
|
||
* - 支持CORS跨域访问,生产环境需调整安全配置
|
||
*
|
||
* @usage
|
||
* # API服务器模式
|
||
* python run_pipeline.py
|
||
* # 选择: 2 (API服务器模式)
|
||
*
|
||
* # 本地测试模式
|
||
* python run_pipeline.py
|
||
* # 选择: 1 (本地测试模式)
|
||
*
|
||
* @copyright
|
||
* (c) 2025 砚生项目组
|
||
*/
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from dotenv import load_dotenv
|
||
import json
|
||
import shutil
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
# 导入工具函数
|
||
from utils import (
|
||
print_step, print_result, get_session_folder,
|
||
llm_user_analysis, save_json_file, save_vue_file,
|
||
create_temp_config, CONFIG_PATHS
|
||
)
|
||
|
||
# 导入核心模块
|
||
from generate_layout import generate_vue_code_enhanced, save_code
|
||
from generate_text import load_config_from_file, get_poster_content_suggestions
|
||
from flux_con import comfyui_img_info
|
||
from export_psd_from_json import create_psd_from_images as create_psd_impl
|
||
|
||
# FastAPI相关
|
||
from fastapi import FastAPI
|
||
from fastapi.responses import FileResponse, JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
|
||
from colorama import init, Fore, Style
|
||
|
||
# 初始化colorama
|
||
init(autoreset=True)
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
# FastAPI应用
|
||
app = FastAPI(title="AI海报生成系统API", version="2.0.0")
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 在生产环境中应该设置具体的域名
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 请求模型
|
||
class PosterRequest(BaseModel):
|
||
user_input: str
|
||
session_id: str = None
|
||
|
||
class ApiResponse(BaseModel):
|
||
status: str
|
||
message: str
|
||
data: dict = None
|
||
session_id: str = None
|
||
|
||
# 全局变量存储会话数据
|
||
sessions = {}
|
||
|
||
|
||
def create_psd_from_fixed_images(output_path: str) -> Optional[str]:
|
||
"""
|
||
使用固定的图片文件创建PSD文件
|
||
简化的PSD创建函数,直接使用预定义的四个图片
|
||
"""
|
||
print(f"{Fore.CYAN}🎨 开始合成PSD文件...{Style.RESET_ALL}")
|
||
|
||
try:
|
||
# 使用固定的图片文件列表
|
||
fixed_image_files = ["background.png", "lotus.jpg", "nku.png", "stamp.jpg"]
|
||
image_paths = []
|
||
|
||
for img_file in fixed_image_files:
|
||
img_path = os.path.join(CONFIG_PATHS["output_folder"], img_file)
|
||
if os.path.exists(img_path):
|
||
image_paths.append(img_path)
|
||
print(f"{Fore.GREEN}✓ 找到图片: {img_file}{Style.RESET_ALL}")
|
||
else:
|
||
print(f"{Fore.YELLOW}⚠️ 图片不存在: {img_file}{Style.RESET_ALL}")
|
||
|
||
if not image_paths:
|
||
print(f"{Fore.RED}❌ 没有找到任何指定的图片文件{Style.RESET_ALL}")
|
||
return None
|
||
|
||
print(f"{Fore.CYAN}📋 将合并以下图片 (共{len(image_paths)}张):")
|
||
for i, path in enumerate(image_paths):
|
||
print(f" {i + 1}. {os.path.basename(path)}")
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
|
||
# 调用PSD创建函数
|
||
create_psd_impl(
|
||
image_paths=image_paths,
|
||
output_path=output_path,
|
||
canvas_size=(1080, 1920),
|
||
mode='RGB'
|
||
)
|
||
|
||
print(f"{Fore.GREEN}✅ PSD文件创建成功: {output_path}{Style.RESET_ALL}")
|
||
|
||
# 验证PSD文件
|
||
if os.path.exists(output_path):
|
||
file_size = os.path.getsize(output_path) / (1024 * 1024)
|
||
print(f"{Fore.CYAN}📁 PSD文件大小: {file_size:.2f} MB{Style.RESET_ALL}")
|
||
|
||
return output_path
|
||
|
||
except Exception as e:
|
||
print(f"{Fore.RED}❌ PSD文件创建失败: {str(e)}{Style.RESET_ALL}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def run_pipeline(user_input: str = None) -> Optional[str]:
|
||
try:
|
||
print(f"{Fore.MAGENTA}{'=' * 50}")
|
||
print(f"{Fore.MAGENTA}🎨 海报生成流程启动 🎨")
|
||
print(f"{'=' * 50}{Style.RESET_ALL}")
|
||
|
||
print_step(1, "生成临时配置")
|
||
temp_config_path = create_temp_config(user_input)
|
||
load_config_from_file(temp_config_path)
|
||
print_step(1, "生成临时配置", "完成")
|
||
|
||
print_step(2, "分析用户输入")
|
||
user_input_analysis_result = llm_user_analysis(user_input)
|
||
print_result("分析结果", user_input_analysis_result.get('main_theme', '未知'))
|
||
print_step(2, "分析用户输入", "完成")
|
||
|
||
print_step(3, "生成图片信息")
|
||
system_prompt = user_input_analysis_result["analyzed_prompt"]
|
||
parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt)
|
||
print_result("生成图片数量", len(parse_imglist))
|
||
print_step(3, "生成图片信息", "完成")
|
||
|
||
print_step(4, "生成文案建议")
|
||
suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"])
|
||
|
||
# 保存文案到文件
|
||
suggestions_path = os.path.join(CONFIG_PATHS["output_folder"], "poster_content.json")
|
||
save_json_file(suggestions, suggestions_path)
|
||
print_step(4, "生成文案建议", "完成")
|
||
|
||
print_step(5, "生成Vue组件")
|
||
# 使用增强的Vue代码生成(已移至generate_layout模块)
|
||
vue_code = generate_vue_code_enhanced(
|
||
user_input_analysis_result,
|
||
parse_imglist,
|
||
suggestions
|
||
)
|
||
|
||
vue_path = os.path.join(CONFIG_PATHS["output_folder"], "generated_code.vue")
|
||
save_vue_file(vue_code, vue_path)
|
||
print_step(5, "生成Vue组件", "完成")
|
||
|
||
print_step(6, "合成PSD文件")
|
||
psd_path = os.path.join(CONFIG_PATHS["output_folder"], "final_poster.psd")
|
||
result_path = create_psd_from_fixed_images(psd_path)
|
||
|
||
if result_path:
|
||
print_step(6, "合成PSD文件", "完成")
|
||
else:
|
||
print_step(6, "合成PSD文件", "错误")
|
||
|
||
print(f"\n{Fore.GREEN}{'=' * 50}")
|
||
print(f"✅ 流程执行完成!")
|
||
print(f"{'=' * 50}{Style.RESET_ALL}")
|
||
|
||
return result_path
|
||
|
||
except Exception as e:
|
||
print_step("X", f"Pipeline执行", "错误")
|
||
print(f"{Fore.RED}错误详情: {str(e)}{Style.RESET_ALL}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def run_local_pipeline(user_input: str = None):
|
||
"""
|
||
本地运行整个管道流程
|
||
"""
|
||
print(f"{Fore.CYAN}🎬 启动本地流程,输入: {Style.BRIGHT}{user_input}{Style.RESET_ALL}")
|
||
output_path = run_pipeline(user_input)
|
||
|
||
if output_path:
|
||
print(f"{Fore.GREEN}🎊 流程执行成功!")
|
||
print(f"{Fore.YELLOW}📁 生成文件:")
|
||
print(f" - Vue组件: {os.path.join(CONFIG_PATHS['output_folder'], 'generated_code.vue')}")
|
||
print(f" - PSD文件: {os.path.join(CONFIG_PATHS['output_folder'], 'final_poster.psd')}")
|
||
print(f" - 文案JSON: {os.path.join(CONFIG_PATHS['output_folder'], 'poster_content.json')}")
|
||
print(f"{Fore.CYAN}💡 查看 outputs/ 目录获取生成的文件{Style.RESET_ALL}")
|
||
else:
|
||
print(f"{Fore.RED}❌ 流程执行失败{Style.RESET_ALL}")
|
||
|
||
|
||
# === API路由 ===
|
||
|
||
@app.get("/")
|
||
def read_root():
|
||
return {
|
||
"message": "AI海报生成系统API v2.0",
|
||
"version": "2.0.0",
|
||
"features": [
|
||
"预定义Vue模板支持",
|
||
"简化的代码结构",
|
||
"优化的流程管理"
|
||
],
|
||
"endpoints": {
|
||
"generate_poster": "/api/generate-poster",
|
||
"download": "/api/download/{file_type}",
|
||
"health": "/health",
|
||
"status": "/api/status/{session_id}"
|
||
}
|
||
}
|
||
|
||
|
||
@app.get("/health")
|
||
def health_check():
|
||
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
|
||
|
||
|
||
@app.post("/api/generate-poster", response_model=ApiResponse)
|
||
async def generate_poster_api(request: PosterRequest):
|
||
try:
|
||
session_folder, session_id = get_session_folder(request.session_id)
|
||
|
||
print(f"{Fore.BLUE}🎨 开始生成海报...{Style.RESET_ALL}")
|
||
print(f"{Fore.CYAN}用户输入: {request.user_input}{Style.RESET_ALL}")
|
||
|
||
# === 步骤1: 生成配置文件 ===
|
||
temp_config_path = create_temp_config(request.user_input)
|
||
load_config_from_file(temp_config_path)
|
||
|
||
# === 步骤2: 分析用户输入 ===
|
||
user_input_analysis_result = llm_user_analysis(request.user_input)
|
||
|
||
# === 步骤3: 生成图片信息 ===
|
||
system_prompt = user_input_analysis_result["analyzed_prompt"]
|
||
parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt)
|
||
|
||
# === 步骤4: 生成文案建议 ===
|
||
suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"])
|
||
|
||
# 保存文案到会话文件夹
|
||
suggestions_path = os.path.join(session_folder, "poster_content.json")
|
||
save_json_file(suggestions, suggestions_path)
|
||
|
||
# === 步骤5: 生成Vue组件 ===
|
||
vue_code = generate_vue_code_enhanced(
|
||
user_input_analysis_result,
|
||
parse_imglist,
|
||
suggestions
|
||
)
|
||
vue_path = os.path.join(session_folder, "generated_code.vue")
|
||
save_vue_file(vue_code, vue_path)
|
||
|
||
# === 步骤6: 合成PSD文件 ===
|
||
psd_path = os.path.join(session_folder, "final_poster.psd")
|
||
psd_created = create_psd_from_fixed_images(psd_path)
|
||
|
||
# 返回API响应
|
||
return ApiResponse(
|
||
status="success",
|
||
message="海报生成完成",
|
||
data={
|
||
"vue_file": vue_path if os.path.exists(vue_path) else None,
|
||
"psd_file": psd_path if psd_created else None,
|
||
"content_file": suggestions_path,
|
||
"analysis_result": user_input_analysis_result,
|
||
"images_info": parse_imglist,
|
||
"suggestions": suggestions,
|
||
"vue_code": vue_code,
|
||
"file_size_mb": round(os.path.getsize(psd_path) / (1024 * 1024), 2) if psd_created else 0,
|
||
"generated_images": len(parse_imglist)
|
||
},
|
||
session_id=session_id
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"{Fore.RED}❌ API错误: {str(e)}{Style.RESET_ALL}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
return ApiResponse(
|
||
status="error",
|
||
message=f"生成失败: {str(e)}",
|
||
data=None,
|
||
session_id=session_id if 'session_id' in locals() else None
|
||
)
|
||
|
||
|
||
@app.get("/api/download/{file_type}")
|
||
async def download_file(file_type: str, session_id: str = None):
|
||
"""下载生成的文件"""
|
||
try:
|
||
if session_id:
|
||
session_folder = os.path.join(CONFIG_PATHS["output_folder"], f"session_{session_id}")
|
||
else:
|
||
session_folder = CONFIG_PATHS["output_folder"]
|
||
|
||
file_mapping = {
|
||
"vue": ("generated_code.vue", "text/plain"),
|
||
"psd": ("final_poster.psd", "application/octet-stream"),
|
||
"json": ("poster_content.json", "application/json")
|
||
}
|
||
|
||
if file_type not in file_mapping:
|
||
return JSONResponse(status_code=400, content={"error": "不支持的文件类型"})
|
||
|
||
filename, media_type = file_mapping[file_type]
|
||
file_path = os.path.join(session_folder, filename)
|
||
|
||
if os.path.exists(file_path):
|
||
return FileResponse(
|
||
path=file_path,
|
||
media_type=media_type,
|
||
filename=filename
|
||
)
|
||
else:
|
||
return JSONResponse(status_code=404, content={"error": "文件不存在"})
|
||
|
||
except Exception as e:
|
||
return JSONResponse(status_code=500, content={"error": f"下载失败: {str(e)}"})
|
||
|
||
|
||
@app.get("/api/status/{session_id}")
|
||
async def get_session_status(session_id: str):
|
||
"""获取会话状态"""
|
||
try:
|
||
session_folder = os.path.join(CONFIG_PATHS["output_folder"], f"session_{session_id}")
|
||
|
||
if not os.path.exists(session_folder):
|
||
return JSONResponse(status_code=404, content={"error": "会话不存在"})
|
||
|
||
files_status = {
|
||
"vue_file": os.path.exists(os.path.join(session_folder, "generated_code.vue")),
|
||
"psd_file": os.path.exists(os.path.join(session_folder, "final_poster.psd")),
|
||
"content_file": os.path.exists(os.path.join(session_folder, "poster_content.json"))
|
||
}
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"files": files_status,
|
||
"folder": session_folder
|
||
}
|
||
|
||
except Exception as e:
|
||
return JSONResponse(status_code=500, content={"error": f"状态查询失败: {str(e)}"})
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
print(f"{Fore.BLUE}🔧 运行模式选择:{Style.RESET_ALL}")
|
||
print(f"{Fore.YELLOW}1. 本地测试模式")
|
||
print(f"2. API服务器模式{Style.RESET_ALL}")
|
||
|
||
choice = input("请选择运行模式 (1/2): ").strip()
|
||
|
||
if choice == "1":
|
||
# 本地测试
|
||
test_input = input("请输入海报需求 (留空使用默认): ").strip()
|
||
if not test_input:
|
||
test_input = "端午节海报,传统风格"
|
||
run_local_pipeline(test_input)
|
||
else:
|
||
# 启动API服务器
|
||
print(f"{Fore.GREEN}🚀 启动API服务器 v2.0...{Style.RESET_ALL}")
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||
|