282 lines
9.6 KiB
Python
282 lines
9.6 KiB
Python
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import uuid
|
||
import random
|
||
from datetime import datetime
|
||
from websocket import create_connection, WebSocketTimeoutException, WebSocketConnectionClosedException
|
||
import urllib.request
|
||
import urllib.parse
|
||
|
||
def comfyui_img_info(user_input_analysis_result, system_prompt):
|
||
"""
|
||
根据提示词分析结果生成图片,并返回parse_imglist列表
|
||
|
||
参数:
|
||
user_input_analysis_result: 用户输入的分析结果(字典)
|
||
system_prompt: 用户输入的system prompt内容
|
||
|
||
返回:
|
||
parse_imglist (list): 图片解析列表,包含图片信息的字典
|
||
"""
|
||
# 从分析结果中提取参数
|
||
width = user_input_analysis_result.get('width', 1024)
|
||
height = user_input_analysis_result.get('height', 768)
|
||
batch_size = user_input_analysis_result.get('batch_size', 1)
|
||
|
||
# 配置参数
|
||
WORKING_DIR = 'output'
|
||
COMFYUI_ENDPOINT = '127.0.0.1:8188'
|
||
DEFAULT_WORKFLOW = './workflows/flux_work.json'
|
||
TEMP_WORKFLOW_DIR = './workflows/temp'
|
||
|
||
# 创建临时目录
|
||
os.makedirs(TEMP_WORKFLOW_DIR, exist_ok=True)
|
||
PROCESSED_WORKFLOW = os.path.join(TEMP_WORKFLOW_DIR, f"processed_workflow_{uuid.uuid4().hex}.json")
|
||
|
||
# 1. 预处理工作流
|
||
workflow_file = preprocess_workflow(
|
||
system_prompt=system_prompt,
|
||
width=width,
|
||
height=height,
|
||
batch_size=batch_size,
|
||
input_json=DEFAULT_WORKFLOW,
|
||
output_json=PROCESSED_WORKFLOW
|
||
)
|
||
|
||
# 2. 准备输出目录
|
||
os.makedirs(WORKING_DIR, exist_ok=True)
|
||
|
||
# 创建客户端ID
|
||
client_id = str(uuid.uuid4())
|
||
|
||
# 生成图像
|
||
saved_files = generate_images(
|
||
workflow_file=workflow_file,
|
||
server_address=COMFYUI_ENDPOINT,
|
||
output_dir=WORKING_DIR,
|
||
client_id=client_id
|
||
)
|
||
|
||
# 构建parse_imglist
|
||
parse_imglist = []
|
||
for file_path in saved_files:
|
||
# 提取图片信息
|
||
filename = os.path.basename(file_path)
|
||
name_without_ext = os.path.splitext(filename)[0]
|
||
|
||
# 构造图片信息字典
|
||
img_info = {
|
||
"picture_name": name_without_ext,
|
||
"picture_type": "png",
|
||
"picture_description": system_prompt,
|
||
"picture_size": f"{width}x{height}"
|
||
}
|
||
parse_imglist.append(img_info)
|
||
|
||
return parse_imglist
|
||
|
||
def preprocess_workflow(system_prompt, width, height, batch_size, input_json='flux_work.json', output_json='processed_workflow.json'):
|
||
"""
|
||
预处理工作流文件,更新系统提示和图像参数
|
||
"""
|
||
try:
|
||
with open(input_json, 'r') as f:
|
||
workflow = json.load(f)
|
||
|
||
# 更新系统提示
|
||
workflow['31']['inputs']['system_prompt'] = system_prompt
|
||
|
||
# 更新图像参数
|
||
workflow['27']['inputs']['width'] = str(width)
|
||
workflow['27']['inputs']['height'] = str(height)
|
||
workflow['27']['inputs']['batch_size'] = str(batch_size)
|
||
|
||
# 保存更新后的工作流
|
||
with open(output_json, 'w') as f:
|
||
json.dump(workflow, f, indent=2)
|
||
|
||
print(f"工作流已更新并保存到: {output_json}")
|
||
return output_json
|
||
|
||
except Exception as e:
|
||
print(f"预处理工作流出错: {str(e)}")
|
||
sys.exit(1)
|
||
|
||
def queue_prompt(prompt, server_address, client_id):
|
||
"""向服务器队列发送提示信息"""
|
||
p = {"prompt": prompt, "client_id": client_id}
|
||
data = json.dumps(p).encode('utf-8')
|
||
req = urllib.request.Request(f"http://{server_address}/prompt", data=data)
|
||
return json.loads(urllib.request.urlopen(req).read())
|
||
|
||
def get_image(filename, subfolder, folder_type, server_address):
|
||
"""获取生成的图像"""
|
||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||
url_values = urllib.parse.urlencode(data)
|
||
with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response:
|
||
return response.read()
|
||
|
||
def get_history(prompt_id, server_address):
|
||
"""获取历史记录"""
|
||
with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response:
|
||
return json.loads(response.read())
|
||
|
||
def get_images(ws, prompt, server_address, client_id, timeout=600):
|
||
"""获取生成的所有图像"""
|
||
prompt_id = queue_prompt(prompt, server_address, client_id)['prompt_id']
|
||
print(f'提示ID: {prompt_id}')
|
||
output_images = {}
|
||
|
||
start_time = time.time()
|
||
while True:
|
||
if time.time() - start_time > timeout:
|
||
print(f"超时:等待执行超过{timeout}秒")
|
||
break
|
||
|
||
try:
|
||
out = ws.recv()
|
||
if isinstance(out, str):
|
||
message = json.loads(out)
|
||
if message['type'] == 'executing':
|
||
data = message['data']
|
||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||
print('执行完成')
|
||
break
|
||
except Exception as e:
|
||
print(f"接收消息出错: {str(e)}")
|
||
break
|
||
|
||
history = get_history(prompt_id, server_address).get(prompt_id, {})
|
||
if not history:
|
||
print("未找到该提示的历史记录")
|
||
return {}
|
||
|
||
for node_id, node_output in history['outputs'].items():
|
||
if 'images' in node_output:
|
||
images_output = []
|
||
for image in node_output['images']:
|
||
try:
|
||
image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address)
|
||
images_output.append({
|
||
'data': image_data,
|
||
'filename': image['filename'],
|
||
'subfolder': image['subfolder'],
|
||
'type': image['type']
|
||
})
|
||
except Exception as e:
|
||
print(f"获取图像错误: {str(e)}")
|
||
output_images[node_id] = images_output
|
||
|
||
print(f'获取到 {len(output_images)} 组图像输出')
|
||
return output_images
|
||
|
||
def generate_images(workflow_file, server_address, output_dir, client_id):
|
||
"""生成图像主函数"""
|
||
try:
|
||
# 加载工作流
|
||
with open(workflow_file, 'r', encoding='utf-8') as f:
|
||
workflow_data = json.load(f)
|
||
|
||
# 使用随机种子
|
||
seed = random.randint(1, 10**8)
|
||
print(f'使用种子: {seed}')
|
||
|
||
# 更新种子
|
||
workflow_data['25']['inputs']['noise_seed'] = seed
|
||
|
||
# 创建WebSocket连接
|
||
ws_url = f"ws://{server_address}/ws?clientId={client_id}"
|
||
ws = create_connection(ws_url, timeout=600)
|
||
|
||
# 获取图像
|
||
images = get_images(ws, workflow_data, server_address, client_id)
|
||
ws.close()
|
||
|
||
# 保存图像
|
||
saved_files = []
|
||
if images:
|
||
for node_id, image_list in images.items():
|
||
for i, img in enumerate(image_list):
|
||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
filename = f"{seed}_{timestamp}_{i}.png"
|
||
file_path = os.path.join(output_dir, filename)
|
||
|
||
try:
|
||
with open(file_path, "wb") as f:
|
||
f.write(img['data'])
|
||
saved_files.append(file_path)
|
||
print(f'已保存: {file_path}')
|
||
except Exception as e:
|
||
print(f"保存图像错误: {str(e)}")
|
||
|
||
return saved_files
|
||
|
||
except Exception as e:
|
||
print(f"生成图像出错: {str(e)}")
|
||
return []
|
||
|
||
if __name__ == "__main__":
|
||
# 配置参数
|
||
WORKING_DIR = 'output'
|
||
COMFYUI_ENDPOINT = '127.0.0.1:8188'
|
||
DEFAULT_WORKFLOW = './workflows/flux_work.json'
|
||
TEMP_WORKFLOW_DIR = './workflows/temp'
|
||
|
||
# 从命令行获取输入参数
|
||
if len(sys.argv) != 5:
|
||
print("用法: python test.py <prompt> <width> <height> <batch_size>")
|
||
print("示例: python test.py \"南开大学图书馆,大雨天\" 2048 1024 1")
|
||
sys.exit(1)
|
||
|
||
system_prompt = sys.argv[1]
|
||
width = int(sys.argv[2])
|
||
height = int(sys.argv[3])
|
||
batch_size = int(sys.argv[4])
|
||
|
||
# 创建临时目录
|
||
os.makedirs(TEMP_WORKFLOW_DIR, exist_ok=True)
|
||
|
||
# 创建临时文件路径
|
||
PROCESSED_WORKFLOW = os.path.join(TEMP_WORKFLOW_DIR, f"processed_workflow_{uuid.uuid4().hex}.json")
|
||
|
||
# 1. 预处理工作流
|
||
workflow_file = preprocess_workflow(
|
||
system_prompt=system_prompt,
|
||
width=width,
|
||
height=height,
|
||
batch_size=batch_size,
|
||
input_json=DEFAULT_WORKFLOW,
|
||
output_json=PROCESSED_WORKFLOW
|
||
)
|
||
|
||
# 2. 准备输出目录
|
||
os.makedirs(WORKING_DIR, exist_ok=True)
|
||
|
||
# 创建客户端ID
|
||
client_id = str(uuid.uuid4())
|
||
|
||
print(f"系统提示: {system_prompt}")
|
||
print(f"图像尺寸: {width}x{height}")
|
||
print(f"批次大小: {batch_size}")
|
||
print(f"工作流文件: {workflow_file}")
|
||
print(f"客户端ID: {client_id}")
|
||
print(f"开始使用ComfyUI生成图像: {COMFYUI_ENDPOINT}")
|
||
|
||
start_time = time.time()
|
||
|
||
# 生成图像
|
||
print(f"\n===== 开始生成图像 =====")
|
||
saved_files = generate_images(
|
||
workflow_file=workflow_file,
|
||
server_address=COMFYUI_ENDPOINT,
|
||
output_dir=WORKING_DIR,
|
||
client_id=client_id
|
||
)
|
||
|
||
# 输出结果
|
||
elapsed = time.time() - start_time
|
||
print(f"\n处理完成,耗时 {elapsed:.2f} 秒")
|
||
print(f"共生成 {len(saved_files)} 张图像")
|
||
print(f"保存位置: {WORKING_DIR}") |