ai_service/scripts/run_pipeline.py

790 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
* @file run_pipeline.py
* @brief AI海报生成系统主服务入口和API服务器
* 集成多个AI模型提供统一的海报生成接口
*
* @author 王秀强 (2310460@mail.nankai.edu.cn)
* @date 2025.6.9
* @version v1.0.0
*
* @details
* 本文件主要实现:
* - FastAPI服务器和RESTful API接口
* - 用户输入分析和海报生成流程编排
* - Vue组件代码生成和PSD文件合成
* - 会话管理和文件下载服务
* - 集成DeepSeek、Kimi、ComfyUI等AI服务
*
* @note
* - 依赖外部ComfyUI服务(101.201.50.90:8188)进行图片生成
* - 需要配置DEEPSEEK_API_KEY和MOONSHOT_API_KEY环境变量
* - PSD生成优先使用手动创建的模板文件
* - 支持CORS跨域访问生产环境需调整安全配置
*
* @usage
* # API服务器模式
* python run_pipeline.py
* # 选择: 2 (API服务器模式)
*
* # 本地测试模式
* python run_pipeline.py
* # 选择: 1 (本地测试模式)
*
* @copyright
* (c) 2025 砚生项目组
*/
"""
import os
from dotenv import load_dotenv
import yaml
from prompt_analysis import llm_user_analysis
from generate_layout import call_deepseek, generate_vue_code, save_code
from generate_text import load_config_from_file, get_poster_content_suggestions
from fastapi import FastAPI
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from flux_con import comfyui_img_info
from export_psd_from_json import create_psd_from_images as create_psd_impl
from colorama import init, Fore, Style
import json
import shutil
import uuid
from datetime import datetime
# 初始化colorama
init(autoreset=True)
# 配置路径
config_paths = {
"font": "../configs/font.yaml",
"output_folder": "../outputs/",
}
# 加载环境变量和配置
load_dotenv()
app = FastAPI(title="AI海报生成系统API", version="1.0.0")
# 添加CORS中间件允许前端跨域访问
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该设置具体的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class PosterRequest(BaseModel):
user_input: str
session_id: str = None # 可选的会话ID用于跟踪同一个用户的请求
class GenerateVueRequest(BaseModel):
user_input: str
session_id: str = None
class GeneratePSDRequest(BaseModel):
user_input: str
session_id: str = None
use_manual_psd: bool = False # 是否使用手动创建的PSD文件
# 响应模型
class ApiResponse(BaseModel):
status: str
message: str
data: dict = None
session_id: str = None
# 全局变量存储会话数据
sessions = {}
# 加载字体配置
try:
with open(config_paths["font"], "r", encoding="utf-8") as f:
fonts_config = yaml.safe_load(f)
print(f"{Fore.GREEN}✅ 字体配置加载成功{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.YELLOW}⚠️ 字体配置加载失败: {e},使用默认配置{Style.RESET_ALL}")
fonts_config = {}
# 辅助函数
def print_step(step_num, description, status="进行中"):
"""打印带颜色的步骤信息"""
if status == "进行中":
print(f"{Fore.BLUE}📋 步骤{step_num}: {description}...{Style.RESET_ALL}")
elif status == "完成":
print(f"{Fore.GREEN}✅ 步骤{step_num}: {description} - 完成{Style.RESET_ALL}")
elif status == "错误":
print(f"{Fore.RED}❌ 步骤{step_num}: {description} - 出错{Style.RESET_ALL}")
def print_result(key, value):
"""打印结果信息"""
print(f"{Fore.CYAN}📊 {key}: {value}{Style.RESET_ALL}")
def get_session_folder(session_id):
"""获取会话专用的输出文件夹"""
if not session_id:
session_id = str(uuid.uuid4())
session_folder = os.path.join(config_paths["output_folder"], f"session_{session_id}")
os.makedirs(session_folder, exist_ok=True)
return session_folder, session_id
# 生成prompts.yaml的函数
def generate_prompts_yaml(user_input=None):
"""
动态生成prompts.yaml配置文件
"""
if not user_input:
user_input = "端午节海报,包含背景、活动亮点和图标"
prompts_config = {
"default_logo_text": "",
"available_fonts": [
{
"name": "Microsoft YaHei",
"displayName": "微软雅黑",
"tags": ["现代", "清晰"],
"roles": ["title", "subtitle", "content"]
},
{
"name": "SimHei",
"displayName": "黑体",
"tags": ["通用", "标准"],
"roles": ["title", "subtitle", "content"]
}
],
"NAMING_COLORS": {
"primary": "#1976D2",
"secondary": "#424242",
"accent": "#FF5722"
},
"STYLE_RULES": {
"modern": {
"primary_font": "Microsoft YaHei",
"secondary_font": "SimHei"
}
},
"LOGO_RULES": {
"default_position": "bottom",
"fallback_text": "活动主办方"
}
}
# 保存到临时文件
temp_prompts_path = os.path.join(config_paths["output_folder"], "temp_prompts.yaml")
os.makedirs(os.path.dirname(temp_prompts_path), exist_ok=True)
with open(temp_prompts_path, 'w', encoding='utf-8') as f:
yaml.dump(prompts_config, f, allow_unicode=True, default_flow_style=False)
print(f"{Fore.GREEN}✅ 临时prompts.yaml已生成: {temp_prompts_path}{Style.RESET_ALL}")
return temp_prompts_path
# 修复PSD合成接口 - 使用临时固定图片列表
def create_psd_from_images_wrapper(img_list, vue_layout_path, output_path):
"""
临时接口使用固定的图片文件创建PSD文件
使用outputs目录下的: lotus.jpg, nankai.png, stamp.jpg, background.png
"""
print(f"{Fore.CYAN}🎨 开始合成PSD文件(临时接口)...{Style.RESET_ALL}")
try:
# 使用固定的图片文件列表
fixed_image_files = ["background.png", "lotus.jpg", "nankai.png", "stamp.jpg"]
image_paths = []
for img_file in fixed_image_files:
img_path = os.path.join(config_paths["output_folder"], img_file)
if os.path.exists(img_path):
image_paths.append(img_path)
print(f"{Fore.GREEN}✓ 找到图片: {img_file}{Style.RESET_ALL}")
else:
print(f"{Fore.YELLOW}⚠️ 图片不存在: {img_file} (路径: {img_path}){Style.RESET_ALL}")
if not image_paths:
print(f"{Fore.RED}❌ 没有找到任何指定的图片文件{Style.RESET_ALL}")
return None
print(f"{Fore.CYAN}📋 将合并以下图片 (共{len(image_paths)}张):")
for i, path in enumerate(image_paths):
print(f" {i + 1}. {os.path.basename(path)}")
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 调用export_psd中的PSD创建函数
create_psd_impl(
image_paths=image_paths,
output_path=output_path,
canvas_size=(1080, 1920), # 使用标准海报尺寸
mode='RGB'
)
print(f"{Fore.GREEN}✅ PSD文件创建成功: {output_path}{Style.RESET_ALL}")
# 验证PSD文件是否成功创建
if os.path.exists(output_path):
file_size = os.path.getsize(output_path) / (1024 * 1024) # 转换为MB
print(f"{Fore.CYAN}📁 PSD文件大小: {file_size:.2f} MB{Style.RESET_ALL}")
return output_path
except Exception as e:
print(f"{Fore.RED}❌ PSD文件创建失败: {str(e)}{Style.RESET_ALL}")
import traceback
traceback.print_exc()
return None
# 增强Vue代码生成确保包含文案内容
def generate_layout_prompt(user_input_analysis_result, parse_imglist, suggestions=None):
"""
生成更完整的Vue布局提示包含文案内容
"""
width = user_input_analysis_result["width"]
height = user_input_analysis_result["height"]
theme = user_input_analysis_result.get("main_theme", "活动海报")
# 构造图片信息字符串
images_info = "\n".join(
[f"- {img['picture_name']} ({img['picture_description']})" for img in parse_imglist]
)
# 构造文案信息
content_info = ""
if suggestions:
try:
if 'layer6_title_content' in suggestions:
title = suggestions['layer6_title_content'].get('content', theme)
content_info += f"- 主标题: {title}\n"
if 'layer7_subtitle_content' in suggestions:
subtitle = suggestions['layer7_subtitle_content'].get('content', '精彩活动,敬请参与')
content_info += f"- 副标题: {subtitle}\n"
if 'layer5_logo_content' in suggestions:
logo = suggestions['layer5_logo_content'].get('text', '主办方')
content_info += f"- Logo文字: {logo}\n"
except Exception as e:
print(f"{Fore.YELLOW}⚠️ 文案信息解析错误: {e}{Style.RESET_ALL}")
content_info = f"- 主标题: {theme}\n- 副标题: 精彩活动,敬请参与\n"
# 调用DeepSeek生成动态排版Prompt
system_prompt = "你是一个擅长前端开发的AI专注于生成Vue.js代码。请根据提供的信息生成完整的Vue组件包含所有必要的HTML结构和基础定位样式。"
prompt = f"""
请生成一个Vue.js组件代码用于{theme}海报,要求如下:
组件尺寸: {width}x{height}px
图片资源:
{images_info}
文案内容:
{content_info}
布局要求:
1. 背景图层: 使用第一张图片作为背景,占据整个组件区域
2. 主标题: 位于画布上方1/3处居中显示
3. 副标题: 位于主标题下方,居中显示
4. 内容区域: 使用剩余图片,合理布局在中间区域
5. Logo区域: 位于底部,居中显示
技术要求:
- 使用Vue 3 Composition API
- 使用absolute定位进行精确布局
- 包含完整的template、script和style部分
- 确保所有文本内容都正确显示
- 图片使用相对路径引用
请生成完整可用的Vue组件代码不要包含任何说明文字。
"""
try:
result, _ = call_deepseek(prompt=prompt, system_prompt=system_prompt, temperature=0.4)
return result
except Exception as e:
print(f"{Fore.RED}❌ 布局提示生成失败: {e}{Style.RESET_ALL}")
return generate_fallback_vue_code(theme, width, height)
def generate_fallback_vue_code(theme, width=1080, height=1920):
"""
生成备用的Vue代码
"""
return f"""<template>
<div class="poster-container" :style="containerStyle">
<div class="background-layer">
<img src="../outputs/background.png" alt="背景" class="background-image" />
</div>
<div class="content-layer">
<div class="title-section">
<h1 class="main-title">{theme}</h1>
<h2 class="subtitle">精彩活动,敬请参与</h2>
</div>
<div class="main-content">
<div class="image-gallery">
<img src="../outputs/image1.png" alt="活动图片" class="content-image" />
</div>
</div>
<div class="footer-section">
<div class="logo-area">
<span class="logo-text">主办方</span>
</div>
</div>
</div>
</div>
</template>
<script setup>
import {{ computed }} from 'vue'
const containerStyle = computed(() => ({{
width: '{width}px',
height: '{height}px',
position: 'relative',
overflow: 'hidden'
}}))
</script>
<style scoped>
.poster-container {{
position: relative;
background: #ffffff;
}}
.background-layer {{
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
z-index: 1;
}}
.background-image {{
width: 100%;
height: 100%;
object-fit: cover;
}}
.content-layer {{
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
z-index: 2;
}}
.title-section {{
position: absolute;
top: 20%;
left: 50%;
transform: translateX(-50%);
text-align: center;
}}
.main-title {{
font-size: 48px;
font-weight: bold;
margin-bottom: 20px;
color: #333;
}}
.subtitle {{
font-size: 24px;
color: #666;
margin-bottom: 40px;
}}
.main-content {{
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}}
.content-image {{
max-width: 400px;
max-height: 300px;
object-fit: cover;
}}
.footer-section {{
position: absolute;
bottom: 10%;
left: 50%;
transform: translateX(-50%);
}}
.logo-text {{
font-size: 18px;
color: #666;
}}
</style>"""
# 一键执行流程
def run_pipeline(user_input=None):
"""
自动执行海报生成流程
"""
try:
print(f"{Fore.MAGENTA}{'=' * 50}")
print(f"{Fore.MAGENTA}🎨 海报生成流程启动 🎨")
print(f"{'=' * 50}{Style.RESET_ALL}")
print_step(1, "加载配置文件")
prompts_yaml_path = generate_prompts_yaml(user_input)
load_config_from_file(prompts_yaml_path)
print_step(1, "加载配置文件", "完成")
print_step(2, "分析用户输入")
user_input_analysis_result = llm_user_analysis(user_input)
print_result("分析结果", user_input_analysis_result.get('main_theme', '未知'))
print_step(2, "分析用户输入", "完成")
print_step(3, "生成图片信息")
system_prompt = user_input_analysis_result["analyzed_prompt"]
parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt)
print_result("生成图片数量", len(parse_imglist))
print_step(3, "生成图片信息", "完成")
print_step(4, "生成文案建议")
suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"])
print(f"{Fore.CYAN}文案生成结果:")
print(json.dumps(suggestions, indent=2, ensure_ascii=False))
# 保存文案到文件
suggestions_path = os.path.join(config_paths["output_folder"], "poster_content.json")
with open(suggestions_path, "w", encoding="utf-8") as f:
json.dump(suggestions, f, indent=2, ensure_ascii=False)
print_step(4, "生成文案建议", "完成")
print_step(5, "生成Vue排版")
dynamic_prompt = generate_layout_prompt(user_input_analysis_result, parse_imglist, suggestions)
vue_code = generate_vue_code(dynamic_prompt)
vue_path = os.path.join(config_paths["output_folder"], "generated_code.vue")
save_code(vue_code, file_path=vue_path)
# 验证Vue文件是否成功生成
if os.path.exists(vue_path):
print(f"{Fore.GREEN}✅ Vue文件已生成: {vue_path}{Style.RESET_ALL}")
# 显示Vue代码的前几行用于验证
with open(vue_path, 'r', encoding='utf-8') as f:
preview = f.read()[:500]
print(f"{Fore.CYAN}Vue代码预览:\n{preview}...{Style.RESET_ALL}")
else:
print(f"{Fore.RED}❌ Vue文件生成失败{Style.RESET_ALL}")
print_step(5, "生成Vue排版", "完成")
print_step(6, "合成PSD文件")
img_list = [(pic["picture_name"], pic["picture_type"]) for pic in parse_imglist]
psd_path = os.path.join(config_paths["output_folder"], "final_poster.psd")
result_path = create_psd_from_images_wrapper(
img_list=img_list,
vue_layout_path=vue_path,
output_path=psd_path
)
if result_path:
print_step(6, "合成PSD文件", "完成")
else:
print_step(6, "合成PSD文件", "错误")
print(f"\n{Fore.GREEN}{'=' * 50}")
print(f"✅ 流程执行完成!")
print(f"{'=' * 50}{Style.RESET_ALL}")
return os.path.join(config_paths["output_folder"], "final_poster.png")
except Exception as e:
print_step("X", f"Pipeline执行", "错误")
print(f"{Fore.RED}错误详情: {str(e)}{Style.RESET_ALL}")
import traceback
traceback.print_exc()
return None
# 本地运行函数
def run_local_pipeline(user_input=None):
"""
本地运行整个管道流程,输出结果到控制台和文件系统。
"""
print(f"{Fore.CYAN}🎬 Starting local pipeline with input: {Style.BRIGHT}{user_input}{Style.RESET_ALL}")
output_path = run_pipeline(user_input)
if output_path:
print(f"{Fore.GREEN}🎊 Pipeline completed successfully!")
print(f"{Fore.YELLOW}📁 Results saved to:")
print(f" - Vue layout: {os.path.join(config_paths['output_folder'], 'generated_code.vue')}")
print(f" - PSD file: {os.path.join(config_paths['output_folder'], 'final_poster.psd')}")
print(f" - Content JSON: {os.path.join(config_paths['output_folder'], 'poster_content.json')}")
print(f"{Fore.CYAN}💡 Check the outputs/ directory for generated files.{Style.RESET_ALL}")
else:
print(f"{Fore.RED}❌ Pipeline执行失败{Style.RESET_ALL}")
# API路由
@app.get("/")
def read_root():
return {
"message": "AI海报生成系统API",
"version": "1.0.0",
"endpoints": {
"generate_poster": "/api/generate-poster", # 主要接口
"download": "/api/download/{file_type}",
"health": "/health",
"status": "/api/status/{session_id}"
}
}
@app.get("/health")
def health_check():
"""健康检查接口"""
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
@app.post("/api/generate-poster", response_model=ApiResponse)
async def generate_poster_api(request: PosterRequest):
"""
一键生成完整海报包含Vue代码和PSD文件的主要API接口
"""
try:
session_folder, session_id = get_session_folder(request.session_id)
print(f"{Fore.BLUE}🎨 开始生成海报...{Style.RESET_ALL}")
print(f"{Fore.CYAN}用户输入: {request.user_input}{Style.RESET_ALL}")
# === 步骤1: 生成配置文件 ===
print(f"{Fore.BLUE}📋 步骤1: 生成配置文件{Style.RESET_ALL}")
temp_prompts_path = os.path.join(session_folder, "temp_prompts.yaml")
prompts_config = {
"default_logo_text": "",
"available_fonts": [
{
"name": "Microsoft YaHei",
"displayName": "微软雅黑",
"tags": ["现代", "清晰"],
"roles": ["title", "subtitle", "content"]
},
{
"name": "SimHei",
"displayName": "黑体",
"tags": ["通用", "标准"],
"roles": ["title", "subtitle", "content"]
}
]
}
with open(temp_prompts_path, 'w', encoding='utf-8') as f:
yaml.dump(prompts_config, f, allow_unicode=True, default_flow_style=False)
load_config_from_file(temp_prompts_path)
# === 步骤2: 分析用户输入 ===
print(f"{Fore.BLUE}📋 步骤2: 分析用户输入{Style.RESET_ALL}")
user_input_analysis_result = llm_user_analysis(request.user_input)
# === 步骤3: 生成图片信息 ===
print(f"{Fore.BLUE}📋 步骤3: 生成图片信息{Style.RESET_ALL}")
system_prompt = user_input_analysis_result["analyzed_prompt"]
parse_imglist = comfyui_img_info(user_input_analysis_result, system_prompt)
# === 步骤4: 生成文案建议 ===
print(f"{Fore.BLUE}📋 步骤4: 生成文案建议{Style.RESET_ALL}")
suggestions = get_poster_content_suggestions(user_input_analysis_result["analyzed_prompt"])
# 保存文案到会话文件夹
suggestions_path = os.path.join(session_folder, "poster_content.json")
with open(suggestions_path, "w", encoding="utf-8") as f:
json.dump(suggestions, f, indent=2, ensure_ascii=False)
# === 步骤5: 生成Vue排版 ===
print(f"{Fore.BLUE}📋 步骤5: 生成Vue排版{Style.RESET_ALL}")
dynamic_prompt = generate_layout_prompt(user_input_analysis_result, parse_imglist, suggestions)
vue_code = generate_vue_code(dynamic_prompt)
vue_path = os.path.join(session_folder, "generated_code.vue")
save_code(vue_code, file_path=vue_path)
# === 步骤6: 合成PSD文件 ===
print(f"{Fore.BLUE}📋 步骤6: 合成PSD文件{Style.RESET_ALL}")
img_list = [(pic["picture_name"], pic["picture_type"]) for pic in parse_imglist]
psd_path = os.path.join(session_folder, "final_poster.psd")
# 修复PSD创建调用
try:
# 使用固定的图片文件列表
fixed_image_files = ["background.png", "lotus.jpg", "nankai.png", "stamp.jpg"]
image_paths = []
for img_file in fixed_image_files:
img_path = os.path.join(config_paths["output_folder"], img_file)
if os.path.exists(img_path):
image_paths.append(img_path)
if image_paths:
# 确保输出目录存在
os.makedirs(os.path.dirname(psd_path), exist_ok=True)
# 调用PSD创建函数
create_psd_impl(
image_paths=image_paths,
output_path=psd_path,
canvas_size=(1080, 1920),
mode='RGB'
)
print(f"{Fore.GREEN}✅ PSD文件创建成功: {psd_path}{Style.RESET_ALL}")
psd_created = True
else:
print(f"{Fore.YELLOW}⚠️ 没有找到图片文件跳过PSD创建{Style.RESET_ALL}")
psd_created = False
except Exception as psd_error:
print(f"{Fore.RED}❌ PSD文件创建失败: {str(psd_error)}{Style.RESET_ALL}")
psd_created = False
# 返回API响应
return ApiResponse(
status="success",
message="海报生成完成",
data={
"vue_file": vue_path if os.path.exists(vue_path) else None,
"psd_file": psd_path if psd_created else None,
"content_file": suggestions_path,
"analysis_result": user_input_analysis_result,
"images_info": parse_imglist,
"suggestions": suggestions
},
session_id=session_id
)
except Exception as e:
print(f"{Fore.RED}❌ API错误: {str(e)}{Style.RESET_ALL}")
import traceback
traceback.print_exc()
return ApiResponse(
status="error",
message=f"生成失败: {str(e)}",
data=None,
session_id=session_id if 'session_id' in locals() else None
)
@app.get("/api/download/{file_type}")
async def download_file(file_type: str, session_id: str = None):
"""
下载生成的文件
"""
try:
if session_id:
session_folder = os.path.join(config_paths["output_folder"], f"session_{session_id}")
else:
session_folder = config_paths["output_folder"]
if file_type == "vue":
file_path = os.path.join(session_folder, "generated_code.vue")
media_type = "text/plain"
elif file_type == "psd":
file_path = os.path.join(session_folder, "final_poster.psd")
media_type = "application/octet-stream"
elif file_type == "json":
file_path = os.path.join(session_folder, "poster_content.json")
media_type = "application/json"
else:
return JSONResponse(
status_code=400,
content={"error": "不支持的文件类型"}
)
if os.path.exists(file_path):
return FileResponse(
path=file_path,
media_type=media_type,
filename=os.path.basename(file_path)
)
else:
return JSONResponse(
status_code=404,
content={"error": "文件不存在"}
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"下载失败: {str(e)}"}
)
@app.get("/api/status/{session_id}")
async def get_session_status(session_id: str):
"""
获取会话状态
"""
try:
session_folder = os.path.join(config_paths["output_folder"], f"session_{session_id}")
if not os.path.exists(session_folder):
return JSONResponse(
status_code=404,
content={"error": "会话不存在"}
)
files_status = {
"vue_file": os.path.exists(os.path.join(session_folder, "generated_code.vue")),
"psd_file": os.path.exists(os.path.join(session_folder, "final_poster.psd")),
"content_file": os.path.exists(os.path.join(session_folder, "poster_content.json"))
}
return {
"session_id": session_id,
"files": files_status,
"folder": session_folder
}
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": f"状态查询失败: {str(e)}"}
)
if __name__ == "__main__":
import uvicorn
# 启动本地运行(可选)
print(f"{Fore.BLUE}🔧 运行模式选择:{Style.RESET_ALL}")
print(f"{Fore.YELLOW}1. 本地测试模式")
print(f"2. API服务器模式{Style.RESET_ALL}")
choice = input("请选择运行模式 (1/2): ").strip()
if choice == "1":
# 本地测试
test_input = input("请输入海报需求 (留空使用默认): ").strip()
if not test_input:
test_input = "端午节安康,"
run_local_pipeline(test_input)
else:
# 启动API服务器
print(f"{Fore.GREEN}🚀 启动API服务器...{Style.RESET_ALL}")
uvicorn.run(app, host="0.0.0.0", port=8000)