ai_service/scripts/utils.py
Wang Xiuqiang 20802db28a 重构run_pipeline和generate_layout
将辅助函数移动整合到utils.py当中
2025-07-03 11:47:48 +08:00

267 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
工具函数模块
整合所有辅助函数和公共功能
"""
import os
import uuid
import json
import yaml
from datetime import datetime
from dotenv import load_dotenv
from colorama import init, Fore, Style
from pathlib import Path
from typing import Dict, List, Optional, Tuple
# 初始化colorama
init(autoreset=True)
# 加载环境变量
load_dotenv()
# 配置路径
CONFIG_PATHS = {
"font": "../configs/font.yaml",
"output_folder": "../outputs/",
"workflows": "../workflows/",
"templates": "../configs/templates.yaml"
}
def print_step(step_num: int, description: str, status: str = "进行中"):
"""打印带颜色的步骤信息"""
if status == "进行中":
print(f"{Fore.BLUE}📋 步骤{step_num}: {description}...{Style.RESET_ALL}")
elif status == "完成":
print(f"{Fore.GREEN}✅ 步骤{step_num}: {description} - 完成{Style.RESET_ALL}")
elif status == "错误":
print(f"{Fore.RED}❌ 步骤{step_num}: {description} - 出错{Style.RESET_ALL}")
def print_result(key: str, value: str):
"""打印结果信息"""
print(f"{Fore.CYAN}📊 {key}: {value}{Style.RESET_ALL}")
def get_session_folder(session_id: Optional[str] = None) -> Tuple[str, str]:
"""获取会话专用的输出文件夹"""
if not session_id:
session_id = str(uuid.uuid4())
session_folder = os.path.join(CONFIG_PATHS["output_folder"], f"session_{session_id}")
os.makedirs(session_folder, exist_ok=True)
return session_folder, session_id
def load_config_file(config_type: str) -> Dict:
"""加载配置文件"""
config_path = CONFIG_PATHS.get(config_type)
if not config_path or not os.path.exists(config_path):
print(f"{Fore.YELLOW}⚠️ 配置文件不存在: {config_path},使用默认配置{Style.RESET_ALL}")
return {}
try:
with open(config_path, "r", encoding="utf-8") as f:
if config_path.endswith('.yaml') or config_path.endswith('.yml'):
config = yaml.safe_load(f)
else:
config = json.load(f)
print(f"{Fore.GREEN}✅ 配置文件加载成功: {config_path}{Style.RESET_ALL}")
return config
except Exception as e:
print(f"{Fore.RED}❌ 配置文件加载失败: {e}{Style.RESET_ALL}")
return {}
def save_json_file(data: Dict, file_path: str):
"""保存JSON文件"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
print(f"{Fore.GREEN}✅ JSON文件已保存: {file_path}{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}❌ JSON文件保存失败: {e}{Style.RESET_ALL}")
def save_vue_file(vue_code: str, file_path: str):
"""保存Vue文件"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
f.write(vue_code)
if os.path.exists(file_path):
file_size = os.path.getsize(file_path)
print(f"{Fore.GREEN}✅ Vue文件已保存: {file_path} ({file_size} 字节){Style.RESET_ALL}")
else:
print(f"{Fore.RED}❌ Vue文件保存失败{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}❌ Vue文件保存失败: {e}{Style.RESET_ALL}")
def create_temp_config(user_input: str = None) -> str:
"""动态生成临时配置文件"""
if not user_input:
user_input = "默认海报配置"
temp_config = {
"default_logo_text": "",
"available_fonts": [
{
"name": "Microsoft YaHei",
"displayName": "微软雅黑",
"tags": ["现代", "清晰"],
"roles": ["title", "subtitle", "content"]
},
{
"name": "SimHei",
"displayName": "黑体",
"tags": ["通用", "标准"],
"roles": ["title", "subtitle", "content"]
}
],
"vue_templates": {
"lotus.jpg": {
"theme": "荷花主题",
"style": "传统优雅",
"template_name": "lotus_template"
},
"nku.png": {
"theme": "南开大学",
"style": "学术正式",
"template_name": "nku_template"
},
"stamp.jpg": {
"theme": "印章装饰",
"style": "传统文化",
"template_name": "stamp_template"
},
"background.png": {
"theme": "通用背景",
"style": "简约现代",
"template_name": "background_template"
}
}
}
temp_config_path = os.path.join(CONFIG_PATHS["output_folder"], "temp_config.yaml")
with open(temp_config_path, 'w', encoding='utf-8') as f:
yaml.dump(temp_config, f, allow_unicode=True, default_flow_style=False)
print(f"{Fore.GREEN}✅ 临时配置文件已生成: {temp_config_path}{Style.RESET_ALL}")
return temp_config_path
# === 用户输入分析模块 (从 prompt_analysis.py 整合) ===
from generate_layout import call_deepseek
def llm_user_analysis(user_input: str) -> Dict:
"""
使用DeepSeek动态分析用户输入提取海报设计参数
"""
if not user_input:
user_input = "端午节海报,包含背景、活动亮点和图标"
print(f"{Fore.CYAN}🔍 正在分析用户输入: {Style.BRIGHT}{user_input}{Style.RESET_ALL}")
# 构建分析提示词
analysis_prompt = f"""
请分析以下用户输入的海报需求提取关键信息并以JSON格式返回
用户输入:{user_input}
请严格按照以下JSON格式返回
{{
"analyzed_prompt": "原始用户输入",
"keywords": ["提取的关键词1", "关键词2", "关键词3"],
"width": 1080,
"height": 1920,
"batch_size": 2,
"poster_type": "海报类型(如:节日海报、活动海报、产品海报等)",
"main_theme": "主要主题",
"style_preference": "风格偏好(如:现代、传统、简约等)",
"color_preference": "颜色偏好(如:暖色调、冷色调、单色等)"
}}
注意:
1. keywords应该包含3-5个最重要的关键词
2. 根据用户输入推断合适的海报类型
3. 尺寸默认为1080x1920除非用户明确指定
4. batch_size根据需求调整通常为1-4
5. 分析用户的风格和颜色偏好
"""
system_prompt = "你是一个专业的设计需求分析师擅长从用户描述中提取海报设计的关键参数。请严格按照JSON格式返回结果确保输出的JSON格式正确且完整。"
try:
result, _ = call_deepseek(prompt=analysis_prompt, system_prompt=system_prompt, temperature=0.3)
# 解析JSON
json_str = result.strip()
if "```json" in json_str:
json_str = json_str.split("```json")[1].split("```")[0].strip()
elif json_str.startswith("```") and json_str.endswith("```"):
json_str = json_str[3:-3].strip()
analysis_result = json.loads(json_str)
print(f"{Fore.GREEN}✅ 分析完成! {Style.RESET_ALL}")
print(f"{Fore.YELLOW}📊 主题: {Style.BRIGHT}{analysis_result.get('main_theme', '未知')}{Style.RESET_ALL}")
print(f"{Fore.YELLOW}🎨 风格: {analysis_result.get('style_preference', '未设置')}")
print(f"{Fore.YELLOW}🔖 关键词: {', '.join(analysis_result.get('keywords', []))}{Style.RESET_ALL}")
return analysis_result
except Exception as e:
print(f"{Fore.RED}❌ 分析失败: {str(e)}{Style.RESET_ALL}")
# 返回默认值
return {
"analyzed_prompt": user_input,
"keywords": ["海报", "设计", "活动"],
"width": 1080,
"height": 1920,
"batch_size": 2,
"poster_type": "通用海报",
"main_theme": "默认主题",
"style_preference": "现代",
"color_preference": "暖色调"
}
def validate_file_exists(file_path: str) -> bool:
"""验证文件是否存在"""
exists = os.path.exists(file_path)
if exists:
print(f"{Fore.GREEN}✓ 文件存在: {os.path.basename(file_path)}{Style.RESET_ALL}")
else:
print(f"{Fore.YELLOW}⚠️ 文件不存在: {os.path.basename(file_path)}{Style.RESET_ALL}")
return exists
def get_file_info(file_path: str) -> Dict:
"""获取文件信息"""
if not os.path.exists(file_path):
return {"exists": False}
stat = os.stat(file_path)
return {
"exists": True,
"size": stat.st_size,
"size_mb": round(stat.st_size / (1024 * 1024), 2),
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat()
}
if __name__ == "__main__":
# 测试工具函数
print(f"{Fore.MAGENTA}🧪 测试工具函数{Style.RESET_ALL}")
# 测试用户输入分析
test_input = "春节海报,红色背景,现代风格"
result = llm_user_analysis(test_input)
print(f"\n{Fore.GREEN}分析结果:")
print(json.dumps(result, indent=2, ensure_ascii=False))