import json import os import sys import time import uuid import random from datetime import datetime from websocket import create_connection, WebSocketTimeoutException, WebSocketConnectionClosedException import urllib.request import urllib.parse def preprocess_workflow(system_prompt, width, height, batch_size, input_json='flux_work.json', output_json='processed_workflow.json'): """ 预处理工作流文件,更新系统提示和图像参数 """ try: with open(input_json, 'r') as f: workflow = json.load(f) # 更新系统提示 workflow['31']['inputs']['system_prompt'] = system_prompt # 更新图像参数 workflow['27']['inputs']['width'] = str(width) workflow['27']['inputs']['height'] = str(height) workflow['27']['inputs']['batch_size'] = str(batch_size) # 保存更新后的工作流 with open(output_json, 'w') as f: json.dump(workflow, f, indent=2) print(f"工作流已更新并保存到: {output_json}") return output_json except Exception as e: print(f"预处理工作流出错: {str(e)}") sys.exit(1) def queue_prompt(prompt, server_address, client_id): """向服务器队列发送提示信息""" p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p).encode('utf-8') req = urllib.request.Request(f"http://{server_address}/prompt", data=data) return json.loads(urllib.request.urlopen(req).read()) def get_image(filename, subfolder, folder_type, server_address): """获取生成的图像""" data = {"filename": filename, "subfolder": subfolder, "type": folder_type} url_values = urllib.parse.urlencode(data) with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response: return response.read() def get_history(prompt_id, server_address): """获取历史记录""" with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response: return json.loads(response.read()) def get_images(ws, prompt, server_address, client_id, timeout=600): """获取生成的所有图像""" prompt_id = queue_prompt(prompt, server_address, client_id)['prompt_id'] print(f'提示ID: {prompt_id}') output_images = {} start_time = time.time() while True: if time.time() - start_time > timeout: print(f"超时:等待执行超过{timeout}秒") break try: out = ws.recv() if isinstance(out, str): message = json.loads(out) if message['type'] == 'executing': data = message['data'] if data['node'] is None and data['prompt_id'] == prompt_id: print('执行完成') break except Exception as e: print(f"接收消息出错: {str(e)}") break history = get_history(prompt_id, server_address).get(prompt_id, {}) if not history: print("未找到该提示的历史记录") return {} for node_id, node_output in history['outputs'].items(): if 'images' in node_output: images_output = [] for image in node_output['images']: try: image_data = get_image(image['filename'], image['subfolder'], image['type'], server_address) images_output.append({ 'data': image_data, 'filename': image['filename'], 'subfolder': image['subfolder'], 'type': image['type'] }) except Exception as e: print(f"获取图像错误: {str(e)}") output_images[node_id] = images_output print(f'获取到 {len(output_images)} 组图像输出') return output_images def generate_images(workflow_file, server_address, output_dir, client_id): """生成图像主函数""" try: # 加载工作流 with open(workflow_file, 'r', encoding='utf-8') as f: workflow_data = json.load(f) # 使用随机种子 seed = random.randint(1, 10**8) print(f'使用种子: {seed}') # 更新种子 workflow_data['25']['inputs']['noise_seed'] = seed # 创建WebSocket连接 ws_url = f"ws://{server_address}/ws?clientId={client_id}" ws = create_connection(ws_url, timeout=600) # 获取图像 images = get_images(ws, workflow_data, server_address, client_id) ws.close() # 保存图像 saved_files = [] if images: for node_id, image_list in images.items(): for i, img in enumerate(image_list): timestamp = datetime.now().strftime("%Y%m%d%H%M%S") filename = f"{seed}_{timestamp}_{i}.png" file_path = os.path.join(output_dir, filename) try: with open(file_path, "wb") as f: f.write(img['data']) saved_files.append(file_path) print(f'已保存: {file_path}') except Exception as e: print(f"保存图像错误: {str(e)}") return saved_files except Exception as e: print(f"生成图像出错: {str(e)}") return [] if __name__ == "__main__": # 配置参数 WORKING_DIR = 'output' COMFYUI_ENDPOINT = '127.0.0.1:8188' DEFAULT_WORKFLOW = './workflows/flux_work.json' TEMP_WORKFLOW_DIR = './workflows/temp' # 从命令行获取输入参数 if len(sys.argv) != 5: print("用法: python test.py ") print("示例: python test.py \"南开大学图书馆,大雨天\" 2048 1024 1") sys.exit(1) system_prompt = sys.argv[1] width = int(sys.argv[2]) height = int(sys.argv[3]) batch_size = int(sys.argv[4]) # 创建临时目录 os.makedirs(TEMP_WORKFLOW_DIR, exist_ok=True) # 创建临时文件路径 PROCESSED_WORKFLOW = os.path.join(TEMP_WORKFLOW_DIR, f"processed_workflow_{uuid.uuid4().hex}.json") # 1. 预处理工作流 workflow_file = preprocess_workflow( system_prompt=system_prompt, width=width, height=height, batch_size=batch_size, input_json=DEFAULT_WORKFLOW, output_json=PROCESSED_WORKFLOW ) # 2. 准备输出目录 os.makedirs(WORKING_DIR, exist_ok=True) # 创建客户端ID client_id = str(uuid.uuid4()) print(f"系统提示: {system_prompt}") print(f"图像尺寸: {width}x{height}") print(f"批次大小: {batch_size}") print(f"工作流文件: {workflow_file}") print(f"客户端ID: {client_id}") print(f"开始使用ComfyUI生成图像: {COMFYUI_ENDPOINT}") start_time = time.time() # 生成图像 print(f"\n===== 开始生成图像 =====") saved_files = generate_images( workflow_file=workflow_file, server_address=COMFYUI_ENDPOINT, output_dir=WORKING_DIR, client_id=client_id ) # 输出结果 elapsed = time.time() - start_time print(f"\n处理完成,耗时 {elapsed:.2f} 秒") print(f"共生成 {len(saved_files)} 张图像") print(f"保存位置: {WORKING_DIR}")