comfyui/test.py
2025-06-08 17:44:58 +08:00

220 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from websocket import create_connection, WebSocketTimeoutException, WebSocketConnectionClosedException
import uuid
import urllib.request
import urllib.parse
import random
import time
from datetime import datetime
# 定义一个函数来显示GIF图片
def show_gif(fname):
import base64
from IPython import display
with open(fname, 'rb') as fd:
b64 = base64.b64encode(fd.read()).decode('ascii')
return display.HTML(f'<img src="data:image/gif;base64,{b64}" />')
# 定义一个函数向服务器队列发送提示信息
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
# 定义一个函数来获取图片
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
# 定义一个函数来获取历史记录
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
# 定义一个函数来获取图片这涉及到监听WebSocket消息
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id']
print('Prompt: ', prompt)
print('Prompt ID: ', prompt_id)
output_images = {}
# 等待执行完成
start_time = time.time()
while True:
if time.time() - start_time > 1200: # 设置2分钟超时
print("超时等待执行完成超过120秒")
break
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
print('Execution complete')
break # 执行完成
else:
continue # 预览为二进制数据
# 获取完成的历史记录
history = get_history(prompt_id).get(prompt_id, {})
if not history:
print("未找到该提示的历史记录")
return {}
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
# 图片分支
if 'images' in node_output:
images_output = []
for image in node_output['images']:
try:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
except Exception as e:
print(f"获取图像错误: {str(e)}")
output_images[node_id] = images_output
# 视频分支
elif 'videos' in node_output:
videos_output = []
for video in node_output['videos']:
try:
video_data = get_image(video['filename'], video['subfolder'], video['type'])
videos_output.append(video_data)
except Exception as e:
print(f"获取视频错误: {str(e)}")
output_images[node_id] = videos_output
print(f'Obtained {len(output_images)} image/video sets')
return output_images
# 解析工作流并获取图片
def parse_workflow(prompt, seed, workflowfile):
print(f'Workflow file: {workflowfile}')
try:
with open(workflowfile, 'r', encoding="utf-8") as f:
prompt_data = json.load(f)
# 设置文本提示
prompt_data["6"]["inputs"]["text"] = prompt
# 设置随机种子(如果需要)
if "Ksampler" in prompt_data:
if "seed" in prompt_data["Ksampler"]["inputs"]:
prompt_data["Ksampler"]["inputs"]["seed"] = seed
elif "noise_seed" in prompt_data["Ksampler"]["inputs"]:
prompt_data["Ksampler"]["inputs"]["noise_seed"] = seed
return prompt_data
except Exception as e:
print(f"工作流解析错误: {str(e)}")
return {}
# 生成图像并保存
def generate_clip(prompt, seed, workflowfile, idx):
print(f'Processing prompt #{idx}: "{prompt[:50]}{"..." if len(prompt) > 50 else ""}"')
print(f'Using seed: {seed}')
try:
# 使用正确的WebSocket连接方式
ws_url = f"ws://{server_address}/ws?clientId={client_id}"
# 使用 create_connection
ws = create_connection(ws_url, timeout=600)
# 解析工作流
workflow_data = parse_workflow(prompt, seed, workflowfile)
if not workflow_data:
print("工作流数据为空")
return
# 获取图像
images = get_images(ws, workflow_data)
# 关闭连接
ws.close()
except WebSocketTimeoutException as e:
print(f"WebSocket连接超时: {str(e)}")
return
except WebSocketConnectionClosedException as e:
print(f"WebSocket连接已关闭: {str(e)}")
return
except Exception as e:
print(f"WebSocket错误: {str(e)}")
return
saved_files = []
if images:
for node_id, image_list in images.items():
for i, image_data in enumerate(image_list):
# 格式化时间戳
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
# 创建唯一文件名
filename = f"{idx}_{seed}_{timestamp}_{i}.png"
file_path = f"{WORKING_DIR}/{filename}"
print(f'Saving to: {file_path}')
try:
with open(file_path, "wb") as f:
f.write(image_data)
saved_files.append(file_path)
except Exception as e:
print(f"文件保存错误: {str(e)}")
else:
print("未获取到图像数据")
if saved_files:
print(f"成功生成 {len(saved_files)} 张图片")
else:
print("未保存任何图片")
# 直接在代码中定义提示词列表
PROMPTS = [
"A beautiful sunset over the ocean, realistic, cinematic lighting",
"A futuristic cityscape at night, cyberpunk style, neon lights",
"An ancient forest with magical creatures, fantasy, photorealistic",
"A steampunk laboratory with bubbling beakers and intricate machinery",
"Abstract geometric patterns in vibrant colors, digital art",
"A majestic lion in the African savannah, golden hour lighting",
"A cozy cabin in the mountains during winter, warm lights inside",
"A detailed close-up of a butterfly on a flower, macro photography",
"Underwater scene with coral reef and tropical fish, crystal clear water"
]
if __name__ == "__main__":
# 设置工作目录和项目相关的路径
WORKING_DIR = 'output'
workflowfile = './workflows/flux_redux.json'
COMFYUI_ENDPOINT = '127.0.0.1:8188'
# 服务器配置
global server_address, client_id
server_address = COMFYUI_ENDPOINT
client_id = str(uuid.uuid4()) # 生成一个唯一的客户端ID
# 种子设置 - 可以选择固定或随机
USE_RANDOM_SEED = True # 设为False则使用固定种子
base_seed = 15465856
print(f"Starting image generation with ComfyUI at {server_address}")
print(f"Working directory: {WORKING_DIR}")
print(f"Workflow file: {workflowfile}")
print(f"Client ID: {client_id}")
start_time = time.time()
# 处理每个提示词
for idx, prompt in enumerate(PROMPTS, start=1):
# 设置种子
current_seed = random.randint(1, 10**8) if USE_RANDOM_SEED else base_seed
print(f"\n===== Processing Prompt #{idx} of {len(PROMPTS)} =====")
generate_clip(prompt, current_seed, workflowfile, idx)
# 添加延迟以避免服务器过载
time.sleep(2) # 2秒延迟
elapsed = time.time() - start_time
print(f"\nProcessing completed in {elapsed:.2f} seconds")
print(f"Generated images for {len(PROMPTS)} prompts")