comfyui/flux_con.py
2025-06-08 17:44:58 +08:00

282 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import sys
import time
import uuid
import random
from datetime import datetime
from websocket import create_connection, WebSocketTimeoutException, WebSocketConnectionClosedException
import urllib.request
import urllib.parse
def comfyui_img_info(user_input_analysis_result, system_prompt):
"""
根据提示词分析结果生成图片并返回parse_imglist列表
参数:
user_input_analysis_result: 用户输入的分析结果(字典)
system_prompt: 用户输入的system prompt内容
返回:
parse_imglist (list): 图片解析列表,包含图片信息的字典
"""
# 从分析结果中提取参数
width = user_input_analysis_result.get('width', 1024)
height = user_input_analysis_result.get('height', 768)
batch_size = user_input_analysis_result.get('batch_size', 1)
# 配置参数
WORKING_DIR = 'output'
COMFYUI_ENDPOINT = '127.0.0.1:8188'
DEFAULT_WORKFLOW = './workflows/flux_work.json'
TEMP_WORKFLOW_DIR = './workflows/temp'
# 创建临时目录
os.makedirs(TEMP_WORKFLOW_DIR, exist_ok=True)
PROCESSED_WORKFLOW = os.path.join(TEMP_WORKFLOW_DIR, f"processed_workflow_{uuid.uuid4().hex}.json")
# 1. 预处理工作流
workflow_file = preprocess_workflow(
system_prompt=system_prompt,
width=width,
height=height,
batch_size=batch_size,
input_json=DEFAULT_WORKFLOW,
output_json=PROCESSED_WORKFLOW
)
# 2. 准备输出目录
os.makedirs(WORKING_DIR, exist_ok=True)
# 创建客户端ID
client_id = str(uuid.uuid4())
# 生成图像
saved_files = generate_images(
workflow_file=workflow_file,
server_address=COMFYUI_ENDPOINT,
output_dir=WORKING_DIR,
client_id=client_id
)
# 构建parse_imglist
parse_imglist = []
for file_path in saved_files:
# 提取图片信息
filename = os.path.basename(file_path)
name_without_ext = os.path.splitext(filename)[0]
# 构造图片信息字典
img_info = {
"picture_name": name_without_ext,
"picture_type": "png",
"picture_description": system_prompt,
"picture_size": f"{width}x{height}"
}
parse_imglist.append(img_info)
return parse_imglist
def preprocess_workflow(system_prompt, width, height, batch_size, input_json='flux_work.json', output_json='processed_workflow.json'):
"""
预处理工作流文件,更新系统提示和图像参数
"""
try:
with open(input_json, 'r') as f:
workflow = json.load(f)
# 更新系统提示
workflow['31']['inputs']['system_prompt'] = system_prompt
# 更新图像参数
workflow['27']['inputs']['width'] = str(width)
workflow['27']['inputs']['height'] = str(height)
workflow['27']['inputs']['batch_size'] = str(batch_size)
# 保存更新后的工作流
with open(output_json, 'w') as f:
json.dump(workflow, f, indent=2)
print(f"工作流已更新并保存到: {output_json}")
return output_json
except Exception as e:
print(f"预处理工作流出错: {str(e)}")
sys.exit(1)
def queue_prompt(prompt, server_address, client_id):
"""向服务器队列发送提示信息"""
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{server_address}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type, server_address):
"""获取生成的图像"""
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response:
return response.read()
def get_history(prompt_id, server_address):
"""获取历史记录"""
with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response:
return json.loads(response.read())
def get_images(ws, prompt, server_address, client_id, timeout=600):
"""获取生成的所有图像"""
prompt_id = queue_prompt(prompt, server_address, client_id)['prompt_id']
print(f'提示ID: {prompt_id}')
output_images = {}
start_time = time.time()
while True:
if time.time() - start_time > timeout:
print(f"超时:等待执行超过{timeout}")
break
try:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
print('执行完成')
break
except Exception as e:
print(f"接收消息出错: {str(e)}")
break
history = get_history(prompt_id, server_address).get(prompt_id, {})
if not history:
print("未找到该提示的历史记录")
return {}
for node_id, node_output in history['outputs'].items():
if 'images' in node_output:
images_output = []
for image in node_output['images']:
try:
image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address)
images_output.append({
'data': image_data,
'filename': image['filename'],
'subfolder': image['subfolder'],
'type': image['type']
})
except Exception as e:
print(f"获取图像错误: {str(e)}")
output_images[node_id] = images_output
print(f'获取到 {len(output_images)} 组图像输出')
return output_images
def generate_images(workflow_file, server_address, output_dir, client_id):
"""生成图像主函数"""
try:
# 加载工作流
with open(workflow_file, 'r', encoding='utf-8') as f:
workflow_data = json.load(f)
# 使用随机种子
seed = random.randint(1, 10**8)
print(f'使用种子: {seed}')
# 更新种子
workflow_data['25']['inputs']['noise_seed'] = seed
# 创建WebSocket连接
ws_url = f"ws://{server_address}/ws?clientId={client_id}"
ws = create_connection(ws_url, timeout=600)
# 获取图像
images = get_images(ws, workflow_data, server_address, client_id)
ws.close()
# 保存图像
saved_files = []
if images:
for node_id, image_list in images.items():
for i, img in enumerate(image_list):
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"{seed}_{timestamp}_{i}.png"
file_path = os.path.join(output_dir, filename)
try:
with open(file_path, "wb") as f:
f.write(img['data'])
saved_files.append(file_path)
print(f'已保存: {file_path}')
except Exception as e:
print(f"保存图像错误: {str(e)}")
return saved_files
except Exception as e:
print(f"生成图像出错: {str(e)}")
return []
if __name__ == "__main__":
# 配置参数
WORKING_DIR = 'output'
COMFYUI_ENDPOINT = '127.0.0.1:8188'
DEFAULT_WORKFLOW = './workflows/flux_work.json'
TEMP_WORKFLOW_DIR = './workflows/temp'
# 从命令行获取输入参数
if len(sys.argv) != 5:
print("用法: python test.py <prompt> <width> <height> <batch_size>")
print("示例: python test.py \"南开大学图书馆,大雨天\" 2048 1024 1")
sys.exit(1)
system_prompt = sys.argv[1]
width = int(sys.argv[2])
height = int(sys.argv[3])
batch_size = int(sys.argv[4])
# 创建临时目录
os.makedirs(TEMP_WORKFLOW_DIR, exist_ok=True)
# 创建临时文件路径
PROCESSED_WORKFLOW = os.path.join(TEMP_WORKFLOW_DIR, f"processed_workflow_{uuid.uuid4().hex}.json")
# 1. 预处理工作流
workflow_file = preprocess_workflow(
system_prompt=system_prompt,
width=width,
height=height,
batch_size=batch_size,
input_json=DEFAULT_WORKFLOW,
output_json=PROCESSED_WORKFLOW
)
# 2. 准备输出目录
os.makedirs(WORKING_DIR, exist_ok=True)
# 创建客户端ID
client_id = str(uuid.uuid4())
print(f"系统提示: {system_prompt}")
print(f"图像尺寸: {width}x{height}")
print(f"批次大小: {batch_size}")
print(f"工作流文件: {workflow_file}")
print(f"客户端ID: {client_id}")
print(f"开始使用ComfyUI生成图像: {COMFYUI_ENDPOINT}")
start_time = time.time()
# 生成图像
print(f"\n===== 开始生成图像 =====")
saved_files = generate_images(
workflow_file=workflow_file,
server_address=COMFYUI_ENDPOINT,
output_dir=WORKING_DIR,
client_id=client_id
)
# 输出结果
elapsed = time.time() - start_time
print(f"\n处理完成,耗时 {elapsed:.2f}")
print(f"共生成 {len(saved_files)} 张图像")
print(f"保存位置: {WORKING_DIR}")