260 lines
9.0 KiB
Python
260 lines
9.0 KiB
Python
'''
|
|
* @file GenerateImages.py
|
|
* @brief Generate images from text descriptions.
|
|
*
|
|
* @author WuYingwen
|
|
* @GitHub @wuyingwen10
|
|
* @Contact 2211537@mail.nankai.edu.cn
|
|
* @date 2025-06-08
|
|
* @version v1.0.2
|
|
*
|
|
* @details
|
|
* Core functionality:
|
|
* - Generate images from user-provided text prompts;
|
|
* - Generate optimized prompt statements based on user input;
|
|
* - Provides a clean interface for generating images from text prompts
|
|
*
|
|
* @note
|
|
* - For local server connections: Update COMFYUI_ENDPOINT with target address + port;
|
|
* - Timestamps prevent filename conflicts during image generation,perform simple string matching if necessary;
|
|
* - Filename format: "<file_prefix>_<timestamp>.png" (e.g., mountains_202507011210902.png);
|
|
* - Contact author for technical support;
|
|
*
|
|
* @interface
|
|
* generate_images_interface(
|
|
* user_topic: str,
|
|
* width: int,
|
|
* height: int,
|
|
* batch_size: int,
|
|
* file_prefix: str
|
|
* ) -> tuple
|
|
*
|
|
* @copyright
|
|
* (c) 2025 Nankai University
|
|
'''
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import uuid
|
|
import random
|
|
from datetime import datetime
|
|
from websocket import create_connection, WebSocketTimeoutException, WebSocketConnectionClosedException
|
|
import urllib.request
|
|
import urllib.parse
|
|
from DeepSeekPromptGenerator import generate_prompt, save_deepseek_output
|
|
|
|
|
|
def generate_images_interface(user_topic, width, height, batch_size, file_prefix):
|
|
deepseek_output_path = ""
|
|
try:
|
|
system_prompt = generate_prompt(user_topic)
|
|
deepseek_output_path = save_deepseek_output(system_prompt, file_prefix)
|
|
except Exception as e:
|
|
print(f"Prompt optimization failed: {str(e)}")
|
|
|
|
output_dir = 'output'
|
|
comfyui_server = '127.0.0.1:8188'
|
|
default_workflow = './workflows/flux_work.json'
|
|
temp_dir = './workflows/temp'
|
|
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
processed_workflow = os.path.join(temp_dir, f"processed_workflow_{uuid.uuid4().hex}.json")
|
|
|
|
workflow_file = preprocess_workflow(
|
|
system_prompt=user_topic,
|
|
width=width,
|
|
height=height,
|
|
batch_size=batch_size,
|
|
input_json=default_workflow,
|
|
output_json=processed_workflow
|
|
)
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
client_id = str(uuid.uuid4())
|
|
image_files = generate_images(
|
|
workflow_file=workflow_file,
|
|
server_address=comfyui_server,
|
|
output_dir=output_dir,
|
|
client_id=client_id,
|
|
file_prefix=file_prefix
|
|
)
|
|
|
|
return (deepseek_output_path, image_files)
|
|
|
|
|
|
def generate_images_info(user_input_analysis, system_prompt):
|
|
width = user_input_analysis.get('width', 1024)
|
|
height = user_input_analysis.get('height', 768)
|
|
batch_size = user_input_analysis.get('batch_size', 1)
|
|
file_prefix = user_input_analysis.get('file_prefix', 'image')
|
|
|
|
OUTPUT_DIR = 'output'
|
|
COMFYUI_SERVER = '127.0.0.1:8188'
|
|
DEFAULT_WORKFLOW = './workflows/flux_work.json'
|
|
TEMP_DIR = './workflows/temp'
|
|
|
|
os.makedirs(TEMP_DIR, exist_ok=True)
|
|
PROCESSED_WORKFLOW = os.path.join(TEMP_DIR, f"processed_workflow_{uuid.uuid4().hex}.json")
|
|
|
|
workflow_file = preprocess_workflow(
|
|
system_prompt=system_prompt,
|
|
width=width,
|
|
height=height,
|
|
batch_size=batch_size,
|
|
input_json=DEFAULT_WORKFLOW,
|
|
output_json=PROCESSED_WORKFLOW
|
|
)
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
client_id = str(uuid.uuid4())
|
|
|
|
image_files = generate_images(
|
|
workflow_file=workflow_file,
|
|
server_address=COMFYUI_SERVER,
|
|
output_dir=OUTPUT_DIR,
|
|
client_id=client_id,
|
|
file_prefix=file_prefix
|
|
)
|
|
|
|
image_list = []
|
|
for file_path in image_files:
|
|
filename = os.path.basename(file_path)
|
|
name_without_ext = os.path.splitext(filename)[0]
|
|
|
|
image_info = {
|
|
"image_name": name_without_ext,
|
|
"image_type": "png",
|
|
"image_description": system_prompt,
|
|
"image_dimensions": f"{width}x{height}"
|
|
}
|
|
image_list.append(image_info)
|
|
|
|
return image_list
|
|
|
|
def preprocess_workflow(system_prompt, width, height, batch_size, input_json='flux_work.json', output_json='processed_workflow.json'):
|
|
try:
|
|
with open(input_json, 'r', encoding='utf-8') as f:
|
|
workflow = json.load(f)
|
|
|
|
workflow['31']['inputs']['system_prompt'] = system_prompt
|
|
|
|
workflow['27']['inputs']['width'] = str(width)
|
|
workflow['27']['inputs']['height'] = str(height)
|
|
workflow['27']['inputs']['batch_size'] = str(batch_size)
|
|
|
|
with open(output_json, 'w', encoding='utf-8') as f:
|
|
json.dump(workflow, f, indent=2, ensure_ascii=False)
|
|
|
|
print(f"Workflow updated and saved to: {output_json}")
|
|
return output_json
|
|
|
|
except Exception as e:
|
|
print(f"Error processing workflow: {str(e)}")
|
|
sys.exit(1)
|
|
|
|
def queue_prompt(prompt, server_address, client_id):
|
|
payload = {"prompt": prompt, "client_id": client_id}
|
|
data = json.dumps(payload).encode('utf-8')
|
|
request = urllib.request.Request(f"http://{server_address}/prompt", data=data)
|
|
return json.loads(urllib.request.urlopen(request).read())
|
|
|
|
def get_image(filename, subfolder, folder_type, server_address):
|
|
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
|
url_values = urllib.parse.urlencode(params)
|
|
with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response:
|
|
return response.read()
|
|
|
|
def get_history(prompt_id, server_address):
|
|
with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response:
|
|
return json.loads(response.read())
|
|
|
|
def retrieve_images(ws, prompt, server_address, client_id, timeout=600):
|
|
prompt_id = queue_prompt(prompt, server_address, client_id)['prompt_id']
|
|
print(f'Prompt ID: {prompt_id}')
|
|
images_data = {}
|
|
|
|
start_time = time.time()
|
|
while True:
|
|
if time.time() - start_time > timeout:
|
|
print(f"Timeout: Execution took longer than {timeout} seconds")
|
|
break
|
|
|
|
try:
|
|
data = ws.recv()
|
|
if isinstance(data, str):
|
|
message = json.loads(data)
|
|
if message['type'] == 'executing':
|
|
content = message['data']
|
|
if content['node'] is None and content['prompt_id'] == prompt_id:
|
|
print('Execution completed')
|
|
break
|
|
except Exception as e:
|
|
print(f"Error receiving data: {str(e)}")
|
|
break
|
|
|
|
history = get_history(prompt_id, server_address).get(prompt_id, {})
|
|
if not history:
|
|
print("No history found for this prompt")
|
|
return {}
|
|
|
|
for node_id, node_data in history['outputs'].items():
|
|
if 'images' in node_data:
|
|
image_collection = []
|
|
for image in node_data['images']:
|
|
try:
|
|
img_data = get_image(image['filename'], image['subfolder'], image['type'], server_address)
|
|
image_collection.append({
|
|
'data': img_data,
|
|
'filename': image['filename'],
|
|
'subfolder': image['subfolder'],
|
|
'type': image['type']
|
|
})
|
|
except Exception as e:
|
|
print(f"Error retrieving image: {str(e)}")
|
|
images_data[node_id] = image_collection
|
|
|
|
print(f'Retrieved {len(images_data)} image outputs')
|
|
return images_data
|
|
|
|
def generate_images(workflow_file, server_address, output_dir, client_id, file_prefix="image"):
|
|
try:
|
|
with open(workflow_file, 'r', encoding='utf-8') as f:
|
|
workflow = json.load(f)
|
|
|
|
seed = random.randint(1, 10**8)
|
|
print(f'Using seed: {seed}')
|
|
|
|
workflow['25']['inputs']['noise_seed'] = seed
|
|
|
|
ws_url = f"ws://{server_address}/ws?clientId={client_id}"
|
|
ws = create_connection(ws_url, timeout=600)
|
|
|
|
images = retrieve_images(ws, workflow, server_address, client_id)
|
|
ws.close()
|
|
|
|
saved_files = []
|
|
if images:
|
|
for node_id, img_list in images.items():
|
|
for i, img in enumerate(img_list):
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
|
|
|
filename = f"{file_prefix}_{timestamp}.png"
|
|
file_path = os.path.join(output_dir, filename)
|
|
|
|
try:
|
|
with open(file_path, "wb") as f:
|
|
f.write(img['data'])
|
|
saved_files.append(file_path)
|
|
print(f'Saved: {file_path}')
|
|
except Exception as e:
|
|
print(f"Error saving image: {str(e)}")
|
|
|
|
return saved_files
|
|
|
|
except Exception as e:
|
|
print(f"Image generation failed: {str(e)}")
|
|
return [] |