220 lines
8.2 KiB
Python
220 lines
8.2 KiB
Python
import json
|
||
from websocket import create_connection, WebSocketTimeoutException, WebSocketConnectionClosedException
|
||
import uuid
|
||
import urllib.request
|
||
import urllib.parse
|
||
import random
|
||
import time
|
||
from datetime import datetime
|
||
|
||
# 定义一个函数来显示GIF图片
|
||
def show_gif(fname):
|
||
import base64
|
||
from IPython import display
|
||
with open(fname, 'rb') as fd:
|
||
b64 = base64.b64encode(fd.read()).decode('ascii')
|
||
return display.HTML(f'<img src="data:image/gif;base64,{b64}" />')
|
||
|
||
# 定义一个函数向服务器队列发送提示信息
|
||
def queue_prompt(prompt):
|
||
p = {"prompt": prompt, "client_id": client_id}
|
||
data = json.dumps(p).encode('utf-8')
|
||
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
|
||
return json.loads(urllib.request.urlopen(req).read())
|
||
|
||
# 定义一个函数来获取图片
|
||
def get_image(filename, subfolder, folder_type):
|
||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||
url_values = urllib.parse.urlencode(data)
|
||
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
|
||
return response.read()
|
||
|
||
# 定义一个函数来获取历史记录
|
||
def get_history(prompt_id):
|
||
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
|
||
return json.loads(response.read())
|
||
|
||
# 定义一个函数来获取图片,这涉及到监听WebSocket消息
|
||
def get_images(ws, prompt):
|
||
prompt_id = queue_prompt(prompt)['prompt_id']
|
||
print('Prompt: ', prompt)
|
||
print('Prompt ID: ', prompt_id)
|
||
output_images = {}
|
||
|
||
# 等待执行完成
|
||
start_time = time.time()
|
||
while True:
|
||
if time.time() - start_time > 1200: # 设置2分钟超时
|
||
print("超时:等待执行完成超过120秒")
|
||
break
|
||
|
||
out = ws.recv()
|
||
if isinstance(out, str):
|
||
message = json.loads(out)
|
||
if message['type'] == 'executing':
|
||
data = message['data']
|
||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||
print('Execution complete')
|
||
break # 执行完成
|
||
else:
|
||
continue # 预览为二进制数据
|
||
|
||
# 获取完成的历史记录
|
||
history = get_history(prompt_id).get(prompt_id, {})
|
||
|
||
if not history:
|
||
print("未找到该提示的历史记录")
|
||
return {}
|
||
|
||
for node_id in history['outputs']:
|
||
node_output = history['outputs'][node_id]
|
||
# 图片分支
|
||
if 'images' in node_output:
|
||
images_output = []
|
||
for image in node_output['images']:
|
||
try:
|
||
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
||
images_output.append(image_data)
|
||
except Exception as e:
|
||
print(f"获取图像错误: {str(e)}")
|
||
output_images[node_id] = images_output
|
||
# 视频分支
|
||
elif 'videos' in node_output:
|
||
videos_output = []
|
||
for video in node_output['videos']:
|
||
try:
|
||
video_data = get_image(video['filename'], video['subfolder'], video['type'])
|
||
videos_output.append(video_data)
|
||
except Exception as e:
|
||
print(f"获取视频错误: {str(e)}")
|
||
output_images[node_id] = videos_output
|
||
|
||
print(f'Obtained {len(output_images)} image/video sets')
|
||
return output_images
|
||
|
||
# 解析工作流并获取图片
|
||
def parse_workflow(prompt, seed, workflowfile):
|
||
print(f'Workflow file: {workflowfile}')
|
||
try:
|
||
with open(workflowfile, 'r', encoding="utf-8") as f:
|
||
prompt_data = json.load(f)
|
||
# 设置文本提示
|
||
prompt_data["6"]["inputs"]["text"] = prompt
|
||
# 设置随机种子(如果需要)
|
||
if "Ksampler" in prompt_data:
|
||
if "seed" in prompt_data["Ksampler"]["inputs"]:
|
||
prompt_data["Ksampler"]["inputs"]["seed"] = seed
|
||
elif "noise_seed" in prompt_data["Ksampler"]["inputs"]:
|
||
prompt_data["Ksampler"]["inputs"]["noise_seed"] = seed
|
||
return prompt_data
|
||
except Exception as e:
|
||
print(f"工作流解析错误: {str(e)}")
|
||
return {}
|
||
|
||
# 生成图像并保存
|
||
def generate_clip(prompt, seed, workflowfile, idx):
|
||
print(f'Processing prompt #{idx}: "{prompt[:50]}{"..." if len(prompt) > 50 else ""}"')
|
||
print(f'Using seed: {seed}')
|
||
|
||
try:
|
||
# 使用正确的WebSocket连接方式
|
||
ws_url = f"ws://{server_address}/ws?clientId={client_id}"
|
||
|
||
# 使用 create_connection
|
||
ws = create_connection(ws_url, timeout=600)
|
||
|
||
# 解析工作流
|
||
workflow_data = parse_workflow(prompt, seed, workflowfile)
|
||
if not workflow_data:
|
||
print("工作流数据为空")
|
||
return
|
||
|
||
# 获取图像
|
||
images = get_images(ws, workflow_data)
|
||
|
||
# 关闭连接
|
||
ws.close()
|
||
except WebSocketTimeoutException as e:
|
||
print(f"WebSocket连接超时: {str(e)}")
|
||
return
|
||
except WebSocketConnectionClosedException as e:
|
||
print(f"WebSocket连接已关闭: {str(e)}")
|
||
return
|
||
except Exception as e:
|
||
print(f"WebSocket错误: {str(e)}")
|
||
return
|
||
|
||
saved_files = []
|
||
if images:
|
||
for node_id, image_list in images.items():
|
||
for i, image_data in enumerate(image_list):
|
||
# 格式化时间戳
|
||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
# 创建唯一文件名
|
||
filename = f"{idx}_{seed}_{timestamp}_{i}.png"
|
||
file_path = f"{WORKING_DIR}/{filename}"
|
||
|
||
print(f'Saving to: {file_path}')
|
||
try:
|
||
with open(file_path, "wb") as f:
|
||
f.write(image_data)
|
||
saved_files.append(file_path)
|
||
except Exception as e:
|
||
print(f"文件保存错误: {str(e)}")
|
||
else:
|
||
print("未获取到图像数据")
|
||
|
||
if saved_files:
|
||
print(f"成功生成 {len(saved_files)} 张图片")
|
||
else:
|
||
print("未保存任何图片")
|
||
|
||
# 直接在代码中定义提示词列表
|
||
PROMPTS = [
|
||
"A beautiful sunset over the ocean, realistic, cinematic lighting",
|
||
"A futuristic cityscape at night, cyberpunk style, neon lights",
|
||
"An ancient forest with magical creatures, fantasy, photorealistic",
|
||
"A steampunk laboratory with bubbling beakers and intricate machinery",
|
||
"Abstract geometric patterns in vibrant colors, digital art",
|
||
"A majestic lion in the African savannah, golden hour lighting",
|
||
"A cozy cabin in the mountains during winter, warm lights inside",
|
||
"A detailed close-up of a butterfly on a flower, macro photography",
|
||
"Underwater scene with coral reef and tropical fish, crystal clear water"
|
||
]
|
||
|
||
if __name__ == "__main__":
|
||
# 设置工作目录和项目相关的路径
|
||
WORKING_DIR = 'output'
|
||
workflowfile = './workflows/flux_redux.json'
|
||
COMFYUI_ENDPOINT = '127.0.0.1:8188'
|
||
|
||
# 服务器配置
|
||
global server_address, client_id
|
||
server_address = COMFYUI_ENDPOINT
|
||
client_id = str(uuid.uuid4()) # 生成一个唯一的客户端ID
|
||
|
||
# 种子设置 - 可以选择固定或随机
|
||
USE_RANDOM_SEED = True # 设为False则使用固定种子
|
||
base_seed = 15465856
|
||
|
||
print(f"Starting image generation with ComfyUI at {server_address}")
|
||
print(f"Working directory: {WORKING_DIR}")
|
||
print(f"Workflow file: {workflowfile}")
|
||
print(f"Client ID: {client_id}")
|
||
|
||
start_time = time.time()
|
||
|
||
# 处理每个提示词
|
||
for idx, prompt in enumerate(PROMPTS, start=1):
|
||
# 设置种子
|
||
current_seed = random.randint(1, 10**8) if USE_RANDOM_SEED else base_seed
|
||
|
||
print(f"\n===== Processing Prompt #{idx} of {len(PROMPTS)} =====")
|
||
generate_clip(prompt, current_seed, workflowfile, idx)
|
||
|
||
# 添加延迟以避免服务器过载
|
||
time.sleep(2) # 2秒延迟
|
||
|
||
elapsed = time.time() - start_time
|
||
print(f"\nProcessing completed in {elapsed:.2f} seconds")
|
||
print(f"Generated images for {len(PROMPTS)} prompts") |