""" * @file run_pipeline.py * @brief AI海报生成系统主服务入口和API服务器 * 集成多个AI模型提供统一的海报生成接口 * * @author 王秀强 (2310460@mail.nankai.edu.cn) * @date 2025.6.9 * @version v2.0.0 * * @details * 本文件主要实现: * - FastAPI服务器和RESTful API接口 * - 用户输入分析和海报生成流程编排 * - Vue组件代码生成和PSD文件合成 * - 会话管理和文件下载服务 * - 集成DeepSeek、Kimi、ComfyUI等AI服务 * * @note * - 依赖外部ComfyUI服务(101.201.50.90:8188)进行图片生成 * - 需要配置DEEPSEEK_API_KEY和MOONSHOT_API_KEY环境变量 * - PSD生成优先使用手动创建的模板文件 * - 支持CORS跨域访问,生产环境需调整安全配置 * * @usage * # API服务器模式 * python run_pipeline.py * # 选择: 2 (API服务器模式) * * # 本地测试模式 * python run_pipeline.py * # 选择: 1 (本地测试模式) * * @copyright * (c) 2025 砚生项目组 */ """ import os import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) from dotenv import load_dotenv import json import shutil import uuid from datetime import datetime from typing import Optional # 导入工具函数 from utils import ( print_step, print_result, get_session_folder, llm_user_analysis, save_json_file, save_vue_file, create_temp_config, CONFIG_PATHS ) # 导入核心模块 from generate_layout import generate_vue_code_enhanced, save_code from generate_text import load_config_from_file, get_poster_content_suggestions from flux_con import comfyui_img_info from export_psd_from_json import create_psd_from_images as create_psd_impl # FastAPI相关 from fastapi import FastAPI from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from colorama import init, Fore, Style # 初始化colorama init(autoreset=True) # 加载环境变量 load_dotenv() # FastAPI应用 app = FastAPI(title="AI海报生成系统API", version="2.0.0") # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 在生产环境中应该设置具体的域名 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 请求模型 class PosterRequest(BaseModel): user_input: str session_id: str = None class ApiResponse(BaseModel): status: str message: str data: dict = None session_id: str = None # 全局变量存储会话数据 sessions = {} def create_psd_from_fixed_images(output_path: str) -> Optional[str]: """ 使用固定的图片文件创建PSD文件 简化的PSD创建函数,直接使用预定义的四个图片 """ print(f"{Fore.CYAN}🎨 开始合成PSD文件...{Style.RESET_ALL}") try: # 使用固定的图片文件列表 fixed_image_files = ["background.png", "lotus.jpg", "nku.png", "stamp.jpg"] image_paths = [] for img_file in fixed_image_files: img_path = os.path.join(CONFIG_PATHS["output_folder"], img_file) if os.path.exists(img_path): image_paths.append(img_path) print(f"{Fore.GREEN}✓ 找到图片: {img_file}{Style.RESET_ALL}") else: print(f"{Fore.YELLOW}⚠️ 图片不存在: {img_file}{Style.RESET_ALL}") if not image_paths: print(f"{Fore.RED}❌ 没有找到任何指定的图片文件{Style.RESET_ALL}") return None print(f"{Fore.CYAN}📋 将合并以下图片 (共{len(image_paths)}张):") for i, path in enumerate(image_paths): print(f" {i + 1}. {os.path.basename(path)}") # 确保输出目录存在 os.makedirs(os.path.dirname(output_path), exist_ok=True) # 调用PSD创建函数 create_psd_impl( image_paths=image_paths, output_path=output_path, canvas_size=(1080, 1920), mode='RGB' ) print(f"{Fore.GREEN}✅ PSD文件创建成功: {output_path}{Style.RESET_ALL}") # 验证PSD文件 if os.path.exists(output_path): file_size = os.path.getsize(output_path) / (1024 * 1024) print(f"{Fore.CYAN}📁 PSD文件大小: {file_size:.2f} MB{Style.RESET_ALL}") return output_path except Exception as e: print(f"{Fore.RED}❌ PSD文件创建失败: {str(e)}{Style.RESET_ALL}") import traceback traceback.print_exc() return None def run_pipeline(user_input: str = None) -> Optional[str]: """ 简化的海报生成流程 """ try: print(f"{Fore.MAGENTA}{'=' * 50}") print(f"{Fore.MAGENTA}🎨 海报生成流程启动 🎨") print(f"{'=' * 50}{Style.RESET_ALL}") print_step(1, "生成临时配置") temp_config_path = create_temp_config(user_input) load_config_from_file(temp_config_path) print_step(1, "生成临时配置", "完成") print_step(2, "分析用户输入") user_input_analysis_result = llm_user_analysis(user_input) print_result("分析结果", user_input_analysis_result.get('main_theme', '未知')) print_step(2, "分析用户输入", "完成") print_step(3, "生成图片信息") system_prompt = user_input_analysis_result["analyzed_prompt"] parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt) print_result("生成图片数量", len(parse_imglist)) print_step(3, "生成图片信息", "完成") print_step(4, "生成文案建议") suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"]) # 保存文案到文件 suggestions_path = os.path.join(CONFIG_PATHS["output_folder"], "poster_content.json") save_json_file(suggestions, suggestions_path) print_step(4, "生成文案建议", "完成") print_step(5, "生成Vue组件") # 使用增强的Vue代码生成(已移至generate_layout模块) vue_code = generate_vue_code_enhanced( user_input_analysis_result, parse_imglist, suggestions ) vue_path = os.path.join(CONFIG_PATHS["output_folder"], "generated_code.vue") save_vue_file(vue_code, vue_path) print_step(5, "生成Vue组件", "完成") print_step(6, "合成PSD文件") psd_path = os.path.join(CONFIG_PATHS["output_folder"], "final_poster.psd") result_path = create_psd_from_fixed_images(psd_path) if result_path: print_step(6, "合成PSD文件", "完成") else: print_step(6, "合成PSD文件", "错误") print(f"\n{Fore.GREEN}{'=' * 50}") print(f"✅ 流程执行完成!") print(f"{'=' * 50}{Style.RESET_ALL}") return result_path except Exception as e: print_step("X", f"Pipeline执行", "错误") print(f"{Fore.RED}错误详情: {str(e)}{Style.RESET_ALL}") import traceback traceback.print_exc() return None def run_local_pipeline(user_input: str = None): """ 本地运行整个管道流程 """ print(f"{Fore.CYAN}🎬 启动本地流程,输入: {Style.BRIGHT}{user_input}{Style.RESET_ALL}") output_path = run_pipeline(user_input) if output_path: print(f"{Fore.GREEN}🎊 流程执行成功!") print(f"{Fore.YELLOW}📁 生成文件:") print(f" - Vue组件: {os.path.join(CONFIG_PATHS['output_folder'], 'generated_code.vue')}") print(f" - PSD文件: {os.path.join(CONFIG_PATHS['output_folder'], 'final_poster.psd')}") print(f" - 文案JSON: {os.path.join(CONFIG_PATHS['output_folder'], 'poster_content.json')}") print(f"{Fore.CYAN}💡 查看 outputs/ 目录获取生成的文件{Style.RESET_ALL}") else: print(f"{Fore.RED}❌ 流程执行失败{Style.RESET_ALL}") # === API路由 === @app.get("/") def read_root(): return { "message": "AI海报生成系统API v2.0", "version": "2.0.0", "features": [ "预定义Vue模板支持", "简化的代码结构", "优化的流程管理" ], "endpoints": { "generate_poster": "/api/generate-poster", "download": "/api/download/{file_type}", "health": "/health", "status": "/api/status/{session_id}" } } @app.get("/health") def health_check(): return {"status": "healthy", "timestamp": datetime.now().isoformat()} @app.post("/api/generate-poster", response_model=ApiResponse) async def generate_poster_api(request: PosterRequest): """ 简化的海报生成API接口 """ try: session_folder, session_id = get_session_folder(request.session_id) print(f"{Fore.BLUE}🎨 开始生成海报...{Style.RESET_ALL}") print(f"{Fore.CYAN}用户输入: {request.user_input}{Style.RESET_ALL}") # === 步骤1: 生成配置文件 === temp_config_path = create_temp_config(request.user_input) load_config_from_file(temp_config_path) # === 步骤2: 分析用户输入 === user_input_analysis_result = llm_user_analysis(request.user_input) # === 步骤3: 生成图片信息 === system_prompt = user_input_analysis_result["analyzed_prompt"] parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt) # === 步骤4: 生成文案建议 === suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"]) # 保存文案到会话文件夹 suggestions_path = os.path.join(session_folder, "poster_content.json") save_json_file(suggestions, suggestions_path) # === 步骤5: 生成Vue组件 === vue_code = generate_vue_code_enhanced( user_input_analysis_result, parse_imglist, suggestions ) vue_path = os.path.join(session_folder, "generated_code.vue") save_vue_file(vue_code, vue_path) # === 步骤6: 合成PSD文件 === psd_path = os.path.join(session_folder, "final_poster.psd") psd_created = create_psd_from_fixed_images(psd_path) # 返回API响应 return ApiResponse( status="success", message="海报生成完成", data={ "vue_file": vue_path if os.path.exists(vue_path) else None, "psd_file": psd_path if psd_created else None, "content_file": suggestions_path, "analysis_result": user_input_analysis_result, "images_info": parse_imglist, "suggestions": suggestions, "vue_code": vue_code, "file_size_mb": round(os.path.getsize(psd_path) / (1024 * 1024), 2) if psd_created else 0, "generated_images": len(parse_imglist) }, session_id=session_id ) except Exception as e: print(f"{Fore.RED}❌ API错误: {str(e)}{Style.RESET_ALL}") import traceback traceback.print_exc() return ApiResponse( status="error", message=f"生成失败: {str(e)}", data=None, session_id=session_id if 'session_id' in locals() else None ) @app.get("/api/download/{file_type}") async def download_file(file_type: str, session_id: str = None): """下载生成的文件""" try: if session_id: session_folder = os.path.join(CONFIG_PATHS["output_folder"], f"session_{session_id}") else: session_folder = CONFIG_PATHS["output_folder"] file_mapping = { "vue": ("generated_code.vue", "text/plain"), "psd": ("final_poster.psd", "application/octet-stream"), "json": ("poster_content.json", "application/json") } if file_type not in file_mapping: return JSONResponse(status_code=400, content={"error": "不支持的文件类型"}) filename, media_type = file_mapping[file_type] file_path = os.path.join(session_folder, filename) if os.path.exists(file_path): return FileResponse( path=file_path, media_type=media_type, filename=filename ) else: return JSONResponse(status_code=404, content={"error": "文件不存在"}) except Exception as e: return JSONResponse(status_code=500, content={"error": f"下载失败: {str(e)}"}) @app.get("/api/status/{session_id}") async def get_session_status(session_id: str): """获取会话状态""" try: session_folder = os.path.join(CONFIG_PATHS["output_folder"], f"session_{session_id}") if not os.path.exists(session_folder): return JSONResponse(status_code=404, content={"error": "会话不存在"}) files_status = { "vue_file": os.path.exists(os.path.join(session_folder, "generated_code.vue")), "psd_file": os.path.exists(os.path.join(session_folder, "final_poster.psd")), "content_file": os.path.exists(os.path.join(session_folder, "poster_content.json")) } return { "session_id": session_id, "files": files_status, "folder": session_folder } except Exception as e: return JSONResponse(status_code=500, content={"error": f"状态查询失败: {str(e)}"}) if __name__ == "__main__": import uvicorn print(f"{Fore.BLUE}🔧 运行模式选择:{Style.RESET_ALL}") print(f"{Fore.YELLOW}1. 本地测试模式") print(f"2. API服务器模式{Style.RESET_ALL}") choice = input("请选择运行模式 (1/2): ").strip() if choice == "1": # 本地测试 test_input = input("请输入海报需求 (留空使用默认): ").strip() if not test_input: test_input = "端午节海报,传统风格" run_local_pipeline(test_input) else: # 启动API服务器 print(f"{Fore.GREEN}🚀 启动API服务器 v2.0...{Style.RESET_ALL}") uvicorn.run(app, host="0.0.0.0", port=8000)