diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index c56283c2..1735fd83 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -17,7 +17,7 @@ jobs: path: "ComfyUI" - uses: actions/setup-python@v4 with: - python-version: '3.9' + python-version: '3.10' - name: Install requirements run: | python -m pip install --upgrade pip diff --git a/.github/workflows/update-api-stubs.yml b/.github/workflows/update-api-stubs.yml index 2ae99b67..c99ec9fc 100644 --- a/.github/workflows/update-api-stubs.yml +++ b/.github/workflows/update-api-stubs.yml @@ -22,10 +22,19 @@ jobs: run: | python -m pip install --upgrade pip pip install 'datamodel-code-generator[http]' + npm install @redocly/cli + + - name: Download OpenAPI spec + run: | + curl -o openapi.yaml https://api.comfy.org/openapi + + - name: Filter OpenAPI spec with Redocly + run: | + npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components - name: Generate API models run: | - datamodel-codegen --use-subclass-enum --url https://api.comfy.org/openapi --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel + datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel - name: Check for changes id: git-check @@ -44,4 +53,4 @@ jobs: Generated automatically by the a Github workflow. branch: update-api-stubs delete-branch: true - base: main + base: master diff --git a/.gitignore b/.gitignore index 61881b8a..4e8cea71 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,6 @@ venv/ *.log web_custom_versions/ .DS_Store +openapi.yaml +filtered-openapi.yaml +uv.lock diff --git a/comfy/cli_args.py b/comfy/cli_args.py index f89a7aab..ef5ab627 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -192,6 +192,13 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.") +parser.add_argument( + "--comfy-api-base", + type=str, + default="https://api.comfy.org", + help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", +) + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md new file mode 100644 index 00000000..e2633a76 --- /dev/null +++ b/comfy_api_nodes/README.md @@ -0,0 +1,41 @@ +# ComfyUI API Nodes + +## Introduction + +Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes). + +## Development + +While developing, you should be testing against the Staging environment. To test against staging: + +**Install ComfyUI_frontend** + +Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication. + +> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication. + +```bash +python run main.py --comfy-api-base https://stagingapi.comfy.org +``` + +API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes. + +### Redocly Instructions + +**Tip** +When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet. + +Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging. + +```bash +# Download the OpenAPI file from prod server. +curl -o openapi.yaml https://stagingapi.comfy.org/openapi + +# Filter out unneeded API definitions. +npm install -g @redocly/cli +redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components + +# Generate the pydantic datamodels for validation. +datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel + +``` diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py new file mode 100644 index 00000000..bd3b8908 --- /dev/null +++ b/comfy_api_nodes/apinode_utils.py @@ -0,0 +1,575 @@ +import io +import logging +from typing import Optional +from comfy.utils import common_upscale +from comfy_api.input_impl import VideoFromFile +from comfy_api.util import VideoContainer, VideoCodec +from comfy_api.input.video_types import VideoInput +from comfy_api.input.basic_types import AudioInput +from comfy_api_nodes.apis.client import ( + ApiClient, + ApiEndpoint, + HttpMethod, + SynchronousOperation, + UploadRequest, + UploadResponse, +) + + +import numpy as np +from PIL import Image +import requests +import torch +import math +import base64 +import uuid +from io import BytesIO +import av + + +def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: + """Downloads a video from a URL and returns a `VIDEO` output. + + Args: + video_url: The URL of the video to download. + + Returns: + A Comfy node `VIDEO` output. + """ + video_io = download_url_to_bytesio(video_url, timeout) + if video_io is None: + error_msg = f"Failed to download video from {video_url}" + logging.error(error_msg) + raise ValueError(error_msg) + return VideoFromFile(video_io) + + +def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: + """Downscale input image tensor to roughly the specified total pixels.""" + samples = image.movedim(-1, 1) + total = int(total_pixels) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + if scale_by >= 1: + return image + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = common_upscale(samples, width, height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + +def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: + """Validates and casts a response to a torch.Tensor. + + Args: + response: The response to validate and cast. + timeout: Request timeout in seconds. Defaults to None (no timeout). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + ValueError: If the response is not valid. + """ + # validate raw JSON response + data = response.data + if not data or len(data) == 0: + raise ValueError("No images returned from API endpoint") + + # Initialize list to store image tensors + image_tensors: list[torch.Tensor] = [] + + # Process each image in the data array + for image_data in data: + image_url = image_data.url + b64_data = image_data.b64_json + + if not image_url and not b64_data: + raise ValueError("No image was generated in the response") + + if b64_data: + img_data = base64.b64decode(b64_data) + img = Image.open(io.BytesIO(img_data)) + + elif image_url: + img_response = requests.get(image_url, timeout=timeout) + if img_response.status_code != 200: + raise ValueError("Failed to download the image") + img = Image.open(io.BytesIO(img_response.content)) + + img = img.convert("RGBA") + + # Convert to numpy array, normalize to float32 between 0 and 1 + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array) + + # Add to list of tensors + image_tensors.append(img_tensor) + + return torch.stack(image_tensors, dim=0) + + +def validate_aspect_ratio( + aspect_ratio: str, + minimum_ratio: float, + maximum_ratio: float, + minimum_ratio_str: str, + maximum_ratio_str: str, +) -> float: + """Validates and casts an aspect ratio string to a float. + + Args: + aspect_ratio: The aspect ratio string to validate. + minimum_ratio: The minimum aspect ratio. + maximum_ratio: The maximum aspect ratio. + minimum_ratio_str: The minimum aspect ratio string. + maximum_ratio_str: The maximum aspect ratio string. + + Returns: + The validated and cast aspect ratio. + + Raises: + Exception: If the aspect ratio is not valid. + """ + # get ratio values + numbers = aspect_ratio.split(":") + if len(numbers) != 2: + raise TypeError( + f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}." + ) + try: + numerator = int(numbers[0]) + denominator = int(numbers[1]) + except ValueError as exc: + raise TypeError( + f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}." + ) from exc + calculated_ratio = numerator / denominator + # if not close to minimum and maximum, check bounds + if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose( + calculated_ratio, maximum_ratio + ): + if calculated_ratio < minimum_ratio: + raise TypeError( + f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." + ) + elif calculated_ratio > maximum_ratio: + raise TypeError( + f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." + ) + return aspect_ratio + + +def mimetype_to_extension(mime_type: str) -> str: + """Converts a MIME type to a file extension.""" + return mime_type.split("/")[-1].lower() + + +def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: + """Downloads content from a URL using requests and returns it as BytesIO. + + Args: + url: The URL to download. + timeout: Request timeout in seconds. Defaults to None (no timeout). + + Returns: + BytesIO object containing the downloaded content. + """ + response = requests.get(url, stream=True, timeout=timeout) + response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) + return BytesIO(response.content) + + +def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: + """Converts image data from BytesIO to a torch.Tensor. + + Args: + image_bytesio: BytesIO object containing the image data. + mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + PIL.UnidentifiedImageError: If the image data cannot be identified. + ValueError: If the specified mode is invalid. + """ + image = Image.open(image_bytesio) + image = image.convert(mode) + image_array = np.array(image).astype(np.float32) / 255.0 + return torch.from_numpy(image_array).unsqueeze(0) + + +def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: + """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" + image_bytesio = download_url_to_bytesio(url, timeout) + return bytesio_to_image_tensor(image_bytesio) + +def process_image_response(response: requests.Response) -> torch.Tensor: + """Uses content from a Response object and converts it to a torch.Tensor""" + return bytesio_to_image_tensor(BytesIO(response.content)) + + +def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: + """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" + if len(image.shape) > 3: + image = image[0] + # TODO: remove alpha if not allowed and present + input_tensor = image.cpu() + input_tensor = downscale_image_tensor( + input_tensor.unsqueeze(0), total_pixels=total_pixels + ).squeeze() + image_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + return img + + +def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: + """Converts a PIL Image to a BytesIO object.""" + if not mime_type: + mime_type = "image/png" + + img_byte_arr = io.BytesIO() + # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') + pil_format = mime_type.split("/")[-1].upper() + if pil_format == "JPG": + pil_format = "JPEG" + img.save(img_byte_arr, format=pil_format) + img_byte_arr.seek(0) + return img_byte_arr + + +def tensor_to_bytesio( + image: torch.Tensor, + name: Optional[str] = None, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> BytesIO: + """Converts a torch.Tensor image to a named BytesIO object. + + Args: + image: Input torch.Tensor image. + name: Optional filename for the BytesIO object. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Named BytesIO object containing the image data. + """ + if not mime_type: + mime_type = "image/png" + + pil_image = _tensor_to_pil(image, total_pixels=total_pixels) + img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type) + img_binary.name = ( + f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + ) + return img_binary + + +def tensor_to_base64_string( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Base64 encoded string of the image. + """ + pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels) + img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type) + img_bytes = img_byte_arr.getvalue() + # Encode bytes to base64 string + base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") + return base64_encoded_string + + +def tensor_to_data_uri( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Converts a tensor image to a Data URI string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). + + Returns: + Data URI string (e.g., 'data:image/png;base64,...'). + """ + base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) + return f"data:{mime_type};base64,{base64_string}" + + +def upload_file_to_comfyapi( + file_bytes_io: BytesIO, + filename: str, + upload_mime_type: str, + auth_token: Optional[str] = None, +) -> str: + """ + Uploads a single file to ComfyUI API and returns its download URL. + + Args: + file_bytes_io: BytesIO object containing the file data. + filename: The filename of the file. + upload_mime_type: MIME type of the file. + auth_token: Optional authentication token. + + Returns: + The download URL for the uploaded file. + """ + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/customers/storage", + method=HttpMethod.POST, + request_model=UploadRequest, + response_model=UploadResponse, + ), + request=request_object, + auth_token=auth_token, + ) + + response: UploadResponse = operation.execute() + upload_response = ApiClient.upload_file( + response.upload_url, file_bytes_io, content_type=upload_mime_type + ) + upload_response.raise_for_status() + + return response.download_url + + +def upload_video_to_comfyapi( + video: VideoInput, + auth_token: Optional[str] = None, + container: VideoContainer = VideoContainer.MP4, + codec: VideoCodec = VideoCodec.H264, + max_duration: Optional[int] = None, +) -> str: + """ + Uploads a single video to ComfyUI API and returns its download URL. + Uses the specified container and codec for saving the video before upload. + + Args: + video: VideoInput object (Comfy VIDEO type). + auth_token: Optional authentication token. + container: The video container format to use (default: MP4). + codec: The video codec to use (default: H264). + max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised. + + Returns: + The download URL for the uploaded video file. + """ + if max_duration is not None: + try: + actual_duration = video.duration_seconds + if actual_duration is not None and actual_duration > max_duration: + raise ValueError( + f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." + ) + except Exception as e: + logging.error(f"Error getting video duration: {e}") + raise ValueError(f"Could not verify video duration from source: {e}") from e + + upload_mime_type = f"video/{container.value.lower()}" + filename = f"uploaded_video.{container.value.lower()}" + + # Convert VideoInput to BytesIO using specified container/codec + video_bytes_io = io.BytesIO() + video.save_to(video_bytes_io, format=container, codec=codec) + video_bytes_io.seek(0) + + return upload_file_to_comfyapi( + video_bytes_io, filename, upload_mime_type, auth_token + ) + + +def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: + """ + Prepares audio waveform for av library by converting to a contiguous numpy array. + + Args: + waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. + + Returns: + Contiguous numpy array of the audio waveform. If the audio was batched, + the first item is taken. + """ + if waveform.ndim != 3 or waveform.shape[0] != 1: + raise ValueError("Expected waveform tensor shape (1, channels, samples)") + + # If batch is > 1, take first item + if waveform.shape[0] > 1: + waveform = waveform[0] + + # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array + audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() + if audio_data_np.dtype != np.float32: + audio_data_np = audio_data_np.astype(np.float32) + + return audio_data_np + + +def audio_ndarray_to_bytesio( + audio_data_np: np.ndarray, + sample_rate: int, + container_format: str = "mp4", + codec_name: str = "aac", +) -> BytesIO: + """ + Encodes a numpy array of audio data into a BytesIO object. + """ + audio_bytes_io = io.BytesIO() + with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: + audio_stream = output_container.add_stream(codec_name, rate=sample_rate) + frame = av.AudioFrame.from_ndarray( + audio_data_np, + format="fltp", + layout="stereo" if audio_data_np.shape[0] > 1 else "mono", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + for packet in audio_stream.encode(frame): + output_container.mux(packet) + + # Flush stream + for packet in audio_stream.encode(None): + output_container.mux(packet) + + audio_bytes_io.seek(0) + return audio_bytes_io + + +def upload_audio_to_comfyapi( + audio: AudioInput, + auth_token: Optional[str] = None, + container_format: str = "mp4", + codec_name: str = "aac", + mime_type: str = "audio/mp4", + filename: str = "uploaded_audio.mp4", +) -> str: + """ + Uploads a single audio input to ComfyUI API and returns its download URL. + Encodes the raw waveform into the specified format before uploading. + + Args: + audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate) + auth_token: Optional authentication token. + + Returns: + The download URL for the uploaded audio file. + """ + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio( + audio_data_np, sample_rate, container_format, codec_name + ) + + return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_token) + + +def upload_images_to_comfyapi( + image: torch.Tensor, max_images=8, auth_token=None, mime_type: Optional[str] = None +) -> list[str]: + """ + Uploads images to ComfyUI API and returns download URLs. + To upload multiple images, stack them in the batch dimension first. + + Args: + image: Input torch.Tensor image. + max_images: Maximum number of images to upload. + auth_token: Optional authentication token. + mime_type: Optional MIME type for the image. + """ + # if batch, try to upload each file if max_images is greater than 0 + idx_image = 0 + download_urls: list[str] = [] + is_batch = len(image.shape) > 3 + batch_length = 1 + if is_batch: + batch_length = image.shape[0] + while True: + curr_image = image + if len(image.shape) > 3: + curr_image = image[idx_image] + # get BytesIO version of image + img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type) + # first, request upload/download urls from comfy API + if not mime_type: + request_object = UploadRequest(file_name=img_binary.name) + else: + request_object = UploadRequest( + file_name=img_binary.name, content_type=mime_type + ) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/customers/storage", + method=HttpMethod.POST, + request_model=UploadRequest, + response_model=UploadResponse, + ), + request=request_object, + auth_token=auth_token, + ) + response = operation.execute() + + upload_response = ApiClient.upload_file( + response.upload_url, img_binary, content_type=mime_type + ) + # verify success + try: + upload_response.raise_for_status() + except requests.exceptions.HTTPError as e: + raise ValueError(f"Could not upload one or more images: {e}") from e + # add download_url to list + download_urls.append(response.download_url) + + idx_image += 1 + # stop uploading additional files if done + if is_batch and max_images > 0: + if idx_image >= max_images: + break + if idx_image >= batch_length: + break + return download_urls + + +def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor, + upscale_method="nearest-exact", crop="disabled", + allow_gradient=True, add_channel_dim=False): + """ + Resize mask to be the same dimensions as an image, while maintaining proper format for API calls. + """ + _, H, W, _ = image.shape + mask = mask.unsqueeze(-1) + mask = mask.movedim(-1,1) + mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop) + mask = mask.movedim(1,-1) + if not add_channel_dim: + mask = mask.squeeze(-1) + if not allow_gradient: + mask = (mask > 0.5).float() + return mask + + +def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None): + if strip_whitespace: + string = string.strip() + if min_length and len(string) < min_length: + raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.") + if max_length and len(string) > max_length: + raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.") + if not string: + raise Exception(f"Field '{field_name}' cannot be empty.") diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py index 29a3ab33..310c0f54 100644 --- a/comfy_api_nodes/apis/PixverseController.py +++ b/comfy_api_nodes/apis/PixverseController.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: -# filename: https://api.comfy.org/openapi -# timestamp: 2025-04-23T15:56:33+00:00 +# filename: filtered-openapi.yaml +# timestamp: 2025-04-29T23:44:54+00:00 from __future__ import annotations diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py index 39951221..323c38e9 100644 --- a/comfy_api_nodes/apis/PixverseDto.py +++ b/comfy_api_nodes/apis/PixverseDto.py @@ -1,12 +1,12 @@ # generated by datamodel-codegen: -# filename: https://api.comfy.org/openapi -# timestamp: 2025-04-23T15:56:33+00:00 +# filename: filtered-openapi.yaml +# timestamp: 2025-04-29T23:44:54+00:00 from __future__ import annotations from typing import Optional -from pydantic import BaseModel, Field, constr +from pydantic import BaseModel, Field class V2OpenAPII2VResp(BaseModel): @@ -30,10 +30,10 @@ class V2OpenAPIT2VReq(BaseModel): description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)', examples=['normal'], ) - negative_prompt: Optional[constr(max_length=2048)] = Field( - None, description='Negative prompt\n' + negative_prompt: Optional[str] = Field( + None, description='Negative prompt\n', max_length=2048 ) - prompt: constr(max_length=2048) = Field(..., description='Prompt') + prompt: str = Field(..., description='Prompt', max_length=2048) quality: str = Field( ..., description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")', diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index e7ea9b33..aa1c4ce0 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1,127 +1,455 @@ # generated by datamodel-codegen: -# filename: https://api.comfy.org/openapi -# timestamp: 2025-04-23T15:56:33+00:00 +# filename: filtered-openapi.yaml +# timestamp: 2025-05-04T04:12:39+00:00 from __future__ import annotations from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, Union +from uuid import UUID -from pydantic import AnyUrl, BaseModel, Field, confloat, conint - -class Customer(BaseModel): - createdAt: Optional[datetime] = Field( - None, description='The date and time the user was created' - ) - email: Optional[str] = Field(None, description='The email address for this user') - id: str = Field(..., description='The firebase UID of the user') - name: Optional[str] = Field(None, description='The name for this user') - updatedAt: Optional[datetime] = Field( - None, description='The date and time the user was last updated' - ) +from pydantic import AnyUrl, BaseModel, Field, RootModel, StrictBytes -class Error(BaseModel): - details: Optional[List[str]] = Field( +class PersonalAccessToken(BaseModel): + id: Optional[UUID] = Field(None, description='Unique identifier for the GitCommit') + name: Optional[str] = Field( None, - description='Optional detailed information about the error or hints for resolving it.', + description='Required. The name of the token. Can be a simple description.', ) - message: Optional[str] = Field( - None, description='A clear and concise description of the error.' + description: Optional[str] = Field( + None, + description="Optional. A more detailed description of the token's intended use.", ) + createdAt: Optional[datetime] = Field( + None, description='[Output Only]The date and time the token was created.' + ) + token: Optional[str] = Field( + None, + description='[Output Only]. The personal access token. Only returned during creation.', + ) + + +class GitCommitSummary(BaseModel): + commit_hash: Optional[str] = Field(None, description='The hash of the commit') + commit_name: Optional[str] = Field(None, description='The name of the commit') + branch_name: Optional[str] = Field( + None, description='The branch where the commit was made' + ) + author: Optional[str] = Field(None, description='The author of the commit') + timestamp: Optional[datetime] = Field( + None, description='The timestamp when the commit was made' + ) + status_summary: Optional[Dict[str, str]] = Field( + None, description='A map of operating system to status pairs' + ) + + +class User(BaseModel): + id: Optional[str] = Field(None, description='The unique id for this user.') + email: Optional[str] = Field(None, description='The email address for this user.') + name: Optional[str] = Field(None, description='The name for this user.') + isApproved: Optional[bool] = Field( + None, description='Indicates if the user is approved.' + ) + isAdmin: Optional[bool] = Field( + None, description='Indicates if the user has admin privileges.' + ) + + +class PublisherUser(BaseModel): + id: Optional[str] = Field(None, description='The unique id for this user.') + email: Optional[str] = Field(None, description='The email address for this user.') + name: Optional[str] = Field(None, description='The name for this user.') class ErrorResponse(BaseModel): error: str message: str + +class StorageFile(BaseModel): + id: Optional[UUID] = Field( + None, description='Unique identifier for the storage file' + ) + file_path: Optional[str] = Field(None, description='Path to the file in storage') + public_url: Optional[str] = Field(None, description='Public URL') + + +class PublisherMember(BaseModel): + id: Optional[str] = Field( + None, description='The unique identifier for the publisher member.' + ) + user: Optional[PublisherUser] = Field( + None, description='The user associated with this publisher member.' + ) + role: Optional[str] = Field( + None, description='The role of the user in the publisher.' + ) + + +class ComfyNode(BaseModel): + comfy_node_name: Optional[str] = Field( + None, description='Unique identifier for the node' + ) + category: Optional[str] = Field( + None, + description='UI category where the node is listed, used for grouping nodes.', + ) + description: Optional[str] = Field( + None, description="Brief description of the node's functionality or purpose." + ) + input_types: Optional[str] = Field(None, description='Defines input parameters') + deprecated: Optional[bool] = Field( + None, + description='Indicates if the node is deprecated. Deprecated nodes are hidden in the UI.', + ) + experimental: Optional[bool] = Field( + None, + description='Indicates if the node is experimental, subject to changes or removal.', + ) + output_is_list: Optional[List[bool]] = Field( + None, description='Boolean values indicating if each output is a list.' + ) + return_names: Optional[str] = Field( + None, description='Names of the outputs for clarity in workflows.' + ) + return_types: Optional[str] = Field( + None, description='Specifies the types of outputs produced by the node.' + ) + function: Optional[str] = Field( + None, description='Name of the entry-point function to execute the node.' + ) + + +class ComfyNodeCloudBuildInfo(BaseModel): + project_id: Optional[str] = None + project_number: Optional[str] = None + location: Optional[str] = None + build_id: Optional[str] = None + + +class Error(BaseModel): + message: Optional[str] = Field( + None, description='A clear and concise description of the error.' + ) + details: Optional[List[str]] = Field( + None, + description='Optional detailed information about the error or hints for resolving it.', + ) + + +class NodeVersionUpdateRequest(BaseModel): + changelog: Optional[str] = Field( + None, description='The changelog describing the version changes.' + ) + deprecated: Optional[bool] = Field( + None, description='Whether the version is deprecated.' + ) + + +class NodeStatus(str, Enum): + NodeStatusActive = 'NodeStatusActive' + NodeStatusDeleted = 'NodeStatusDeleted' + NodeStatusBanned = 'NodeStatusBanned' + + +class NodeVersionStatus(str, Enum): + NodeVersionStatusActive = 'NodeVersionStatusActive' + NodeVersionStatusDeleted = 'NodeVersionStatusDeleted' + NodeVersionStatusBanned = 'NodeVersionStatusBanned' + NodeVersionStatusPending = 'NodeVersionStatusPending' + NodeVersionStatusFlagged = 'NodeVersionStatusFlagged' + + +class PublisherStatus(str, Enum): + PublisherStatusActive = 'PublisherStatusActive' + PublisherStatusBanned = 'PublisherStatusBanned' + + +class WorkflowRunStatus(str, Enum): + WorkflowRunStatusStarted = 'WorkflowRunStatusStarted' + WorkflowRunStatusFailed = 'WorkflowRunStatusFailed' + WorkflowRunStatusCompleted = 'WorkflowRunStatusCompleted' + + +class MachineStats(BaseModel): + machine_name: Optional[str] = Field(None, description='Name of the machine.') + os_version: Optional[str] = Field( + None, description='The operating system version. eg. Ubuntu Linux 20.04' + ) + gpu_type: Optional[str] = Field( + None, description='The GPU type. eg. NVIDIA Tesla K80' + ) + cpu_capacity: Optional[str] = Field(None, description='Total CPU on the machine.') + initial_cpu: Optional[str] = Field( + None, description='Initial CPU available before the job starts.' + ) + memory_capacity: Optional[str] = Field( + None, description='Total memory on the machine.' + ) + initial_ram: Optional[str] = Field( + None, description='Initial RAM available before the job starts.' + ) + vram_time_series: Optional[Dict[str, Any]] = Field( + None, description='Time series of VRAM usage.' + ) + disk_capacity: Optional[str] = Field( + None, description='Total disk capacity on the machine.' + ) + initial_disk: Optional[str] = Field( + None, description='Initial disk available before the job starts.' + ) + pip_freeze: Optional[str] = Field(None, description='The pip freeze output') + + +class Customer(BaseModel): + id: str = Field(..., description='The firebase UID of the user') + email: Optional[str] = Field(None, description='The email address for this user') + name: Optional[str] = Field(None, description='The name for this user') + createdAt: Optional[datetime] = Field( + None, description='The date and time the user was created' + ) + updatedAt: Optional[datetime] = Field( + None, description='The date and time the user was last updated' + ) + + +class MagicPrompt(str, Enum): + ON = 'ON' + OFF = 'OFF' + + +class ColorPalette(BaseModel): + name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) + + +class StyleCode(RootModel[str]): + root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') + + +class StyleType(str, Enum): + GENERAL = 'GENERAL' + + +class IdeogramColorPalette1(BaseModel): + name: str = Field(..., description='Name of the preset color palette') + + +class Member(BaseModel): + color: Optional[str] = Field( + None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$' + ) + weight: Optional[float] = Field( + None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0 + ) + + +class IdeogramColorPalette2(BaseModel): + members: List[Member] = Field( + ..., description='Array of color definitions with optional weights' + ) + + +class IdeogramColorPalette( + RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]] +): + root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field( + ..., + description='A color palette specification that can either use a preset name or explicit color definitions with weights', + ) + + class ImageRequest(BaseModel): + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) aspect_ratio: Optional[str] = Field( None, description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", ) - color_palette: Optional[Dict[str, Any]] = Field( - None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' - ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") magic_prompt_option: Optional[str] = Field( None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." ) - model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") - negative_prompt: Optional[str] = Field( + seed: Optional[int] = Field( None, - description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', - ) - num_images: Optional[conint(ge=1, le=8)] = Field( - 1, description='Optional. Number of images to generate (1-8). Defaults to 1.' - ) - prompt: str = Field( - ..., description='Required. The prompt to use to generate the image.' - ) - resolution: Optional[str] = Field( - None, - description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", - ) - seed: Optional[conint(ge=0, le=2147483647)] = Field( - None, description='Optional. A number between 0 and 2147483647.' + description='Optional. A number between 0 and 2147483647.', + ge=0, + le=2147483647, ) style_type: Optional[str] = Field( None, description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", ) + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[int] = Field( + 1, + description='Optional. Number of images to generate (1-8). Defaults to 1.', + ge=1, + le=8, + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) + + +class IdeogramGenerateRequest(BaseModel): + image_request: ImageRequest = Field( + ..., description='The image generation request parameters.' + ) class Datum(BaseModel): - is_image_safe: Optional[bool] = Field( - None, description='Indicates whether the image is considered safe.' - ) prompt: Optional[str] = Field( None, description='The prompt used to generate this image.' ) resolution: Optional[str] = Field( None, description="The resolution of the generated image (e.g., '1024x1024')." ) + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) seed: Optional[int] = Field( None, description='The seed value used for this generation.' ) + url: Optional[str] = Field(None, description='URL to the generated image.') style_type: Optional[str] = Field( None, description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", ) - url: Optional[str] = Field(None, description='URL to the generated image.') -class Code(Enum): - int_1100 = 1100 - int_1101 = 1101 - int_1102 = 1102 - int_1103 = 1103 +class IdeogramGenerateResponse(BaseModel): + created: Optional[datetime] = Field( + None, description='Timestamp when the generation was created.' + ) + data: Optional[List[Datum]] = Field( + None, description='Array of generated image information.' + ) -class Code1(Enum): - int_1000 = 1000 - int_1001 = 1001 - int_1002 = 1002 - int_1003 = 1003 - int_1004 = 1004 +class RenderingSpeed1(str, Enum): + TURBO = 'TURBO' + DEFAULT = 'DEFAULT' + QUALITY = 'QUALITY' -class AspectRatio(str, Enum): +class MagicPrompt1(str, Enum): + AUTO = 'AUTO' + ON = 'ON' + OFF = 'OFF' + + +class StyleType1(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + + +class IdeogramV3RemixRequest(BaseModel): + image: Optional[StrictBytes] = None + prompt: str + image_weight: Optional[int] = Field(50, ge=1, le=100) + seed: Optional[int] = Field(None, ge=0, le=2147483647) + resolution: Optional[str] = None + aspect_ratio: Optional[str] = None + rendering_speed: Optional[RenderingSpeed1] = None + magic_prompt: Optional[MagicPrompt1] = None + negative_prompt: Optional[str] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + color_palette: Optional[Dict[str, Any]] = None + style_codes: Optional[List[str]] = None + style_type: Optional[StyleType1] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class Datum1(BaseModel): + prompt: Optional[str] = None + resolution: Optional[str] = None + is_image_safe: Optional[bool] = None + seed: Optional[int] = None + url: Optional[str] = None + style_type: Optional[str] = None + + +class IdeogramV3IdeogramResponse(BaseModel): + created: Optional[datetime] = None + data: Optional[List[Datum1]] = None + + +class IdeogramV3ReframeRequest(BaseModel): + image: Optional[StrictBytes] = None + resolution: str + num_images: Optional[int] = Field(None, ge=1, le=8) + seed: Optional[int] = Field(None, ge=0, le=2147483647) + rendering_speed: Optional[RenderingSpeed1] = None + color_palette: Optional[Dict[str, Any]] = None + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class IdeogramV3ReplaceBackgroundRequest(BaseModel): + image: Optional[StrictBytes] = None + prompt: str + magic_prompt: Optional[MagicPrompt1] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + seed: Optional[int] = Field(None, ge=0, le=2147483647) + rendering_speed: Optional[RenderingSpeed1] = None + color_palette: Optional[Dict[str, Any]] = None + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class KlingTaskStatus(str, Enum): + submitted = 'submitted' + processing = 'processing' + succeed = 'succeed' + failed = 'failed' + + +class KlingVideoGenModelName(str, Enum): + kling_v1 = 'kling-v1' + kling_v1_5 = 'kling-v1-5' + kling_v1_6 = 'kling-v1-6' + kling_v2_master = 'kling-v2-master' + + +class KlingVideoGenMode(str, Enum): + std = 'std' + pro = 'pro' + + +class KlingVideoGenAspectRatio(str, Enum): field_16_9 = '16:9' field_9_16 = '9:16' field_1_1 = '1:1' -class Config(BaseModel): - horizontal: Optional[confloat(ge=-10.0, le=10.0)] = None - pan: Optional[confloat(ge=-10.0, le=10.0)] = None - roll: Optional[confloat(ge=-10.0, le=10.0)] = None - tilt: Optional[confloat(ge=-10.0, le=10.0)] = None - vertical: Optional[confloat(ge=-10.0, le=10.0)] = None - zoom: Optional[confloat(ge=-10.0, le=10.0)] = None +class KlingVideoGenDuration(str, Enum): + field_5 = '5' + field_10 = '10' -class Type(str, Enum): +class KlingVideoGenCfgScale(RootModel[float]): + root: float = Field( + ..., + description="Flexibility in video generation. The higher the value, the lower the model's degree of flexibility, and the stronger the relevance to the user's prompt.", + ge=0.0, + le=1.0, + ) + + +class KlingCameraControlType(str, Enum): simple = 'simple' down_back = 'down_back' forward_up = 'forward_up' @@ -129,52 +457,99 @@ class Type(str, Enum): left_turn_forward = 'left_turn_forward' -class CameraControl(BaseModel): - config: Optional[Config] = None - type: Optional[Type] = Field(None, description='Predefined camera movements type') +class KlingCameraConfig(BaseModel): + horizontal: Optional[float] = Field( + None, + description="Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right.", + ge=-10.0, + le=10.0, + ) + vertical: Optional[float] = Field( + None, + description="Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward.", + ge=-10.0, + le=10.0, + ) + pan: Optional[float] = Field( + None, + description="Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.", + ge=-10.0, + le=10.0, + ) + tilt: Optional[float] = Field( + None, + description="Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.", + ge=-10.0, + le=10.0, + ) + roll: Optional[float] = Field( + None, + description="Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", + ge=-10.0, + le=10.0, + ) + zoom: Optional[float] = Field( + None, + description="Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.", + ge=-10.0, + le=10.0, + ) -class Duration(str, Enum): - field_5 = 5 - field_10 = 10 - - -class Mode(str, Enum): - std = 'std' - pro = 'pro' - - -class TaskInfo(BaseModel): - external_task_id: Optional[str] = None - - -class Video(BaseModel): - duration: Optional[str] = Field(None, description='Total video duration') +class KlingVideoResult(BaseModel): id: Optional[str] = Field(None, description='Generated video ID') url: Optional[AnyUrl] = Field(None, description='URL for generated video') + duration: Optional[str] = Field(None, description='Total video duration') -class TaskResult(BaseModel): - videos: Optional[List[Video]] = None +class KlingAudioUploadType(str, Enum): + file = 'file' + url = 'url' -class TaskStatus(str, Enum): - submitted = 'submitted' - processing = 'processing' - succeed = 'succeed' - failed = 'failed' +class KlingLipSyncMode(str, Enum): + text2video = 'text2video' + audio2video = 'audio2video' -class Data(BaseModel): - created_at: Optional[int] = Field(None, description='Task creation time') - task_id: Optional[str] = Field(None, description='Task ID') - task_info: Optional[TaskInfo] = None - task_result: Optional[TaskResult] = None - task_status: Optional[TaskStatus] = None - updated_at: Optional[int] = Field(None, description='Task update time') +class KlingLipSyncVoiceLanguage(str, Enum): + zh = 'zh' + en = 'en' -class AspectRatio1(str, Enum): +class KlingDualCharacterEffectsScene(str, Enum): + hug = 'hug' + kiss = 'kiss' + heart_gesture = 'heart_gesture' + + +class KlingSingleImageEffectsScene(str, Enum): + bloombloom = 'bloombloom' + dizzydizzy = 'dizzydizzy' + fuzzyfuzzy = 'fuzzyfuzzy' + squish = 'squish' + expansion = 'expansion' + + +class KlingCharacterEffectModelName(str, Enum): + kling_v1 = 'kling-v1' + kling_v1_5 = 'kling-v1-5' + kling_v1_6 = 'kling-v1-6' + + +class KlingSingleImageEffectModelName(str, Enum): + kling_v1_6 = 'kling-v1-6' + + +class KlingSingleImageEffectDuration(str, Enum): + field_5 = '5' + + +class KlingDualCharacterImages(RootModel[List[str]]): + root: List[str] = Field(..., max_length=2, min_length=2) + + +class KlingImageGenAspectRatio(str, Enum): field_16_9 = '16:9' field_9_16 = '9:16' field_1_1 = '1:1' @@ -185,63 +560,289 @@ class AspectRatio1(str, Enum): field_21_9 = '21:9' -class ImageReference(str, Enum): +class KlingImageGenImageReferenceType(str, Enum): subject = 'subject' face = 'face' -class Image(BaseModel): +class KlingImageGenModelName(str, Enum): + kling_v1 = 'kling-v1' + kling_v1_5 = 'kling-v1-5' + kling_v2 = 'kling-v2' + + +class KlingImageResult(BaseModel): index: Optional[int] = Field(None, description='Image Number (0-9)') url: Optional[AnyUrl] = Field(None, description='URL for generated image') -class TaskResult1(BaseModel): - images: Optional[List[Image]] = None +class KlingVirtualTryOnModelName(str, Enum): + kolors_virtual_try_on_v1 = 'kolors-virtual-try-on-v1' + kolors_virtual_try_on_v1_5 = 'kolors-virtual-try-on-v1-5' + + +class TaskInfo(BaseModel): + external_task_id: Optional[str] = None + + +class TaskResult(BaseModel): + videos: Optional[List[KlingVideoResult]] = None + + +class Data(BaseModel): + task_id: Optional[str] = Field(None, description='Task ID') + task_status: Optional[KlingTaskStatus] = None + task_info: Optional[TaskInfo] = None + created_at: Optional[int] = Field(None, description='Task creation time') + updated_at: Optional[int] = Field(None, description='Task update time') + task_result: Optional[TaskResult] = None + + +class KlingText2VideoResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + data: Optional[Data] = None + + +class Trajectory(BaseModel): + x: Optional[int] = Field( + None, + description='The horizontal coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).', + ) + y: Optional[int] = Field( + None, + description='The vertical coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).', + ) + + +class DynamicMask(BaseModel): + mask: Optional[AnyUrl] = Field( + None, + description='Dynamic Brush Application Area (Mask image created by users using the motion brush). The aspect ratio must match the input image.', + ) + trajectories: Optional[List[Trajectory]] = None class Data1(BaseModel): - created_at: Optional[int] = Field(None, description='Task creation time') task_id: Optional[str] = Field(None, description='Task ID') - task_result: Optional[TaskResult1] = None - task_status: Optional[TaskStatus] = None - task_status_msg: Optional[str] = Field(None, description='Task status information') + task_status: Optional[KlingTaskStatus] = None + task_info: Optional[TaskInfo] = None + created_at: Optional[int] = Field(None, description='Task creation time') updated_at: Optional[int] = Field(None, description='Task update time') + task_result: Optional[TaskResult] = None -class AspectRatio2(str, Enum): - field_16_9 = '16:9' - field_9_16 = '9:16' - field_1_1 = '1:1' +class KlingImage2VideoResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + data: Optional[Data1] = None -class CameraControl1(BaseModel): - config: Optional[Config] = None - type: Optional[Type] = Field(None, description='Predefined camera movements type') - - -class ModelName2(str, Enum): - kling_v1 = 'kling-v1' - kling_v1_6 = 'kling-v1-6' - - -class TaskResult2(BaseModel): - videos: Optional[List[Video]] = None +class KlingVideoExtendRequest(BaseModel): + video_id: Optional[str] = Field( + None, + description='The ID of the video to be extended. Supports videos generated by text-to-video, image-to-video, and previous video extension operations. Cannot exceed 3 minutes total duration after extension.', + ) + prompt: Optional[str] = Field( + None, + description='Positive text prompt for guiding the video extension', + max_length=2500, + ) + negative_prompt: Optional[str] = Field( + None, + description='Negative text prompt for elements to avoid in the extended video', + max_length=2500, + ) + cfg_scale: Optional[KlingVideoGenCfgScale] = Field( + default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + ) + callback_url: Optional[AnyUrl] = Field( + None, + description='The callback notification address. Server will notify when the task status changes.', + ) class Data2(BaseModel): - created_at: Optional[int] = Field(None, description='Task creation time') task_id: Optional[str] = Field(None, description='Task ID') + task_status: Optional[KlingTaskStatus] = None task_info: Optional[TaskInfo] = None - task_result: Optional[TaskResult2] = None - task_status: Optional[TaskStatus] = None + created_at: Optional[int] = Field(None, description='Task creation time') updated_at: Optional[int] = Field(None, description='Task update time') + task_result: Optional[TaskResult] = None -class Code2(Enum): - int_1200 = 1200 - int_1201 = 1201 - int_1202 = 1202 - int_1203 = 1203 +class KlingVideoExtendResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + data: Optional[Data2] = None + + +class KlingLipSyncInputObject(BaseModel): + video_id: Optional[str] = Field( + None, + description='The ID of the video generated by Kling AI. Only supports 5-second and 10-second videos generated within the last 30 days.', + ) + video_url: Optional[str] = Field( + None, + description='Get link for uploaded video. Video files support .mp4/.mov, file size does not exceed 100MB, video length between 2-10s.', + ) + mode: KlingLipSyncMode + text: Optional[str] = Field( + None, + description='Text Content for Lip-Sync Video Generation. Required when mode is text2video. Maximum length is 120 characters.', + ) + voice_id: Optional[str] = Field( + None, + description='Voice ID. Required when mode is text2video. The system offers a variety of voice options to choose from.', + ) + voice_language: Optional[KlingLipSyncVoiceLanguage] = 'en' + voice_speed: Optional[float] = Field( + 1, + description='Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.', + ge=0.8, + le=2.0, + ) + audio_type: Optional[KlingAudioUploadType] = None + audio_file: Optional[str] = Field( + None, + description='Local Path of Audio File. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB. Base64 code.', + ) + audio_url: Optional[str] = Field( + None, + description='Audio File Download URL. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB.', + ) + + +class KlingLipSyncRequest(BaseModel): + input: KlingLipSyncInputObject + callback_url: Optional[AnyUrl] = Field( + None, + description='The callback notification address. Server will notify when the task status changes.', + ) + + +class Data3(BaseModel): + task_id: Optional[str] = Field(None, description='Task ID') + task_status: Optional[KlingTaskStatus] = None + task_info: Optional[TaskInfo] = None + created_at: Optional[int] = Field(None, description='Task creation time') + updated_at: Optional[int] = Field(None, description='Task update time') + task_result: Optional[TaskResult] = None + + +class KlingLipSyncResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + data: Optional[Data3] = None + + +class KlingSingleImageEffectInput(BaseModel): + model_name: KlingSingleImageEffectModelName + image: str = Field( + ..., + description='Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1.', + ) + duration: KlingSingleImageEffectDuration + + +class KlingDualCharacterEffectInput(BaseModel): + model_name: Optional[KlingCharacterEffectModelName] = 'kling-v1' + mode: Optional[KlingVideoGenMode] = 'std' + images: KlingDualCharacterImages + duration: KlingVideoGenDuration + + +class Data4(BaseModel): + task_id: Optional[str] = Field(None, description='Task ID') + task_status: Optional[KlingTaskStatus] = None + task_info: Optional[TaskInfo] = None + created_at: Optional[int] = Field(None, description='Task creation time') + updated_at: Optional[int] = Field(None, description='Task update time') + task_result: Optional[TaskResult] = None + + +class KlingVideoEffectsResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + data: Optional[Data4] = None + + +class KlingImageGenerationsRequest(BaseModel): + model_name: Optional[KlingImageGenModelName] = 'kling-v1' + prompt: str = Field(..., description='Positive text prompt', max_length=500) + negative_prompt: Optional[str] = Field( + None, description='Negative text prompt', max_length=200 + ) + image: Optional[str] = Field( + None, description='Reference Image - Base64 encoded string or image URL' + ) + image_reference: Optional[KlingImageGenImageReferenceType] = None + image_fidelity: Optional[float] = Field( + 0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0 + ) + human_fidelity: Optional[float] = Field( + 0.45, description='Subject reference similarity', ge=0.0, le=1.0 + ) + n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9) + aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9' + callback_url: Optional[AnyUrl] = Field( + None, description='The callback notification address' + ) + + +class TaskResult5(BaseModel): + images: Optional[List[KlingImageResult]] = None + + +class Data5(BaseModel): + task_id: Optional[str] = Field(None, description='Task ID') + task_status: Optional[KlingTaskStatus] = None + task_status_msg: Optional[str] = Field(None, description='Task status information') + created_at: Optional[int] = Field(None, description='Task creation time') + updated_at: Optional[int] = Field(None, description='Task update time') + task_result: Optional[TaskResult5] = None + + +class KlingImageGenerationsResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + data: Optional[Data5] = None + + +class KlingVirtualTryOnRequest(BaseModel): + model_name: Optional[KlingVirtualTryOnModelName] = 'kolors-virtual-try-on-v1' + human_image: str = Field( + ..., description='Reference human image - Base64 encoded string or image URL' + ) + cloth_image: Optional[str] = Field( + None, + description='Reference clothing image - Base64 encoded string or image URL', + ) + callback_url: Optional[AnyUrl] = Field( + None, description='The callback notification address' + ) + + +class Data6(BaseModel): + task_id: Optional[str] = Field(None, description='Task ID') + task_status: Optional[KlingTaskStatus] = None + task_status_msg: Optional[str] = Field(None, description='Task status information') + created_at: Optional[int] = Field(None, description='Task creation time') + updated_at: Optional[int] = Field(None, description='Task update time') + task_result: Optional[TaskResult5] = None + + +class KlingVirtualTryOnResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + request_id: Optional[str] = Field(None, description='Request ID') + data: Optional[Data6] = None class ResourcePackType(str, Enum): @@ -257,87 +858,1140 @@ class Status(str, Enum): class ResourcePackSubscribeInfo(BaseModel): + resource_pack_name: Optional[str] = Field(None, description='Resource package name') + resource_pack_id: Optional[str] = Field(None, description='Resource package ID') + resource_pack_type: Optional[ResourcePackType] = Field( + None, + description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)', + ) + total_quantity: Optional[float] = Field(None, description='Total quantity') + remaining_quantity: Optional[float] = Field( + None, description='Remaining quantity (updated with a 12-hour delay)' + ) + purchase_time: Optional[int] = Field( + None, description='Purchase time, Unix timestamp in ms' + ) effective_time: Optional[int] = Field( None, description='Effective time, Unix timestamp in ms' ) invalid_time: Optional[int] = Field( None, description='Expiration time, Unix timestamp in ms' ) - purchase_time: Optional[int] = Field( - None, description='Purchase time, Unix timestamp in ms' - ) - remaining_quantity: Optional[float] = Field( - None, description='Remaining quantity (updated with a 12-hour delay)' - ) - resource_pack_id: Optional[str] = Field(None, description='Resource package ID') - resource_pack_name: Optional[str] = Field(None, description='Resource package name') - resource_pack_type: Optional[ResourcePackType] = Field( - None, - description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)', - ) status: Optional[Status] = Field(None, description='Resource Package Status') - total_quantity: Optional[float] = Field(None, description='Total quantity') - -class Background(str, Enum): - transparent = 'transparent' - opaque = 'opaque' -class Moderation(str, Enum): - low = 'low' - auto = 'auto' +class Data7(BaseModel): + code: Optional[int] = Field(None, description='Error code; 0 indicates success') + msg: Optional[str] = Field(None, description='Error information') + resource_pack_subscribe_infos: Optional[List[ResourcePackSubscribeInfo]] = Field( + None, description='Resource package list' + ) + + +class KlingResourcePackageResponse(BaseModel): + code: Optional[int] = Field(None, description='Error code; 0 indicates success') + message: Optional[str] = Field(None, description='Error information') + request_id: Optional[str] = Field( + None, + description='Request ID, generated by the system, used to track requests and troubleshoot problems', + ) + data: Optional[Data7] = None + + +class Object(str, Enum): + event = 'event' + + +class Type(str, Enum): + payment_intent_succeeded = 'payment_intent.succeeded' + + +class StripeRequestInfo(BaseModel): + id: Optional[str] = None + idempotency_key: Optional[str] = None + + +class Object1(str, Enum): + payment_intent = 'payment_intent' + + +class StripeAmountDetails(BaseModel): + tip: Optional[Dict[str, Any]] = None + + +class Object2(str, Enum): + charge = 'charge' + + +class StripeAddress(BaseModel): + city: Optional[str] = None + country: Optional[str] = None + line1: Optional[str] = None + line2: Optional[str] = None + postal_code: Optional[str] = None + state: Optional[str] = None + + +class StripeOutcome(BaseModel): + advice_code: Optional[Any] = None + network_advice_code: Optional[Any] = None + network_decline_code: Optional[Any] = None + network_status: Optional[str] = None + reason: Optional[Any] = None + risk_level: Optional[str] = None + risk_score: Optional[int] = None + seller_message: Optional[str] = None + type: Optional[str] = None + + +class Checks(BaseModel): + address_line1_check: Optional[Any] = None + address_postal_code_check: Optional[Any] = None + cvc_check: Optional[str] = None + + +class ExtendedAuthorization(BaseModel): + status: Optional[str] = None + + +class IncrementalAuthorization(BaseModel): + status: Optional[str] = None + + +class Multicapture(BaseModel): + status: Optional[str] = None + + +class NetworkToken(BaseModel): + used: Optional[bool] = None + + +class Overcapture(BaseModel): + maximum_amount_capturable: Optional[int] = None + status: Optional[str] = None + + +class StripeCardDetails(BaseModel): + amount_authorized: Optional[int] = None + authorization_code: Optional[Any] = None + brand: Optional[str] = None + checks: Optional[Checks] = None + country: Optional[str] = None + exp_month: Optional[int] = None + exp_year: Optional[int] = None + extended_authorization: Optional[ExtendedAuthorization] = None + fingerprint: Optional[str] = None + funding: Optional[str] = None + incremental_authorization: Optional[IncrementalAuthorization] = None + installments: Optional[Any] = None + last4: Optional[str] = None + mandate: Optional[Any] = None + multicapture: Optional[Multicapture] = None + network: Optional[str] = None + network_token: Optional[NetworkToken] = None + network_transaction_id: Optional[str] = None + overcapture: Optional[Overcapture] = None + regulated_status: Optional[str] = None + three_d_secure: Optional[Any] = None + wallet: Optional[Any] = None + + +class StripeRefundList(BaseModel): + object: Optional[str] = None + data: Optional[List[Dict[str, Any]]] = None + has_more: Optional[bool] = None + total_count: Optional[int] = None + url: Optional[str] = None + + +class Card(BaseModel): + installments: Optional[Any] = None + mandate_options: Optional[Any] = None + network: Optional[Any] = None + request_three_d_secure: Optional[str] = None + + +class StripePaymentMethodOptions(BaseModel): + card: Optional[Card] = None + + +class StripeShipping(BaseModel): + address: Optional[StripeAddress] = None + carrier: Optional[str] = None + name: Optional[str] = None + phone: Optional[str] = None + tracking_number: Optional[str] = None + + +class Model(str, Enum): + T2V_01_Director = 'T2V-01-Director' + I2V_01_Director = 'I2V-01-Director' + S2V_01 = 'S2V-01' + I2V_01 = 'I2V-01' + I2V_01_live = 'I2V-01-live' + T2V_01 = 'T2V-01' + + +class SubjectReferenceItem(BaseModel): + image: Optional[str] = Field( + None, description='URL or base64 encoding of the subject reference image.' + ) + mask: Optional[str] = Field( + None, + description='URL or base64 encoding of the mask for the subject reference image.', + ) + + +class MinimaxVideoGenerationRequest(BaseModel): + model: Model = Field( + ..., + description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', + ) + prompt: Optional[str] = Field( + None, + description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].', + max_length=2000, + ) + prompt_optimizer: Optional[bool] = Field( + True, + description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.', + ) + first_frame_image: Optional[str] = Field( + None, + description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', + ) + subject_reference: Optional[List[SubjectReferenceItem]] = Field( + None, + description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', + ) + callback_url: Optional[str] = Field( + None, + description='Optional. URL to receive real-time status updates about the video generation task.', + ) + + +class MinimaxBaseResponse(BaseModel): + status_code: int = Field( + ..., + description='Status code. 0 indicates success, other values indicate errors.', + ) + status_msg: str = Field( + ..., description='Specific error details or success message.' + ) + + +class MinimaxVideoGenerationResponse(BaseModel): + task_id: str = Field( + ..., description='The task ID for the asynchronous video generation task.' + ) + base_resp: MinimaxBaseResponse + + +class File(BaseModel): + file_id: Optional[int] = Field(None, description='Unique identifier for the file') + bytes: Optional[int] = Field(None, description='File size in bytes') + created_at: Optional[int] = Field( + None, description='Unix timestamp when the file was created, in seconds' + ) + filename: Optional[str] = Field(None, description='The name of the file') + purpose: Optional[str] = Field(None, description='The purpose of using the file') + download_url: Optional[str] = Field( + None, description='The URL to download the video' + ) + + +class MinimaxFileRetrieveResponse(BaseModel): + file: File + base_resp: MinimaxBaseResponse + + +class Status1(str, Enum): + Queueing = 'Queueing' + Preparing = 'Preparing' + Processing = 'Processing' + Success = 'Success' + Fail = 'Fail' + + +class MinimaxTaskResultResponse(BaseModel): + task_id: str = Field(..., description='The task ID being queried.') + status: Status1 = Field( + ..., + description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", + ) + file_id: Optional[str] = Field( + None, + description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', + ) + base_resp: MinimaxBaseResponse class OutputFormat(str, Enum): - png = 'png' - webp = 'webp' jpeg = 'jpeg' + png = 'png' + + +class BFLFluxPro11GenerateRequest(BaseModel): + prompt: str = Field(..., description='The main text prompt for image generation') + image_prompt: Optional[str] = Field(None, description='Optional image prompt') + width: int = Field(..., description='Width of the generated image') + height: int = Field(..., description='Height of the generated image') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to use prompt upsampling' + ) + seed: Optional[int] = Field(None, description='Random seed for reproducibility') + safety_tolerance: Optional[int] = Field(None, description='Safety tolerance level') + output_format: Optional[OutputFormat] = Field( + None, description='Output image format' + ) + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for async processing' + ) + webhook_secret: Optional[str] = Field( + None, description='Optional webhook secret for async processing' + ) + + +class BFLFluxPro11GenerateResponse(BaseModel): + id: str = Field(..., description='Job ID for tracking') + polling_url: str = Field(..., description='URL to poll for results') + + +class BFLFluxProGenerateRequest(BaseModel): + prompt: str = Field(..., description='The text prompt for image generation.') + negative_prompt: Optional[str] = Field( + None, description='The negative prompt for image generation.' + ) + width: int = Field( + ..., description='The width of the image to generate.', ge=64, le=2048 + ) + height: int = Field( + ..., description='The height of the image to generate.', ge=64, le=2048 + ) + num_inference_steps: Optional[int] = Field( + None, description='The number of inference steps.', ge=1, le=100 + ) + guidance_scale: Optional[float] = Field( + None, description='The guidance scale for generation.', ge=1.0, le=20.0 + ) + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + num_images: Optional[int] = Field( + None, description='The number of images to generate.', ge=1, le=4 + ) + + +class BFLFluxProGenerateResponse(BaseModel): + id: str = Field(..., description='The unique identifier for the generation task.') + polling_url: str = Field(..., description='URL to poll for the generation result.') + + +class Steps(RootModel[int]): + root: int = Field( + ..., + description='Number of steps for the image generation process', + examples=[50], + ge=15, + le=50, + title='Steps', + ) + + +class Guidance(RootModel[float]): + root: float = Field( + ..., + description='Guidance strength for the image generation process', + ge=1.5, + le=100.0, + title='Guidance', + ) + + +class WebhookUrl(RootModel[AnyUrl]): + root: AnyUrl = Field( + ..., description='URL to receive webhook notifications', title='Webhook Url' + ) + + +class BFLAsyncResponse(BaseModel): + id: str = Field(..., title='Id') + polling_url: str = Field(..., title='Polling Url') + + +class BFLAsyncWebhookResponse(BaseModel): + id: str = Field(..., title='Id') + status: str = Field(..., title='Status') + webhook_url: str = Field(..., title='Webhook Url') + + +class Top(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand at the top of the image', + ge=0, + le=2048, + title='Top', + ) + + +class Bottom(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand at the bottom of the image', + ge=0, + le=2048, + title='Bottom', + ) + + +class Left(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand on the left side of the image', + ge=0, + le=2048, + title='Left', + ) + + +class Right(RootModel[int]): + root: int = Field( + ..., + description='Number of pixels to expand on the right side of the image', + ge=0, + le=2048, + title='Right', + ) + + +class CannyLowThreshold(RootModel[int]): + root: int = Field( + ..., + description='Low threshold for Canny edge detection', + ge=0, + le=500, + title='Canny Low Threshold', + ) + + +class CannyHighThreshold(RootModel[int]): + root: int = Field( + ..., + description='High threshold for Canny edge detection', + ge=0, + le=500, + title='Canny High Threshold', + ) + + +class Steps2(RootModel[int]): + root: int = Field( + ..., + description='Number of steps for the image generation process', + ge=15, + le=50, + title='Steps', + ) + + +class Guidance2(RootModel[float]): + root: float = Field( + ..., + description='Guidance strength for the image generation process', + ge=1.0, + le=100.0, + title='Guidance', + ) + + +class BFLOutputFormat(str, Enum): + jpeg = 'jpeg' + png = 'png' + + +class BFLValidationError(BaseModel): + loc: List[Union[str, int]] = Field(..., title='Location') + msg: str = Field(..., title='Message') + type: str = Field(..., title='Error Type') + + +class Datum2(BaseModel): + image_id: Optional[str] = Field( + None, description='Unique identifier for the generated image' + ) + url: Optional[str] = Field(None, description='URL to access the generated image') + + +class RecraftImageGenerationResponse(BaseModel): + created: int = Field( + ..., description='Unix timestamp when the generation was created' + ) + credits: int = Field(..., description='Number of credits used for the generation') + data: List[Datum2] = Field(..., description='Array of generated image information') + + +class RecraftImageFeatures(BaseModel): + nsfw_score: Optional[float] = None + + +class RecraftTextLayoutItem(BaseModel): + bbox: List[List[float]] + text: str + + +class RecraftImageColor(BaseModel): + rgb: Optional[List[int]] = None + std: Optional[List[float]] = None + weight: Optional[float] = None + + +class RecraftImageStyle(str, Enum): + digital_illustration = 'digital_illustration' + icon = 'icon' + realistic_image = 'realistic_image' + vector_illustration = 'vector_illustration' + + +class RecraftImageSubStyle(str, Enum): + field_2d_art_poster = '2d_art_poster' + field_3d = '3d' + field_80s = '80s' + glow = 'glow' + grain = 'grain' + hand_drawn = 'hand_drawn' + infantile_sketch = 'infantile_sketch' + kawaii = 'kawaii' + pixel_art = 'pixel_art' + psychedelic = 'psychedelic' + seamless = 'seamless' + voxel = 'voxel' + watercolor = 'watercolor' + broken_line = 'broken_line' + colored_outline = 'colored_outline' + colored_shapes = 'colored_shapes' + colored_shapes_gradient = 'colored_shapes_gradient' + doodle_fill = 'doodle_fill' + doodle_offset_fill = 'doodle_offset_fill' + offset_fill = 'offset_fill' + outline = 'outline' + outline_gradient = 'outline_gradient' + uneven_fill = 'uneven_fill' + field_70s = '70s' + cartoon = 'cartoon' + doodle_line_art = 'doodle_line_art' + engraving = 'engraving' + flat_2 = 'flat_2' + kawaii_1 = 'kawaii' + line_art = 'line_art' + linocut = 'linocut' + seamless_1 = 'seamless' + b_and_w = 'b_and_w' + enterprise = 'enterprise' + hard_flash = 'hard_flash' + hdr = 'hdr' + motion_blur = 'motion_blur' + natural_light = 'natural_light' + studio_portrait = 'studio_portrait' + line_circuit = 'line_circuit' + field_2d_art_poster_2 = '2d_art_poster_2' + engraving_color = 'engraving_color' + flat_air_art = 'flat_air_art' + hand_drawn_outline = 'hand_drawn_outline' + handmade_3d = 'handmade_3d' + stickers_drawings = 'stickers_drawings' + plastic = 'plastic' + pictogram = 'pictogram' + + +class RecraftTransformModel(str, Enum): + refm1 = 'refm1' + recraft20b = 'recraft20b' + recraftv2 = 'recraftv2' + recraftv3 = 'recraftv3' + flux1_1pro = 'flux1_1pro' + flux1dev = 'flux1dev' + imagen3 = 'imagen3' + hidream_i1_dev = 'hidream_i1_dev' + + +class RecraftImageFormat(str, Enum): + webp = 'webp' + png = 'png' + + +class RecraftResponseFormat(str, Enum): + url = 'url' + b64_json = 'b64_json' + + +class RecraftImage(BaseModel): + b64_json: Optional[str] = None + features: Optional[RecraftImageFeatures] = None + image_id: UUID + revised_prompt: Optional[str] = None + url: Optional[str] = None + + +class RecraftUserControls(BaseModel): + artistic_level: Optional[int] = None + background_color: Optional[RecraftImageColor] = None + colors: Optional[List[RecraftImageColor]] = None + no_text: Optional[bool] = None + + +class RecraftTextLayout(RootModel[List[RecraftTextLayoutItem]]): + root: List[RecraftTextLayoutItem] + + +class RecraftProcessImageRequest(BaseModel): + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + response_format: Optional[RecraftResponseFormat] = None + + +class RecraftProcessImageResponse(BaseModel): + created: int + credits: int + image: RecraftImage + + +class RecraftImageToImageRequest(BaseModel): + block_nsfw: Optional[bool] = None + calculate_features: Optional[bool] = None + controls: Optional[RecraftUserControls] = None + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + model: Optional[RecraftTransformModel] = None + n: Optional[int] = None + negative_prompt: Optional[str] = None + prompt: str + random_seed: Optional[int] = None + response_format: Optional[RecraftResponseFormat] = None + strength: float + style: Optional[RecraftImageStyle] = None + style_id: Optional[UUID] = None + substyle: Optional[RecraftImageSubStyle] = None + text_layout: Optional[RecraftTextLayout] = None + + +class RecraftGenerateImageResponse(BaseModel): + created: int + credits: int + data: List[RecraftImage] + + +class RecraftTransformImageWithMaskRequest(BaseModel): + block_nsfw: Optional[bool] = None + calculate_features: Optional[bool] = None + image: StrictBytes + image_format: Optional[RecraftImageFormat] = None + mask: StrictBytes + model: Optional[RecraftTransformModel] = None + n: Optional[int] = None + negative_prompt: Optional[str] = None + prompt: str + random_seed: Optional[int] = None + response_format: Optional[RecraftResponseFormat] = None + style: Optional[RecraftImageStyle] = None + style_id: Optional[UUID] = None + substyle: Optional[RecraftImageSubStyle] = None + text_layout: Optional[RecraftTextLayout] = None + + +class KlingErrorResponse(BaseModel): + code: int = Field( + ..., + description='- 1000: Authentication failed\n- 1001: Authorization is empty\n- 1002: Authorization is invalid\n- 1003: Authorization is not yet valid\n- 1004: Authorization has expired\n- 1100: Account exception\n- 1101: Account in arrears (postpaid scenario)\n- 1102: Resource pack depleted or expired (prepaid scenario)\n- 1103: Unauthorized access to requested resource\n- 1200: Invalid request parameters\n- 1201: Invalid parameters\n- 1202: Invalid request method\n- 1203: Requested resource does not exist\n- 1300: Trigger platform strategy\n- 1301: Trigger content security policy\n- 1302: API request too frequent\n- 1303: Concurrency/QPS exceeds limit\n- 1304: Trigger IP whitelist policy\n- 5000: Internal server error\n- 5001: Service temporarily unavailable\n- 5002: Server internal timeout\n', + ) + message: str = Field(..., description='Human-readable error message') + request_id: str = Field( + ..., description='Request ID for tracking and troubleshooting' + ) + + +class LumaAspectRatio(str, Enum): + field_1_1 = '1:1' + field_16_9 = '16:9' + field_9_16 = '9:16' + field_4_3 = '4:3' + field_3_4 = '3:4' + field_21_9 = '21:9' + field_9_21 = '9:21' + + +class LumaVideoModel(str, Enum): + ray_2 = 'ray-2' + ray_flash_2 = 'ray-flash-2' + ray_1_6 = 'ray-1-6' + + +class LumaVideoModelOutputResolution1(str, Enum): + field_540p = '540p' + field_720p = '720p' + field_1080p = '1080p' + field_4k = '4k' + + +class LumaVideoModelOutputResolution( + RootModel[Union[LumaVideoModelOutputResolution1, str]] +): + root: Union[LumaVideoModelOutputResolution1, str] + + +class LumaVideoModelOutputDuration1(str, Enum): + field_5s = '5s' + field_9s = '9s' + + +class LumaVideoModelOutputDuration( + RootModel[Union[LumaVideoModelOutputDuration1, str]] +): + root: Union[LumaVideoModelOutputDuration1, str] + + +class LumaImageModel(str, Enum): + photon_1 = 'photon-1' + photon_flash_1 = 'photon-flash-1' + + +class LumaImageRef(BaseModel): + url: Optional[AnyUrl] = Field(None, description='The URL of the image reference') + weight: Optional[float] = Field( + None, description='The weight of the image reference' + ) + + +class LumaImageIdentity(BaseModel): + images: Optional[List[AnyUrl]] = Field( + None, description='The URLs of the image identity' + ) + + +class LumaModifyImageRef(BaseModel): + url: Optional[AnyUrl] = Field(None, description='The URL of the image reference') + weight: Optional[float] = Field( + None, description='The weight of the modify image reference' + ) + + +class Type1(str, Enum): + generation = 'generation' + + +class LumaGenerationReference(BaseModel): + type: Literal['generation'] + id: UUID = Field(..., description='The ID of the generation') + + +class Type2(str, Enum): + image = 'image' + + +class LumaImageReference(BaseModel): + type: Literal['image'] + url: AnyUrl = Field(..., description='The URL of the image') + + +class LumaKeyframe(RootModel[Union[LumaGenerationReference, LumaImageReference]]): + root: Union[LumaGenerationReference, LumaImageReference] = Field( + ..., + description='A keyframe can be either a Generation reference, an Image, or a Video', + discriminator='type', + ) + + +class LumaGenerationType(str, Enum): + video = 'video' + image = 'image' + + +class LumaState(str, Enum): + queued = 'queued' + dreaming = 'dreaming' + completed = 'completed' + failed = 'failed' + + +class LumaAssets(BaseModel): + video: Optional[AnyUrl] = Field(None, description='The URL of the video') + image: Optional[AnyUrl] = Field(None, description='The URL of the image') + progress_video: Optional[AnyUrl] = Field( + None, description='The URL of the progress video' + ) + + +class GenerationType(str, Enum): + video = 'video' + + +class GenerationType1(str, Enum): + image = 'image' + + +class CharacterRef(BaseModel): + identity0: Optional[LumaImageIdentity] = None + + +class LumaImageGenerationRequest(BaseModel): + generation_type: Optional[GenerationType1] = 'image' + model: Optional[LumaImageModel] = 'photon-1' + prompt: Optional[str] = Field(None, description='The prompt of the generation') + aspect_ratio: Optional[LumaAspectRatio] = '16:9' + callback_url: Optional[AnyUrl] = Field( + None, description='The callback URL for the generation' + ) + image_ref: Optional[List[LumaImageRef]] = None + style_ref: Optional[List[LumaImageRef]] = None + character_ref: Optional[CharacterRef] = None + modify_image_ref: Optional[LumaModifyImageRef] = None + + +class GenerationType2(str, Enum): + upscale_video = 'upscale_video' + + +class LumaUpscaleVideoGenerationRequest(BaseModel): + generation_type: Optional[GenerationType2] = 'upscale_video' + resolution: Optional[LumaVideoModelOutputResolution] = None + callback_url: Optional[AnyUrl] = Field( + None, description='The callback URL for the upscale' + ) + + +class GenerationType3(str, Enum): + add_audio = 'add_audio' + + +class LumaAudioGenerationRequest(BaseModel): + generation_type: Optional[GenerationType3] = 'add_audio' + prompt: Optional[str] = Field(None, description='The prompt of the audio') + negative_prompt: Optional[str] = Field( + None, description='The negative prompt of the audio' + ) + callback_url: Optional[AnyUrl] = Field( + None, description='The callback URL for the audio' + ) + + +class LumaError(BaseModel): + detail: Optional[str] = Field(None, description='The error message') + + +class AspectRatio(str, Enum): + field_16_9 = '16:9' + field_4_3 = '4:3' + field_1_1 = '1:1' + field_3_4 = '3:4' + field_9_16 = '9:16' + + +class Duration(int, Enum): + integer_5 = 5 + integer_8 = 8 + + +class Model1(str, Enum): + v3_5 = 'v3.5' + + +class MotionMode(str, Enum): + normal = 'normal' + fast = 'fast' class Quality(str, Enum): - low = 'low' - medium = 'medium' - high = 'high' + field_360p = '360p' + field_540p = '540p' + field_720p = '720p' + field_1080p = '1080p' -class OpenAIImageEditRequest(BaseModel): - background: Optional[str] = Field( - None, description='Background transparency', examples=['opaque'] - ) - model: str = Field( - ..., description='The model to use for image editing', examples=['gpt-image-1'] - ) - moderation: Optional[Moderation] = Field( - None, description='Content moderation setting', examples=['auto'] - ) - n: Optional[int] = Field( - None, description='The number of images to generate', examples=[1] - ) - output_compression: Optional[int] = Field( - None, description='Compression level for JPEG or WebP (0-100)', examples=[100] - ) - output_format: Optional[OutputFormat] = Field( - None, description='Format of the output image', examples=['png'] - ) - prompt: str = Field( - ..., - description='A text description of the desired edit', - examples=['Give the rocketship rainbow coloring'], - ) - quality: Optional[str] = Field( - None, description='The quality of the edited image', examples=['low'] - ) - size: Optional[str] = Field( - None, description='Size of the output image', examples=['1024x1024'] - ) - user: Optional[str] = Field( +class Style(str, Enum): + anime = 'anime' + field_3d_animation = '3d_animation' + clay = 'clay' + comic = 'comic' + cyberpunk = 'cyberpunk' + + +class PixverseTextVideoRequest(BaseModel): + aspect_ratio: AspectRatio + duration: Duration + model: Model1 + motion_mode: Optional[MotionMode] = None + negative_prompt: Optional[str] = None + prompt: str + quality: Quality + seed: Optional[int] = None + style: Optional[Style] = None + template_id: Optional[int] = None + water_mark: Optional[bool] = None + + +class Resp(BaseModel): + video_id: Optional[int] = None + + +class PixverseVideoResponse(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp_1: Optional[Resp] = Field(None, alias='Resp') + + +class Resp1(BaseModel): + img_id: Optional[int] = None + + +class PixverseImageUploadResponse(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp: Optional[Resp1] = None + + +class PixverseImageVideoRequest(BaseModel): + img_id: int + model: Model1 + prompt: str + duration: Duration + quality: Quality + motion_mode: Optional[MotionMode] = None + seed: Optional[int] = None + style: Optional[Style] = None + template_id: Optional[int] = None + water_mark: Optional[bool] = None + + +class PixverseTransitionVideoRequest(BaseModel): + first_frame_img: int + last_frame_img: int + model: Model1 + duration: Duration + quality: Quality + motion_mode: MotionMode + seed: int + prompt: str + style: Optional[Style] = None + template_id: Optional[int] = None + water_mark: Optional[bool] = None + + +class Status2(int, Enum): + integer_1 = 1 + integer_5 = 5 + integer_6 = 6 + integer_7 = 7 + integer_8 = 8 + + +class Resp2(BaseModel): + create_time: Optional[str] = None + id: Optional[int] = None + modify_time: Optional[str] = None + negative_prompt: Optional[str] = None + outputHeight: Optional[int] = None + outputWidth: Optional[int] = None + prompt: Optional[str] = None + resolution_ratio: Optional[int] = None + seed: Optional[int] = None + size: Optional[int] = None + status: Optional[Status2] = Field( None, - description='A unique identifier for end-user monitoring', - examples=['user-1234'], + description='Video generation status codes:\n* 1 - Generation successful\n* 5 - Generating\n* 6 - Deleted\n* 7 - Contents moderation failed\n* 8 - Generation failed\n', + ) + style: Optional[str] = None + url: Optional[str] = None + + +class PixverseVideoResultResponse(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp: Optional[Resp2] = None + + +class Image(BaseModel): + bytesBase64Encoded: str + gcsUri: Optional[str] = None + mimeType: Optional[str] = None + + +class Image1(BaseModel): + bytesBase64Encoded: Optional[str] = None + gcsUri: str + mimeType: Optional[str] = None + + +class Instance(BaseModel): + prompt: str = Field(..., description='Text description of the video') + image: Optional[Union[Image, Image1]] = Field( + None, description='Optional image to guide video generation' ) -class Quality1(str, Enum): +class PersonGeneration(str, Enum): + ALLOW = 'ALLOW' + BLOCK = 'BLOCK' + + +class Parameters(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + negativePrompt: Optional[str] = None + personGeneration: Optional[PersonGeneration] = None + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + + +class Veo2GenVidRequest(BaseModel): + instances: Optional[List[Instance]] = None + parameters: Optional[Parameters] = None + + +class Veo2GenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + +class Veo2GenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Video(BaseModel): + gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded video content' + ) + mimeType: Optional[str] = Field(None, description='Video MIME type') + + +class Response(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[List[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[List[Video]] = None + + +class Error1(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + + +class Veo2GenVidPollResponse(BaseModel): + name: Optional[str] = None + done: Optional[bool] = None + response: Optional[Response] = Field( + None, description='The actual prediction response if done is true' + ) + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + + +class RunwayImageToVideoResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class RunwayTaskStatusEnum(str, Enum): + SUCCEEDED = 'SUCCEEDED' + RUNNING = 'RUNNING' + FAILED = 'FAILED' + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + THROTTLED = 'THROTTLED' + + +class RunwayModelEnum(str, Enum): + gen4_turbo = 'gen4_turbo' + gen3a_turbo = 'gen3a_turbo' + + +class Position(str, Enum): + first = 'first' + last = 'last' + + +class RunwayPromptImageDetailedObject(BaseModel): + uri: str = Field( + ..., description='A HTTPS URL or data URI containing an encoded image.' + ) + position: Position = Field( + ..., + description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", + ) + + +class RunwayDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RunwayAspectRatioEnum(str, Enum): + field_1280_720 = '1280:720' + field_720_1280 = '720:1280' + field_1104_832 = '1104:832' + field_832_1104 = '832:1104' + field_960_960 = '960:960' + field_1584_672 = '1584:672' + field_1280_768 = '1280:768' + field_768_1280 = '768:1280' + + +class RunwayPromptImageObject( + RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] +): + root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( + ..., + description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', + ) + + +class Datum3(BaseModel): + b64_json: Optional[str] = Field(None, description='Base64 encoded image data') + url: Optional[str] = Field(None, description='URL of the image') + revised_prompt: Optional[str] = Field(None, description='Revised prompt') + + +class InputTokensDetails(BaseModel): + text_tokens: Optional[int] = None + image_tokens: Optional[int] = None + + +class Usage(BaseModel): + input_tokens: Optional[int] = None + input_tokens_details: Optional[InputTokensDetails] = None + output_tokens: Optional[int] = None + total_tokens: Optional[int] = None + + +class OpenAIImageGenerationResponse(BaseModel): + data: Optional[List[Datum3]] = None + usage: Optional[Usage] = None + + +class Quality3(str, Enum): low = 'low' medium = 'medium' high = 'high' @@ -345,54 +1999,70 @@ class Quality1(str, Enum): hd = 'hd' +class OutputFormat1(str, Enum): + png = 'png' + webp = 'webp' + jpeg = 'jpeg' + + +class Moderation(str, Enum): + low = 'low' + auto = 'auto' + + +class Background(str, Enum): + transparent = 'transparent' + opaque = 'opaque' + + class ResponseFormat(str, Enum): url = 'url' b64_json = 'b64_json' -class Style(str, Enum): +class Style3(str, Enum): vivid = 'vivid' natural = 'natural' class OpenAIImageGenerationRequest(BaseModel): - background: Optional[Background] = Field( - None, description='Background transparency', examples=['opaque'] - ) model: Optional[str] = Field( None, description='The model to use for image generation', examples=['dall-e-3'] ) - moderation: Optional[Moderation] = Field( - None, description='Content moderation setting', examples=['auto'] - ) - n: Optional[int] = Field( - None, - description='The number of images to generate (1-10). Only 1 supported for dall-e-3.', - examples=[1], - ) - output_compression: Optional[int] = Field( - None, description='Compression level for JPEG or WebP (0-100)', examples=[100] - ) - output_format: Optional[OutputFormat] = Field( - None, description='Format of the output image', examples=['png'] - ) prompt: str = Field( ..., description='A text description of the desired image', examples=['Draw a rocket in front of a blackhole in deep space'], ) - quality: Optional[Quality1] = Field( - None, description='The quality of the generated image', examples=['high'] + n: Optional[int] = Field( + None, + description='The number of images to generate (1-10). Only 1 supported for dall-e-3.', + examples=[1], ) - response_format: Optional[ResponseFormat] = Field( - None, description='Response format of image data', examples=['b64_json'] + quality: Optional[Quality3] = Field( + None, description='The quality of the generated image', examples=['high'] ) size: Optional[str] = Field( None, description='Size of the image (e.g., 1024x1024, 1536x1024, auto)', examples=['1024x1536'], ) - style: Optional[Style] = Field( + output_format: Optional[OutputFormat1] = Field( + None, description='Format of the output image', examples=['png'] + ) + output_compression: Optional[int] = Field( + None, description='Compression level for JPEG or WebP (0-100)', examples=[100] + ) + moderation: Optional[Moderation] = Field( + None, description='Content moderation setting', examples=['auto'] + ) + background: Optional[Background] = Field( + None, description='Background transparency', examples=['opaque'] + ) + response_format: Optional[ResponseFormat] = Field( + None, description='Response format of image data', examples=['b64_json'] + ) + style: Optional[Style3] = Field( None, description='Style of the image (only for dall-e-3)', examples=['vivid'] ) user: Optional[str] = Field( @@ -402,21 +2072,1758 @@ class OpenAIImageGenerationRequest(BaseModel): ) -class Datum1(BaseModel): - b64_json: Optional[str] = Field(None, description='Base64 encoded image data') - revised_prompt: Optional[str] = Field(None, description='Revised prompt') - url: Optional[str] = Field(None, description='URL of the image') +class OpenAIImageEditRequest(BaseModel): + model: str = Field( + ..., description='The model to use for image editing', examples=['gpt-image-1'] + ) + prompt: str = Field( + ..., + description='A text description of the desired edit', + examples=['Give the rocketship rainbow coloring'], + ) + n: Optional[int] = Field( + None, description='The number of images to generate', examples=[1] + ) + quality: Optional[str] = Field( + None, description='The quality of the edited image', examples=['low'] + ) + size: Optional[str] = Field( + None, description='Size of the output image', examples=['1024x1024'] + ) + output_format: Optional[OutputFormat1] = Field( + None, description='Format of the output image', examples=['png'] + ) + output_compression: Optional[int] = Field( + None, description='Compression level for JPEG or WebP (0-100)', examples=[100] + ) + moderation: Optional[Moderation] = Field( + None, description='Content moderation setting', examples=['auto'] + ) + background: Optional[str] = Field( + None, description='Background transparency', examples=['opaque'] + ) + user: Optional[str] = Field( + None, + description='A unique identifier for end-user monitoring', + examples=['user-1234'], + ) -class OpenAIImageGenerationResponse(BaseModel): - data: Optional[List[Datum1]] = None -class User(BaseModel): - email: Optional[str] = Field(None, description='The email address for this user.') - id: Optional[str] = Field(None, description='The unique id for this user.') - isAdmin: Optional[bool] = Field( - None, description='Indicates if the user has admin privileges.' +class CustomerStorageResourceResponse(BaseModel): + download_url: Optional[str] = Field( + None, + description='The signed URL to use for downloading the file from the specified path', ) - isApproved: Optional[bool] = Field( - None, description='Indicates if the user is approved.' + upload_url: Optional[str] = Field( + None, + description='The signed URL to use for uploading the file to the specified path', ) - name: Optional[str] = Field(None, description='The name for this user.') + expires_at: Optional[datetime] = Field( + None, description='When the signed URL will expire' + ) + existing_file: Optional[bool] = Field( + None, description='Whether an existing file with the same hash was found' + ) + + +class Pikaffect(str, Enum): + Cake_ify = 'Cake-ify' + Crumble = 'Crumble' + Crush = 'Crush' + Decapitate = 'Decapitate' + Deflate = 'Deflate' + Dissolve = 'Dissolve' + Explode = 'Explode' + Eye_pop = 'Eye-pop' + Inflate = 'Inflate' + Levitate = 'Levitate' + Melt = 'Melt' + Peel = 'Peel' + Poke = 'Poke' + Squish = 'Squish' + Ta_da = 'Ta-da' + Tear = 'Tear' + + +class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel): + image: Optional[StrictBytes] = Field(None, title='Image') + pikaffect: Optional[Pikaffect] = Field(None, title='Pikaffect') + promptText: Optional[str] = Field(None, title='Prompttext') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + seed: Optional[int] = Field(None, title='Seed') + + +class PikaGenerateResponse(BaseModel): + video_id: str = Field(..., title='Video Id') + + +class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel): + video: Optional[StrictBytes] = Field(None, title='Video') + image: Optional[StrictBytes] = Field(None, title='Image') + promptText: Optional[str] = Field(None, title='Prompttext') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + seed: Optional[int] = Field(None, title='Seed') + + +class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel): + video: Optional[StrictBytes] = Field(None, title='Video') + image: Optional[StrictBytes] = Field(None, title='Image') + promptText: Optional[str] = Field(None, title='Prompttext') + modifyRegionMask: Optional[StrictBytes] = Field( + None, + description='A mask image that specifies the region to modify, where the mask is white and the background is black', + title='Modifyregionmask', + ) + modifyRegionRoi: Optional[str] = Field( + None, + description='Plaintext description of the object / region to modify', + title='Modifyregionroi', + ) + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + seed: Optional[int] = Field(None, title='Seed') + + +class IngredientsMode(str, Enum): + creative = 'creative' + precise = 'precise' + + +class AspectRatio1(RootModel[float]): + root: float = Field( + ..., + description='Aspect ratio (width / height)', + ge=0.4, + le=2.5, + title='Aspectratio', + ) + + +class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel): + images: Optional[List[StrictBytes]] = Field(None, title='Images') + ingredientsMode: IngredientsMode = Field(..., title='Ingredientsmode') + promptText: Optional[str] = Field(None, title='Prompttext') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + seed: Optional[int] = Field(None, title='Seed') + resolution: Optional[str] = Field('1080p', title='Resolution') + duration: Optional[int] = Field(5, title='Duration') + aspectRatio: Optional[AspectRatio1] = Field( + None, description='Aspect ratio (width / height)', title='Aspectratio' + ) + + +class PikaStatusEnum(str, Enum): + queued = 'queued' + started = 'started' + finished = 'finished' + + +class PikaValidationError(BaseModel): + loc: List[Union[str, int]] = Field(..., title='Location') + msg: str = Field(..., title='Message') + type: str = Field(..., title='Error Type') + + +class PikaResolutionEnum(str, Enum): + field_1080p = '1080p' + field_720p = '720p' + + +class PikaDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RgbItem(RootModel[int]): + root: int = Field(..., ge=0, le=255) + + +class RGBColor(BaseModel): + rgb: List[RgbItem] = Field(..., max_length=3, min_length=3) + + +class StabilityStabilityClientID(RootModel[str]): + root: str = Field( + ..., + description='The name of your application, used to help us communicate app-specific debugging or moderation issues to you.', + examples=['my-awesome-app'], + max_length=256, + ) + + +class StabilityStabilityClientUserID(RootModel[str]): + root: str = Field( + ..., + description='A unique identifier for your end user. Used to help us communicate user-specific debugging or moderation issues to you. Feel free to obfuscate this value to protect user privacy.', + examples=['DiscordUser#9999'], + max_length=256, + ) + + +class StabilityStabilityClientVersion(RootModel[str]): + root: str = Field( + ..., + description='The version of your application, used to help us communicate version-specific debugging or moderation issues to you.', + examples=['1.2.1'], + max_length=256, + ) + + +class Name(str, Enum): + content_moderation = 'content_moderation' + + +class StabilityContentModerationResponse(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new) you file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: Name = Field( + ..., + description='Our content moderation system has flagged some part of your request and subsequently denied it. You were not charged for this request. While this may at times be frustrating, it is necessary to maintain the integrity of our platform and ensure a safe experience for all users. If you would like to provide feedback, please use the [Support Form](https://kb.stability.ai/knowledge-base/kb-tickets/new).', + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class RenderingSpeed(str, Enum): + BALANCED = 'BALANCED' + TURBO = 'TURBO' + QUALITY = 'QUALITY' + + +class StabilityCreativity(RootModel[float]): + root: float = Field( + ..., + description='Controls the likelihood of creating additional details not heavily conditioned by the init image.', + ge=0.2, + le=0.5, + ) + + +class StabilityGenerationID(RootModel[str]): + root: str = Field( + ..., + description='The `id` of a generation, typically used for async generations, that can be used to check the status of the generation or retrieve the result.', + examples=['a6dc6c6e20acda010fe14d71f180658f2896ed9b4ec25aa99a6ff06c796987c4'], + max_length=64, + min_length=64, + ) + + +class Mode(str, Enum): + text_to_image = 'text-to-image' + image_to_image = 'image-to-image' + + +class AspectRatio2(str, Enum): + field_21_9 = '21:9' + field_16_9 = '16:9' + field_3_2 = '3:2' + field_5_4 = '5:4' + field_1_1 = '1:1' + field_4_5 = '4:5' + field_2_3 = '2:3' + field_9_16 = '9:16' + field_9_21 = '9:21' + + +class Model4(str, Enum): + sd3_5_large = 'sd3.5-large' + sd3_5_large_turbo = 'sd3.5-large-turbo' + sd3_5_medium = 'sd3.5-medium' + + +class OutputFormat3(str, Enum): + png = 'png' + jpeg = 'jpeg' + + +class StylePreset(str, Enum): + enhance = 'enhance' + anime = 'anime' + photographic = 'photographic' + digital_art = 'digital-art' + comic_book = 'comic-book' + fantasy_art = 'fantasy-art' + line_art = 'line-art' + analog_film = 'analog-film' + neon_punk = 'neon-punk' + isometric = 'isometric' + low_poly = 'low-poly' + origami = 'origami' + modeling_compound = 'modeling-compound' + cinematic = 'cinematic' + field_3d_model = '3d-model' + pixel_art = 'pixel-art' + tile_texture = 'tile-texture' + + +class StabilityImageGenrationSD3Request(BaseModel): + prompt: str = Field( + ..., + description='What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.', + max_length=10000, + min_length=1, + ) + mode: Optional[Mode] = Field( + 'text-to-image', + description='Controls whether this is a text-to-image or image-to-image generation, which affects which parameters are required:\n- **text-to-image** requires only the `prompt` parameter\n- **image-to-image** requires the `prompt`, `image`, and `strength` parameters', + title='GenerationMode', + ) + image: Optional[StrictBytes] = Field( + None, + description='The image to use as the starting point for the generation.\n\nSupported formats:\n\n\n\n - jpeg\n - png\n - webp\n\nSupported dimensions:\n\n\n\n - Every side must be at least 64 pixels\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', + ) + strength: Optional[float] = Field( + None, + description='Sometimes referred to as _denoising_, this parameter controls how much influence the\n`image` parameter has on the generated image. A value of 0 would yield an image that\nis identical to the input. A value of 1 would be as if you passed in no image at all.\n\n> **Important:** This parameter is only valid for **image-to-image** requests.', + ge=0.0, + le=1.0, + ) + aspect_ratio: Optional[AspectRatio2] = Field( + '1:1', + description='Controls the aspect ratio of the generated image. Defaults to 1:1.\n\n> **Important:** This parameter is only valid for **text-to-image** requests.', + ) + model: Optional[Model4] = Field( + 'sd3.5-large', + description='The model to use for generation.\n\n- `sd3.5-large` requires 6.5 credits per generation\n- `sd3.5-large-turbo` requires 4 credits per generation\n- `sd3.5-medium` requires 3.5 credits per generation\n- As of the April 17, 2025, `sd3-large`, `sd3-large-turbo` and `sd3-medium`\n\n\n\n are re-routed to their `sd3.5-[model version]` equivalent, at the same price.', + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + output_format: Optional[OutputFormat3] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + style_preset: Optional[StylePreset] = Field( + None, description='Guides the image model towards a particular style.' + ) + negative_prompt: Optional[str] = Field( + None, + description='Keywords of what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + cfg_scale: Optional[float] = Field( + None, + description='How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt). The _Large_ and _Medium_ models use a default of `4`. The _Turbo_ model uses a default of `1`.', + ge=1.0, + le=10.0, + ) + + +class FinishReason(str, Enum): + SUCCESS = 'SUCCESS' + CONTENT_FILTERED = 'CONTENT_FILTERED' + + +class StabilityImageGenrationSD3Response200(BaseModel): + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + + +class StabilityImageGenrationSD3Response400(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response413(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response422(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response429(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationSD3Response500(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class OutputFormat4(str, Enum): + jpeg = 'jpeg' + png = 'png' + webp = 'webp' + + +class StabilityImageGenrationUpscaleConservativeRequest(BaseModel): + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 9,437,184 pixels\n- The aspect ratio must be between 1:2.5 and 2.5:1', + examples=['./some/image.png'], + ) + prompt: str = Field( + ..., + description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", + max_length=10000, + min_length=1, + ) + negative_prompt: Optional[str] = Field( + None, + description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + creativity: Optional[StabilityCreativity] = Field( + default_factory=lambda: StabilityCreativity.model_validate(0.35) + ) + + +class StabilityImageGenrationUpscaleConservativeResponse200(BaseModel): + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + + +class StabilityImageGenrationUpscaleConservativeResponse400(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse413(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse422(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse429(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleConservativeResponse500(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeRequest(BaseModel): + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Every side must be at least 64 pixels\n- Total pixel count must be between 4,096 and 1,048,576 pixels', + examples=['./some/image.png'], + ) + prompt: str = Field( + ..., + description="What you wish to see in the output image. A strong, descriptive prompt that clearly defines\nelements, colors, and subjects will lead to better results.\n\nTo control the weight of a given word use the format `(word:weight)`,\nwhere `word` is the word you'd like to control the weight of and `weight`\nis a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`\nwould convey a sky that was blue and green, but more green than blue.", + max_length=10000, + min_length=1, + ) + negative_prompt: Optional[str] = Field( + None, + description='A blurb of text describing what you **do not** wish to see in the output image.\nThis is an advanced feature.', + max_length=10000, + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + seed: Optional[float] = Field( + 0, + description="A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass `0` to use a random seed.)", + ge=0.0, + le=4294967294.0, + ) + creativity: Optional[float] = Field( + 0.3, + description='Indicates how creative the model should be when upscaling an image.\nHigher values will result in more details being added to the image during upscaling.', + ge=0.1, + le=0.5, + ) + style_preset: Optional[StylePreset] = Field( + None, description='Guides the image model towards a particular style.' + ) + + +class StabilityImageGenrationUpscaleCreativeResponse200(BaseModel): + id: StabilityGenerationID + + +class StabilityImageGenrationUpscaleCreativeResponse400(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse413(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse422(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse429(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleCreativeResponse500(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastRequest(BaseModel): + image: StrictBytes = Field( + ..., + description='The image you wish to upscale.\n\nSupported Formats:\n- jpeg\n- png\n- webp\n\nValidation Rules:\n- Width must be between 32 and 1,536 pixels\n- Height must be between 32 and 1,536 pixels\n- Total pixel count must be between 1,024 and 1,048,576 pixels', + examples=['./some/image.png'], + ) + output_format: Optional[OutputFormat4] = Field( + 'png', description='Dictates the `content-type` of the generated image.' + ) + + +class StabilityImageGenrationUpscaleFastResponse200(BaseModel): + image: str = Field( + ..., + description='The generated image, encoded to base64.', + examples=['AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1...'], + ) + seed: Optional[float] = Field( + 0, + description='The seed used as random noise for this generation.', + examples=[343940597], + ge=0.0, + le=4294967294.0, + ) + finish_reason: FinishReason = Field( + ..., + description='The reason the generation finished.\n\n- `SUCCESS` = successful generation.\n- `CONTENT_FILTERED` = successful generation, however the output violated our content moderation\npolicy and has been blurred as a result.', + examples=['SUCCESS'], + ) + + +class StabilityImageGenrationUpscaleFastResponse400(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse413(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse422(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse429(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class StabilityImageGenrationUpscaleFastResponse500(BaseModel): + id: str = Field( + ..., + description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new)\nyou file, as it will greatly assist us in diagnosing the root cause of the problem.', + examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'], + min_length=1, + ) + name: str = Field( + ..., + description='Short-hand name for an error, useful for discriminating between errors with the same status code.', + examples=['bad_request'], + min_length=1, + ) + errors: List[str] = Field( + ..., + description='One or more error messages indicating what went wrong.', + examples=[['some-field: is required']], + min_length=1, + ) + + +class ActionJobResult(BaseModel): + id: Optional[UUID] = Field(None, description='Unique identifier for the job result') + workflow_name: Optional[str] = Field(None, description='Name of the workflow') + operating_system: Optional[str] = Field(None, description='Operating system used') + python_version: Optional[str] = Field(None, description='PyTorch version used') + pytorch_version: Optional[str] = Field(None, description='PyTorch version used') + action_run_id: Optional[str] = Field( + None, description='Identifier of the run this result belongs to' + ) + action_job_id: Optional[str] = Field( + None, description='Identifier of the job this result belongs to' + ) + cuda_version: Optional[str] = Field(None, description='CUDA version used') + branch_name: Optional[str] = Field( + None, description='Name of the relevant git branch' + ) + commit_hash: Optional[str] = Field(None, description='The hash of the commit') + commit_id: Optional[str] = Field(None, description='The ID of the commit') + commit_time: Optional[int] = Field( + None, description='The Unix timestamp when the commit was made' + ) + commit_message: Optional[str] = Field(None, description='The message of the commit') + comfy_run_flags: Optional[str] = Field( + None, description='The comfy run flags. E.g. `--low-vram`' + ) + git_repo: Optional[str] = Field(None, description='The repository name') + pr_number: Optional[str] = Field(None, description='The pull request number') + start_time: Optional[int] = Field( + None, description='The start time of the job as a Unix timestamp.' + ) + end_time: Optional[int] = Field( + None, description='The end time of the job as a Unix timestamp.' + ) + avg_vram: Optional[int] = Field( + None, description='The average VRAM used by the job' + ) + peak_vram: Optional[int] = Field(None, description='The peak VRAM used by the job') + job_trigger_user: Optional[str] = Field( + None, description='The user who triggered the job.' + ) + author: Optional[str] = Field(None, description='The author of the commit') + machine_stats: Optional[MachineStats] = None + status: Optional[WorkflowRunStatus] = None + storage_file: Optional[StorageFile] = None + + +class Publisher(BaseModel): + name: Optional[str] = None + id: Optional[str] = Field( + None, + description="The unique identifier for the publisher. It's akin to a username. Should be lowercase.", + ) + description: Optional[str] = None + website: Optional[str] = None + support: Optional[str] = None + source_code_repo: Optional[str] = None + logo: Optional[str] = Field(None, description="URL to the publisher's logo.") + createdAt: Optional[datetime] = Field( + None, description='The date and time the publisher was created.' + ) + members: Optional[List[PublisherMember]] = Field( + None, description='A list of members in the publisher.' + ) + status: Optional[PublisherStatus] = Field( + None, description='The status of the publisher.' + ) + + +class NodeVersion(BaseModel): + id: Optional[str] = None + version: Optional[str] = Field( + None, + description='The version identifier, following semantic versioning. Must be unique for the node.', + ) + createdAt: Optional[datetime] = Field( + None, description='The date and time the version was created.' + ) + changelog: Optional[str] = Field( + None, description='Summary of changes made in this version' + ) + dependencies: Optional[List[str]] = Field( + None, description='A list of pip dependencies required by the node.' + ) + downloadUrl: Optional[str] = Field( + None, description='[Output Only] URL to download this version of the node' + ) + deprecated: Optional[bool] = Field( + None, description='Indicates if this version is deprecated.' + ) + status: Optional[NodeVersionStatus] = Field( + None, description='The status of the node version.' + ) + status_reason: Optional[str] = Field( + None, description='The reason for the status change.' + ) + node_id: Optional[str] = Field( + None, description='The unique identifier of the node.' + ) + comfy_node_extract_status: Optional[str] = Field( + None, description='The status of comfy node extraction process.' + ) + + +class IdeogramV3Request(BaseModel): + prompt: str = Field(..., description='The text prompt for image generation') + seed: Optional[int] = Field( + None, description='Seed value for reproducible generation' + ) + resolution: Optional[str] = Field( + None, description='Image resolution in format WxH', examples=['1280x800'] + ) + aspect_ratio: Optional[str] = Field( + None, description='Aspect ratio in format WxH', examples=['1x3'] + ) + rendering_speed: RenderingSpeed + magic_prompt: Optional[MagicPrompt] = Field( + None, description='Whether to enable magic prompt enhancement' + ) + negative_prompt: Optional[str] = Field( + None, description='Text prompt specifying what to avoid in the generation' + ) + num_images: Optional[int] = Field( + None, description='Number of images to generate', ge=1 + ) + color_palette: Optional[ColorPalette] = None + style_codes: Optional[List[StyleCode]] = Field( + None, description='Array of style codes in hexadecimal format' + ) + style_type: Optional[StyleType] = Field( + None, description='The type of style to apply' + ) + style_reference_images: Optional[List[str]] = Field( + None, description='Array of reference image URLs or identifiers' + ) + + +class IdeogramV3EditRequest(BaseModel): + image: Optional[StrictBytes] = Field( + None, + description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.', + ) + mask: Optional[StrictBytes] = Field( + None, + description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', + ) + prompt: str = Field( + ..., description='The prompt used to describe the edited result.' + ) + magic_prompt: Optional[str] = Field( + None, + description='Determine if MagicPrompt should be used in generating the request or not.', + ) + num_images: Optional[int] = Field( + None, description='The number of images to generate.' + ) + seed: Optional[int] = Field( + None, description='Random seed. Set for reproducible generation.' + ) + rendering_speed: RenderingSpeed + color_palette: Optional[IdeogramColorPalette] = Field( + None, + description='A color palette for generation, must EITHER be specified via one of the presets (name) or explicitly via hexadecimal representations of the color with optional weights (members). Not supported by V_1, V_1_TURBO, V_2A and V_2A_TURBO models.', + ) + style_codes: Optional[List[StyleCode]] = Field( + None, + description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.', + ) + style_reference_images: Optional[List[StrictBytes]] = Field( + None, + description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', + ) + + +class KlingCameraControl(BaseModel): + type: Optional[KlingCameraControlType] = None + config: Optional[KlingCameraConfig] = None + + +class KlingText2VideoRequest(BaseModel): + model_name: Optional[KlingVideoGenModelName] = 'kling-v2-master' + prompt: Optional[str] = Field( + None, description='Positive text prompt', max_length=2500 + ) + negative_prompt: Optional[str] = Field( + None, description='Negative text prompt', max_length=2500 + ) + cfg_scale: Optional[KlingVideoGenCfgScale] = Field( + default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + ) + mode: Optional[KlingVideoGenMode] = 'std' + camera_control: Optional[KlingCameraControl] = None + aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9' + duration: Optional[KlingVideoGenDuration] = '5' + callback_url: Optional[AnyUrl] = Field( + None, description='The callback notification address' + ) + external_task_id: Optional[str] = Field(None, description='Customized Task ID') + + +class KlingImage2VideoRequest(BaseModel): + model_name: Optional[KlingVideoGenModelName] = 'kling-v2-master' + image: Optional[str] = Field( + None, + description='Reference Image - URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1. Base64 should not include data:image prefix.', + ) + image_tail: Optional[str] = Field( + None, + description='Reference Image - End frame control. URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px. Base64 should not include data:image prefix.', + ) + prompt: Optional[str] = Field( + None, description='Positive text prompt', max_length=2500 + ) + negative_prompt: Optional[str] = Field( + None, description='Negative text prompt', max_length=2500 + ) + cfg_scale: Optional[KlingVideoGenCfgScale] = Field( + default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5) + ) + mode: Optional[KlingVideoGenMode] = 'std' + static_mask: Optional[str] = Field( + None, + description='Static Brush Application Area (Mask image created by users using the motion brush). The aspect ratio must match the input image.', + ) + dynamic_masks: Optional[List[DynamicMask]] = Field( + None, + description='Dynamic Brush Configuration List (up to 6 groups). For 5-second videos, trajectory length must not exceed 77 coordinates.', + ) + camera_control: Optional[KlingCameraControl] = None + aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9' + duration: Optional[KlingVideoGenDuration] = '5' + callback_url: Optional[AnyUrl] = Field( + None, + description='The callback notification address. Server will notify when the task status changes.', + ) + external_task_id: Optional[str] = Field( + None, + description='Customized Task ID. Must be unique within a single user account.', + ) + + +class KlingVideoEffectsInput( + RootModel[Union[KlingSingleImageEffectInput, KlingDualCharacterEffectInput]] +): + root: Union[KlingSingleImageEffectInput, KlingDualCharacterEffectInput] + + +class StripeBillingDetails(BaseModel): + address: Optional[StripeAddress] = None + email: Optional[str] = None + name: Optional[str] = None + phone: Optional[str] = None + tax_id: Optional[Any] = None + + +class StripePaymentMethodDetails(BaseModel): + card: Optional[StripeCardDetails] = None + type: Optional[str] = None + + +class BFLFluxProFillInputs(BaseModel): + image: str = Field( + ..., + description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.', + title='Image', + ) + mask: Optional[str] = Field( + None, + description='A Base64-encoded string representing a mask for the areas you want to modify in the image. The mask should be the same dimensions as the image and in black and white. Black areas (0%) indicate no modification, while white areas (100%) specify areas for inpainting. Optional if you provide an alpha mask in the original image. Validation: The endpoint verifies that the dimensions of the mask match the original image.', + title='Mask', + ) + prompt: Optional[str] = Field( + '', + description='The description of the changes you want to make. This text guides the inpainting process, allowing you to specify features, styles, or modifications for the masked area.', + examples=['ein fantastisches bild'], + title='Prompt', + ) + steps: Optional[Steps] = Field( + default_factory=lambda: Steps.model_validate(50), + description='Number of steps for the image generation process', + examples=[50], + title='Steps', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', + title='Prompt Upsampling', + ) + seed: Optional[int] = Field( + None, description='Optional seed for reproducibility', title='Seed' + ) + guidance: Optional[Guidance] = Field( + default_factory=lambda: Guidance.model_validate(60), + description='Guidance strength for the image generation process', + title='Guidance', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + examples=[2], + ge=0, + le=6, + title='Safety Tolerance', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + + +class BFLHTTPValidationError(BaseModel): + detail: Optional[List[BFLValidationError]] = Field(None, title='Detail') + + +class BFLFluxProExpandInputs(BaseModel): + image: str = Field( + ..., + description='A Base64-encoded string representing the image you wish to expand.', + title='Image', + ) + top: Optional[Top] = Field( + 0, description='Number of pixels to expand at the top of the image', title='Top' + ) + bottom: Optional[Bottom] = Field( + 0, + description='Number of pixels to expand at the bottom of the image', + title='Bottom', + ) + left: Optional[Left] = Field( + 0, + description='Number of pixels to expand on the left side of the image', + title='Left', + ) + right: Optional[Right] = Field( + 0, + description='Number of pixels to expand on the right side of the image', + title='Right', + ) + prompt: Optional[str] = Field( + '', + description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.', + examples=['ein fantastisches bild'], + title='Prompt', + ) + steps: Optional[Steps] = Field( + default_factory=lambda: Steps.model_validate(50), + description='Number of steps for the image generation process', + examples=[50], + title='Steps', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation', + title='Prompt Upsampling', + ) + seed: Optional[int] = Field( + None, description='Optional seed for reproducibility', title='Seed' + ) + guidance: Optional[Guidance] = Field( + default_factory=lambda: Guidance.model_validate(60), + description='Guidance strength for the image generation process', + title='Guidance', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + examples=[2], + ge=0, + le=6, + title='Safety Tolerance', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + + +class BFLCannyInputs(BaseModel): + prompt: str = Field( + ..., + description='Text prompt for image generation', + examples=['ein fantastisches bild'], + title='Prompt', + ) + control_image: Optional[str] = Field( + None, + description='Base64 encoded image to use as control input if no preprocessed image is provided', + title='Control Image', + ) + preprocessed_image: Optional[str] = Field( + None, + description='Optional pre-processed image that will bypass the control preprocessing step', + title='Preprocessed Image', + ) + canny_low_threshold: Optional[CannyLowThreshold] = Field( + default_factory=lambda: CannyLowThreshold.model_validate(50), + description='Low threshold for Canny edge detection', + title='Canny Low Threshold', + ) + canny_high_threshold: Optional[CannyHighThreshold] = Field( + default_factory=lambda: CannyHighThreshold.model_validate(200), + description='High threshold for Canny edge detection', + title='Canny High Threshold', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt', + title='Prompt Upsampling', + ) + seed: Optional[int] = Field( + None, + description='Optional seed for reproducibility', + examples=[42], + title='Seed', + ) + steps: Optional[Steps2] = Field( + default_factory=lambda: Steps2.model_validate(50), + description='Number of steps for the image generation process', + title='Steps', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + guidance: Optional[Guidance2] = Field( + default_factory=lambda: Guidance2.model_validate(30), + description='Guidance strength for the image generation process', + title='Guidance', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ge=0, + le=6, + title='Safety Tolerance', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + + +class BFLDepthInputs(BaseModel): + prompt: str = Field( + ..., + description='Text prompt for image generation', + examples=['ein fantastisches bild'], + title='Prompt', + ) + control_image: Optional[str] = Field( + None, + description='Base64 encoded image to use as control input', + title='Control Image', + ) + preprocessed_image: Optional[str] = Field( + None, + description='Optional pre-processed image that will bypass the control preprocessing step', + title='Preprocessed Image', + ) + prompt_upsampling: Optional[bool] = Field( + False, + description='Whether to perform upsampling on the prompt', + title='Prompt Upsampling', + ) + seed: Optional[int] = Field( + None, + description='Optional seed for reproducibility', + examples=[42], + title='Seed', + ) + steps: Optional[Steps2] = Field( + default_factory=lambda: Steps2.model_validate(50), + description='Number of steps for the image generation process', + title='Steps', + ) + output_format: Optional[BFLOutputFormat] = Field( + 'jpeg', + description="Output format for the generated image. Can be 'jpeg' or 'png'.", + ) + guidance: Optional[Guidance2] = Field( + default_factory=lambda: Guidance2.model_validate(15), + description='Guidance strength for the image generation process', + title='Guidance', + ) + safety_tolerance: Optional[int] = Field( + 2, + description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ge=0, + le=6, + title='Safety Tolerance', + ) + webhook_url: Optional[WebhookUrl] = Field( + None, description='URL to receive webhook notifications', title='Webhook Url' + ) + webhook_secret: Optional[str] = Field( + None, + description='Optional secret for webhook signature verification', + title='Webhook Secret', + ) + + +class Controls(BaseModel): + artistic_level: Optional[int] = Field( + None, + description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity.', + ge=0, + le=5, + ) + colors: Optional[List[RGBColor]] = Field( + None, description='An array of preferable colors' + ) + background_color: Optional[RGBColor] = Field( + None, description='Use given color as a desired background color' + ) + no_text: Optional[bool] = Field(None, description='Do not embed text layouts') + + +class RecraftImageGenerationRequest(BaseModel): + prompt: str = Field( + ..., description='The text prompt describing the image to generate' + ) + model: str = Field( + ..., description='The model to use for generation (e.g., "recraftv3")' + ) + style: Optional[str] = Field( + None, + description='The style to apply to the generated image (e.g., "digital_illustration")', + ) + style_id: Optional[str] = Field( + None, + description='The style ID to apply to the generated image (e.g., "123e4567-e89b-12d3-a456-426614174000"). If style_id is provided, style should not be provided.', + ) + size: str = Field( + ..., description='The size of the generated image (e.g., "1024x1024")' + ) + controls: Optional[Controls] = Field( + None, description='The controls for the generated image' + ) + n: int = Field(..., description='The number of images to generate', ge=1, le=4) + + +class LumaKeyframes(BaseModel): + frame0: Optional[LumaKeyframe] = None + frame1: Optional[LumaKeyframe] = None + + +class LumaGenerationRequest(BaseModel): + generation_type: Optional[GenerationType] = 'video' + prompt: str = Field(..., description='The prompt of the generation') + aspect_ratio: LumaAspectRatio + loop: Optional[bool] = Field(None, description='Whether to loop the video') + keyframes: Optional[LumaKeyframes] = None + callback_url: Optional[AnyUrl] = Field( + None, + description='The callback URL of the generation, a POST request with Generation object will be sent to the callback URL when the generation is dreaming, completed, or failed', + ) + model: LumaVideoModel + resolution: LumaVideoModelOutputResolution + duration: LumaVideoModelOutputDuration + + +class LumaGeneration(BaseModel): + id: Optional[UUID] = Field(None, description='The ID of the generation') + generation_type: Optional[LumaGenerationType] = None + state: Optional[LumaState] = None + failure_reason: Optional[str] = Field( + None, description='The reason for the state of the generation' + ) + created_at: Optional[datetime] = Field( + None, description='The date and time when the generation was created' + ) + assets: Optional[LumaAssets] = None + model: Optional[str] = Field(None, description='The model used for the generation') + request: Optional[ + Union[ + LumaGenerationRequest, + LumaImageGenerationRequest, + LumaUpscaleVideoGenerationRequest, + LumaAudioGenerationRequest, + ] + ] = Field(None, description='The request of the generation') + + +class RunwayImageToVideoRequest(BaseModel): + promptImage: RunwayPromptImageObject + seed: int = Field( + ..., description='Random seed for generation', ge=0, le=4294967295 + ) + model: RunwayModelEnum = Field(..., description='Model to use for generation') + promptText: Optional[str] = Field( + None, description='Text prompt for the generation', max_length=1000 + ) + duration: RunwayDurationEnum = Field( + ..., description='The number of seconds of duration for the output video.' + ) + ratio: RunwayAspectRatioEnum = Field( + ..., + description='The resolution (aspect ratio) of the output video. Allowable values depend on the selected model. 1280:768 and 768:1280 are only supported for gen3a_turbo.', + ) + + +class RunwayTaskStatusResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + status: Optional[RunwayTaskStatusEnum] = Field(None, description='Task status') + createdAt: Optional[datetime] = Field(None, description='Task creation timestamp') + output: Optional[List[str]] = Field(None, description='Array of output video URLs') + + +class PikaHTTPValidationError(BaseModel): + detail: Optional[List[PikaValidationError]] = Field(None, title='Detail') + + +class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel): + promptText: str = Field(..., title='Prompttext') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + seed: Optional[int] = Field(None, title='Seed') + resolution: Optional[PikaResolutionEnum] = Field('1080p', title='Resolution') + duration: Optional[PikaDurationEnum] = Field(5, title='Duration') + aspectRatio: Optional[float] = Field( + 1.7777777777777777, + description='Aspect ratio (width / height)', + ge=0.4, + le=2.5, + title='Aspectratio', + ) + + +class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel): + image: Optional[StrictBytes] = Field(None, title='Image') + promptText: Optional[str] = Field(None, title='Prompttext') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + seed: Optional[int] = Field(None, title='Seed') + resolution: Optional[PikaResolutionEnum] = Field('1080p', title='Resolution') + duration: Optional[PikaDurationEnum] = Field(5, title='Duration') + + +class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel): + keyFrames: Optional[List[StrictBytes]] = Field( + None, description='Array of keyframe images', title='Keyframes' + ) + promptText: str = Field(..., title='Prompttext') + negativePrompt: Optional[str] = Field(None, title='Negativeprompt') + seed: Optional[int] = Field(None, title='Seed') + resolution: Optional[PikaResolutionEnum] = Field('1080p', title='Resolution') + duration: Optional[int] = Field(None, ge=5, le=10, title='Duration') + + +class PikaVideoResponse(BaseModel): + id: str = Field(..., title='Id') + status: PikaStatusEnum = Field( + ..., description='The status of the video', title='Status' + ) + url: Optional[str] = Field(None, title='Url') + progress: Optional[int] = Field(None, title='Progress') + + +class Node(BaseModel): + id: Optional[str] = Field(None, description='The unique identifier of the node.') + name: Optional[str] = Field(None, description='The display name of the node.') + category: Optional[str] = Field(None, description='The category of the node.') + description: Optional[str] = None + author: Optional[str] = None + license: Optional[str] = Field( + None, description="The path to the LICENSE file in the node's repository." + ) + icon: Optional[str] = Field(None, description="URL to the node's icon.") + repository: Optional[str] = Field(None, description="URL to the node's repository.") + tags: Optional[List[str]] = None + latest_version: Optional[NodeVersion] = Field( + None, description='The latest version of the node.' + ) + rating: Optional[float] = Field(None, description='The average rating of the node.') + downloads: Optional[int] = Field( + None, description='The number of downloads of the node.' + ) + publisher: Optional[Publisher] = Field( + None, description='The publisher of the node.' + ) + status: Optional[NodeStatus] = Field(None, description='The status of the node.') + status_detail: Optional[str] = Field( + None, description='The status detail of the node.' + ) + translations: Optional[Dict[str, Dict[str, Any]]] = None + + +class KlingVideoEffectsRequest(BaseModel): + effect_scene: Union[KlingDualCharacterEffectsScene, KlingSingleImageEffectsScene] + input: KlingVideoEffectsInput + callback_url: Optional[AnyUrl] = Field( + None, + description='The callback notification address for the result of this task.', + ) + external_task_id: Optional[str] = Field( + None, + description='Customized Task ID. Must be unique within a single user account.', + ) + + +class StripeCharge(BaseModel): + id: Optional[str] = None + object: Optional[Object2] = None + amount: Optional[int] = None + amount_captured: Optional[int] = None + amount_refunded: Optional[int] = None + application: Optional[str] = None + application_fee: Optional[str] = None + application_fee_amount: Optional[int] = None + balance_transaction: Optional[str] = None + billing_details: Optional[StripeBillingDetails] = None + calculated_statement_descriptor: Optional[str] = None + captured: Optional[bool] = None + created: Optional[int] = None + currency: Optional[str] = None + customer: Optional[str] = None + description: Optional[str] = None + destination: Optional[Any] = None + dispute: Optional[Any] = None + disputed: Optional[bool] = None + failure_balance_transaction: Optional[Any] = None + failure_code: Optional[Any] = None + failure_message: Optional[Any] = None + fraud_details: Optional[Dict[str, Any]] = None + invoice: Optional[Any] = None + livemode: Optional[bool] = None + metadata: Optional[Dict[str, Any]] = None + on_behalf_of: Optional[Any] = None + order: Optional[Any] = None + outcome: Optional[StripeOutcome] = None + paid: Optional[bool] = None + payment_intent: Optional[str] = None + payment_method: Optional[str] = None + payment_method_details: Optional[StripePaymentMethodDetails] = None + radar_options: Optional[Dict[str, Any]] = None + receipt_email: Optional[str] = None + receipt_number: Optional[str] = None + receipt_url: Optional[str] = None + refunded: Optional[bool] = None + refunds: Optional[StripeRefundList] = None + review: Optional[Any] = None + shipping: Optional[StripeShipping] = None + source: Optional[Any] = None + source_transfer: Optional[Any] = None + statement_descriptor: Optional[Any] = None + statement_descriptor_suffix: Optional[Any] = None + status: Optional[str] = None + transfer_data: Optional[Any] = None + transfer_group: Optional[Any] = None + + +class StripeChargeList(BaseModel): + object: Optional[str] = None + data: Optional[List[StripeCharge]] = None + has_more: Optional[bool] = None + total_count: Optional[int] = None + url: Optional[str] = None + + +class StripePaymentIntent(BaseModel): + id: Optional[str] = None + object: Optional[Object1] = None + amount: Optional[int] = None + amount_capturable: Optional[int] = None + amount_details: Optional[StripeAmountDetails] = None + amount_received: Optional[int] = None + application: Optional[str] = None + application_fee_amount: Optional[int] = None + automatic_payment_methods: Optional[Any] = None + canceled_at: Optional[int] = None + cancellation_reason: Optional[str] = None + capture_method: Optional[str] = None + charges: Optional[StripeChargeList] = None + client_secret: Optional[str] = None + confirmation_method: Optional[str] = None + created: Optional[int] = None + currency: Optional[str] = None + customer: Optional[str] = None + description: Optional[str] = None + invoice: Optional[str] = None + last_payment_error: Optional[Any] = None + latest_charge: Optional[str] = None + livemode: Optional[bool] = None + metadata: Optional[Dict[str, Any]] = None + next_action: Optional[Any] = None + on_behalf_of: Optional[Any] = None + payment_method: Optional[str] = None + payment_method_configuration_details: Optional[Any] = None + payment_method_options: Optional[StripePaymentMethodOptions] = None + payment_method_types: Optional[List[str]] = None + processing: Optional[Any] = None + receipt_email: Optional[str] = None + review: Optional[Any] = None + setup_future_usage: Optional[Any] = None + shipping: Optional[StripeShipping] = None + source: Optional[Any] = None + statement_descriptor: Optional[Any] = None + statement_descriptor_suffix: Optional[Any] = None + status: Optional[str] = None + transfer_data: Optional[Any] = None + transfer_group: Optional[Any] = None + + +class Data8(BaseModel): + object: Optional[StripePaymentIntent] = None + + +class StripeEvent(BaseModel): + id: str + object: Object + api_version: Optional[str] = None + created: Optional[int] = None + data: Data8 + livemode: Optional[bool] = None + pending_webhooks: Optional[int] = None + request: Optional[StripeRequestInfo] = None + type: Type diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py new file mode 100644 index 00000000..c189038f --- /dev/null +++ b/comfy_api_nodes/apis/bfl_api.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, confloat, conint + + +class BFLOutputFormat(str, Enum): + png = 'png' + jpeg = 'jpeg' + + +class BFLFluxExpandImageRequest(BaseModel): + prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + ) + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image') + bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image') + left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image') + right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image') + steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') + guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') + safety_tolerance: Optional[conint(ge=0, le=6)] = Field( + 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + ) + output_format: Optional[BFLOutputFormat] = Field( + BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + ) + image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand') + + +class BFLFluxFillImageRequest(BaseModel): + prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + ) + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') + guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') + safety_tolerance: Optional[conint(ge=0, le=6)] = Field( + 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + ) + output_format: Optional[BFLOutputFormat] = Field( + BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + ) + image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.') + mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') + + +class BFLFluxCannyImageRequest(BaseModel): + prompt: str = Field(..., description='Text prompt for image generation') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + ) + canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection') + canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection') + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') + guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') + safety_tolerance: Optional[conint(ge=0, le=6)] = Field( + 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + ) + output_format: Optional[BFLOutputFormat] = Field( + BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + ) + control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') + preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') + + +class BFLFluxDepthImageRequest(BaseModel): + prompt: str = Field(..., description='Text prompt for image generation') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + ) + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') + guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') + safety_tolerance: Optional[conint(ge=0, le=6)] = Field( + 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + ) + output_format: Optional[BFLOutputFormat] = Field( + BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + ) + control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') + preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') + + +class BFLFluxProGenerateRequest(BaseModel): + prompt: str = Field(..., description='The text prompt for image generation.') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + ) + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.') + height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.') + safety_tolerance: Optional[conint(ge=0, le=6)] = Field( + 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + ) + output_format: Optional[BFLOutputFormat] = Field( + BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + ) + image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') + # image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( + # None, description='Blend between the prompt and the image prompt.' + # ) + + +class BFLFluxProUltraGenerateRequest(BaseModel): + prompt: str = Field(..., description='The text prompt for image generation.') + prompt_upsampling: Optional[bool] = Field( + None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + ) + seed: Optional[int] = Field(None, description='The seed value for reproducibility.') + aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') + safety_tolerance: Optional[conint(ge=0, le=6)] = Field( + 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + ) + output_format: Optional[BFLOutputFormat] = Field( + BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + ) + raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.') + image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') + image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( + None, description='Blend between the prompt and the image prompt.' + ) + + +class BFLFluxProGenerateResponse(BaseModel): + id: str = Field(..., description='The unique identifier for the generation task.') + polling_url: str = Field(..., description='URL to poll for the generation result.') + + +class BFLStatus(str, Enum): + task_not_found = "Task not found" + pending = "Pending" + request_moderated = "Request Moderated" + content_moderated = "Content Moderated" + ready = "Ready" + error = "Error" + + +class BFLFluxProStatusResponse(BaseModel): + id: str = Field(..., description="The unique identifier for the generation task.") + status: BFLStatus = Field(..., description="The status of the task.") + result: Optional[Dict[str, Any]] = Field( + None, description="The result of the task (null if not completed)." + ) + progress: confloat(ge=0.0, le=1.0) = Field( + ..., description="The progress of the task (0.0 to 1.0)." + ) + details: Optional[Dict[str, Any]] = Field( + None, description="Additional details about the task (null if not available)." + ) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index d3cd9ad2..929e386d 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -1,5 +1,3 @@ -import logging - """ API Client Framework for api.comfy.org. @@ -46,24 +44,71 @@ operation = ApiOperation( ) user_profile = operation.execute(client=api_client) # Returns immediately with the result + +# Example 2: Asynchronous API Operation with Polling +# ------------------------------------------------- +# For an API that starts a task and requires polling for completion: + +# 1. Define the endpoints (initial request and polling) +generate_image_endpoint = ApiEndpoint( + path="/v1/images/generate", + method=HttpMethod.POST, + request_model=ImageGenerationRequest, + response_model=TaskCreatedResponse, + query_params=None +) + +check_task_endpoint = ApiEndpoint( + path="/v1/tasks/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=ImageGenerationResult, + query_params=None +) + +# 2. Create the request object +request = ImageGenerationRequest( + prompt="a beautiful sunset over mountains", + width=1024, + height=1024, + num_images=1 +) + +# 3. Create and execute the polling operation +operation = PollingOperation( + initial_endpoint=generate_image_endpoint, + initial_request=request, + poll_endpoint=check_task_endpoint, + task_id_field="task_id", + status_field="status", + completed_statuses=["completed"], + failed_statuses=["failed", "error"] +) + +# This will make the initial request and then poll until completion +result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done """ -from typing import ( - Dict, - Type, - Optional, - Any, - TypeVar, - Generic, -) -from pydantic import BaseModel +from __future__ import annotations +import logging +import time +import io +from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable from enum import Enum import json import requests from urllib.parse import urljoin +from pydantic import BaseModel, Field + +from comfy.cli_args import args +from comfy import utils T = TypeVar("T", bound=BaseModel) R = TypeVar("R", bound=BaseModel) +P = TypeVar("P", bound=BaseModel) # For poll response + +PROGRESS_BAR_MAX = 100 + class EmptyRequest(BaseModel): """Base class for empty request bodies. @@ -72,6 +117,19 @@ class EmptyRequest(BaseModel): pass +class UploadRequest(BaseModel): + file_name: str = Field(..., description="Filename to upload") + content_type: str | None = Field( + None, + description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", + ) + + +class UploadResponse(BaseModel): + download_url: str = Field(..., description="URL to GET uploaded file") + upload_url: str = Field(..., description="URL to PUT file to upload") + + class HttpMethod(str, Enum): GET = "GET" POST = "POST" @@ -89,7 +147,7 @@ class ApiClient: self, base_url: str, api_key: Optional[str] = None, - timeout: float = 30.0, + timeout: float = 3600.0, verify_ssl: bool = True, ): self.base_url = base_url @@ -97,6 +155,48 @@ class ApiClient: self.timeout = timeout self.verify_ssl = verify_ssl + def _create_json_payload_args( + self, + data: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + return { + "json": data, + "headers": headers, + } + + def _create_form_data_args( + self, + data: Dict[str, Any], + files: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, + multipart_parser = None, + ) -> Dict[str, Any]: + if headers and "Content-Type" in headers: + del headers["Content-Type"] + + if multipart_parser: + data = multipart_parser(data) + + return { + "data": data, + "files": files, + "headers": headers, + } + + def _create_urlencoded_form_data_args( + self, + data: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + headers = headers or {} + headers["Content-Type"] = "application/x-www-form-urlencoded" + + return { + "data": data, + "headers": headers, + } + def get_headers(self) -> Dict[str, str]: """Get headers for API requests, including authentication if available""" headers = {"Content-Type": "application/json", "Accept": "application/json"} @@ -111,9 +211,11 @@ class ApiClient: method: str, path: str, params: Optional[Dict[str, Any]] = None, - json: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, files: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, + content_type: str = "application/json", + multipart_parser: Callable = None, ) -> Dict[str, Any]: """ Make an HTTP request to the API @@ -122,9 +224,10 @@ class ApiClient: method: HTTP method (GET, POST, etc.) path: API endpoint path (will be joined with base_url) params: Query parameters - json: JSON body data + data: body data files: Files to upload headers: Additional headers + content_type: Content type of the request. Defaults to application/json. Returns: Parsed JSON response @@ -146,34 +249,26 @@ class ApiClient: logging.debug(f"[DEBUG] Request Headers: {request_headers}") logging.debug(f"[DEBUG] Files: {files}") logging.debug(f"[DEBUG] Params: {params}") - logging.debug(f"[DEBUG] Json: {json}") + logging.debug(f"[DEBUG] Data: {data}") + + if content_type == "application/x-www-form-urlencoded": + payload_args = self._create_urlencoded_form_data_args(data, request_headers) + elif content_type == "multipart/form-data": + payload_args = self._create_form_data_args( + data, files, request_headers, multipart_parser + ) + else: + payload_args = self._create_json_payload_args(data, request_headers) try: - # If files are present, use data parameter instead of json - if files: - form_data = {} - if json: - form_data.update(json) - response = requests.request( - method=method, - url=url, - params=params, - data=form_data, # Use data instead of json - files=files, - headers=request_headers, - timeout=self.timeout, - verify=self.verify_ssl, - ) - else: - response = requests.request( - method=method, - url=url, - params=params, - json=json, - headers=request_headers, - timeout=self.timeout, - verify=self.verify_ssl, - ) + response = requests.request( + method=method, + url=url, + params=params, + timeout=self.timeout, + verify=self.verify_ssl, + **payload_args, + ) # Raise exception for error status codes response.raise_for_status() @@ -203,7 +298,9 @@ class ApiClient: error_message = f"API Error: {error_json}" except Exception as json_error: # If we can't parse the JSON, fall back to the original error message - logging.debug(f"[DEBUG] Failed to parse error response: {str(json_error)}") + logging.debug( + f"[DEBUG] Failed to parse error response: {str(json_error)}" + ) logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})") if hasattr(e, "response") and e.response.content: @@ -229,6 +326,32 @@ class ApiClient: raise Exception("Unauthorized: Please login first to use this node.") return auth_token + @staticmethod + def upload_file( + upload_url: str, + file: io.BytesIO | str, + content_type: str | None = None, + ): + """Upload a file to the API. Make sure the file has a filename equal to what the url expects. + + Args: + upload_url: The URL to upload to + file: Either a file path string, BytesIO object, or tuple of (file_path, filename) + mime_type: Optional mime type to set for the upload + """ + headers = {} + if content_type: + headers["Content-Type"] = content_type + + if isinstance(file, io.BytesIO): + file.seek(0) # Ensure we're at the start of the file + data = file.read() + return requests.put(upload_url, data=data, headers=headers) + elif isinstance(file, str): + with open(file, "rb") as f: + data = f.read() + return requests.put(upload_url, data=data, headers=headers) + class ApiEndpoint(Generic[T, R]): """Defines an API endpoint with its request and response types""" @@ -267,27 +390,29 @@ class SynchronousOperation(Generic[T, R]): endpoint: ApiEndpoint[T, R], request: T, files: Optional[Dict[str, Any]] = None, - api_base: str = "https://api.comfy.org", + api_base: str | None = None, auth_token: Optional[str] = None, timeout: float = 604800.0, verify_ssl: bool = True, + content_type: str = "application/json", + multipart_parser: Callable = None, ): self.endpoint = endpoint self.request = request self.response = None self.error = None - self.api_base = api_base + self.api_base: str = api_base or args.comfy_api_base self.auth_token = auth_token self.timeout = timeout self.verify_ssl = verify_ssl self.files = files + self.content_type = content_type + self.multipart_parser = multipart_parser def execute(self, client: Optional[ApiClient] = None) -> R: """Execute the API operation using the provided client or create one""" try: # Create client if not provided if client is None: - if self.api_base is None: - raise ValueError("Either client or api_base must be provided") client = ApiClient( base_url=self.api_base, api_key=self.auth_token, @@ -296,14 +421,25 @@ class SynchronousOperation(Generic[T, R]): ) # Convert request model to dict, but use None for EmptyRequest - request_dict = None if isinstance(self.request, EmptyRequest) else self.request.model_dump(exclude_none=True) + request_dict = ( + None + if isinstance(self.request, EmptyRequest) + else self.request.model_dump(exclude_none=True) + ) if request_dict: for key, value in request_dict.items(): if isinstance(value, Enum): request_dict[key] = value.value + if request_dict: + for key, value in request_dict.items(): + if isinstance(value, Enum): + request_dict[key] = value.value + # Debug log for request - logging.debug(f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}") + logging.debug( + f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" + ) logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") @@ -311,9 +447,11 @@ class SynchronousOperation(Generic[T, R]): resp = client.request( method=self.endpoint.method.value, path=self.endpoint.path, - json=request_dict, + data=request_dict, params=self.endpoint.query_params, files=self.files, + content_type=self.content_type, + multipart_parser=self.multipart_parser ) # Debug log for response @@ -327,7 +465,7 @@ class SynchronousOperation(Generic[T, R]): return self._parse_response(resp) except Exception as e: - logging.debug(f"[DEBUG] API Exception: {str(e)}") + logging.error(f"[DEBUG] API Exception: {str(e)}") raise Exception(str(e)) def _parse_response(self, resp): @@ -339,3 +477,140 @@ class SynchronousOperation(Generic[T, R]): self.response = self.endpoint.response_model.model_validate(resp) logging.debug(f"[DEBUG] Parsed Response: {self.response}") return self.response + + +class TaskStatus(str, Enum): + """Enum for task status values""" + + COMPLETED = "completed" + FAILED = "failed" + PENDING = "pending" + + +class PollingOperation(Generic[T, R]): + """ + Represents an asynchronous API operation that requires polling for completion. + """ + + def __init__( + self, + poll_endpoint: ApiEndpoint[EmptyRequest, R], + completed_statuses: list, + failed_statuses: list, + status_extractor: Callable[[R], str], + progress_extractor: Callable[[R], float] = None, + request: Optional[T] = None, + api_base: str | None = None, + auth_token: Optional[str] = None, + poll_interval: float = 5.0, + ): + self.poll_endpoint = poll_endpoint + self.request = request + self.api_base: str = api_base or args.comfy_api_base + self.auth_token = auth_token + self.poll_interval = poll_interval + + # Polling configuration + self.status_extractor = status_extractor or ( + lambda x: getattr(x, "status", None) + ) + self.progress_extractor = progress_extractor + self.completed_statuses = completed_statuses + self.failed_statuses = failed_statuses + + # For storing response data + self.final_response = None + self.error = None + + def execute(self, client: Optional[ApiClient] = None) -> R: + """Execute the polling operation using the provided client. If failed, raise an exception.""" + try: + if client is None: + client = ApiClient( + base_url=self.api_base, + api_key=self.auth_token, + ) + return self._poll_until_complete(client) + except Exception as e: + raise Exception(f"Error during polling: {str(e)}") + + def _check_task_status(self, response: R) -> TaskStatus: + """Check task status using the status extractor function""" + try: + status = self.status_extractor(response) + if status in self.completed_statuses: + return TaskStatus.COMPLETED + elif status in self.failed_statuses: + return TaskStatus.FAILED + return TaskStatus.PENDING + except Exception as e: + logging.error(f"Error extracting status: {e}") + return TaskStatus.PENDING + + def _poll_until_complete(self, client: ApiClient) -> R: + """Poll until the task is complete""" + poll_count = 0 + if self.progress_extractor: + progress = utils.ProgressBar(PROGRESS_BAR_MAX) + + while True: + try: + poll_count += 1 + logging.debug(f"[DEBUG] Polling attempt #{poll_count}") + + request_dict = ( + self.request.model_dump(exclude_none=True) + if self.request is not None + else None + ) + + if poll_count == 1: + logging.debug( + f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" + ) + logging.debug( + f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" + ) + + # Query task status + resp = client.request( + method=self.poll_endpoint.method.value, + path=self.poll_endpoint.path, + params=self.poll_endpoint.query_params, + data=request_dict, + ) + + # Parse response + response_obj = self.poll_endpoint.response_model.model_validate(resp) + # Check if task is complete + status = self._check_task_status(response_obj) + logging.debug(f"[DEBUG] Task Status: {status}") + + # If progress extractor is provided, extract progress + if self.progress_extractor: + new_progress = self.progress_extractor(response_obj) + if new_progress is not None: + progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) + + if status == TaskStatus.COMPLETED: + logging.debug("[DEBUG] Task completed successfully") + self.final_response = response_obj + if self.progress_extractor: + progress.update(100) + return self.final_response + elif status == TaskStatus.FAILED: + message = f"Task failed: {json.dumps(resp)}" + logging.error(f"[DEBUG] {message}") + raise Exception(message) + else: + logging.debug("[DEBUG] Task still pending, continuing to poll...") + + # Wait before polling again + logging.debug( + f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" + ) + time.sleep(self.poll_interval) + + except Exception as e: + logging.error(f"[DEBUG] Polling error: {str(e)}") + raise Exception(f"Error while polling: {str(e)}") diff --git a/comfy_api_nodes/apis/luma_api.py b/comfy_api_nodes/apis/luma_api.py new file mode 100644 index 00000000..632c4ab9 --- /dev/null +++ b/comfy_api_nodes/apis/luma_api.py @@ -0,0 +1,253 @@ +from __future__ import annotations + + +import torch + +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel, Field, confloat + + + +class LumaIO: + LUMA_REF = "LUMA_REF" + LUMA_CONCEPTS = "LUMA_CONCEPTS" + + +class LumaReference: + def __init__(self, image: torch.Tensor, weight: float): + self.image = image + self.weight = weight + + def create_api_model(self, download_url: str): + return LumaImageRef(url=download_url, weight=self.weight) + +class LumaReferenceChain: + def __init__(self, first_ref: LumaReference=None): + self.refs: list[LumaReference] = [] + if first_ref: + self.refs.append(first_ref) + + def add(self, luma_ref: LumaReference=None): + self.refs.append(luma_ref) + + def create_api_model(self, download_urls: list[str], max_refs=4): + if len(self.refs) == 0: + return None + api_refs: list[LumaImageRef] = [] + for ref, url in zip(self.refs, download_urls): + api_ref = LumaImageRef(url=url, weight=ref.weight) + api_refs.append(api_ref) + return api_refs + + def clone(self): + c = LumaReferenceChain() + for ref in self.refs: + c.add(ref) + return c + + +class LumaConcept: + def __init__(self, key: str): + self.key = key + + +class LumaConceptChain: + def __init__(self, str_list: list[str] = None): + self.concepts: list[LumaConcept] = [] + if str_list is not None: + for c in str_list: + if c != "None": + self.add(LumaConcept(key=c)) + + def add(self, concept: LumaConcept): + self.concepts.append(concept) + + def create_api_model(self): + if len(self.concepts) == 0: + return None + api_concepts: list[LumaConceptObject] = [] + for concept in self.concepts: + if concept.key == "None": + continue + api_concepts.append(LumaConceptObject(key=concept.key)) + if len(api_concepts) == 0: + return None + return api_concepts + + def clone(self): + c = LumaConceptChain() + for concept in self.concepts: + c.add(concept) + return c + + def clone_and_merge(self, other: LumaConceptChain): + c = self.clone() + for concept in other.concepts: + c.add(concept) + return c + + +def get_luma_concepts(include_none=False): + concepts = [] + if include_none: + concepts.append("None") + return concepts + [ + "truck_left", + "pan_right", + "pedestal_down", + "low_angle", + "pedestal_up", + "selfie", + "pan_left", + "roll_right", + "zoom_in", + "over_the_shoulder", + "orbit_right", + "orbit_left", + "static", + "tiny_planet", + "high_angle", + "bolt_cam", + "dolly_zoom", + "overhead", + "zoom_out", + "handheld", + "roll_left", + "pov", + "aerial_drone", + "push_in", + "crane_down", + "truck_right", + "tilt_down", + "elevator_doors", + "tilt_up", + "ground_level", + "pull_out", + "aerial", + "crane_up", + "eye_level" + ] + + +class LumaImageModel(str, Enum): + photon_1 = "photon-1" + photon_flash_1 = "photon-flash-1" + + +class LumaVideoModel(str, Enum): + ray_2 = "ray-2" + ray_flash_2 = "ray-flash-2" + ray_1_6 = "ray-1-6" + + +class LumaAspectRatio(str, Enum): + ratio_1_1 = "1:1" + ratio_16_9 = "16:9" + ratio_9_16 = "9:16" + ratio_4_3 = "4:3" + ratio_3_4 = "3:4" + ratio_21_9 = "21:9" + ratio_9_21 = "9:21" + + +class LumaVideoOutputResolution(str, Enum): + res_540p = "540p" + res_720p = "720p" + res_1080p = "1080p" + res_4k = "4k" + + +class LumaVideoModelOutputDuration(str, Enum): + dur_5s = "5s" + dur_9s = "9s" + + +class LumaGenerationType(str, Enum): + video = 'video' + image = 'image' + + +class LumaState(str, Enum): + queued = "queued" + dreaming = "dreaming" + completed = "completed" + failed = "failed" + + +class LumaAssets(BaseModel): + video: Optional[str] = Field(None, description='The URL of the video') + image: Optional[str] = Field(None, description='The URL of the image') + progress_video: Optional[str] = Field(None, description='The URL of the progress video') + + +class LumaImageRef(BaseModel): + '''Used for image gen''' + url: str = Field(..., description='The URL of the image reference') + weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + + +class LumaImageReference(BaseModel): + '''Used for video gen''' + type: Optional[str] = Field('image', description='Input type, defaults to image') + url: str = Field(..., description='The URL of the image') + + +class LumaModifyImageRef(BaseModel): + url: str = Field(..., description='The URL of the image reference') + weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + + +class LumaCharacterRef(BaseModel): + identity0: LumaImageIdentity = Field(..., description='The image identity object') + + +class LumaImageIdentity(BaseModel): + images: list[str] = Field(..., description='The URLs of the image identity') + + +class LumaGenerationReference(BaseModel): + type: str = Field('generation', description='Input type, defaults to generation') + id: str = Field(..., description='The ID of the generation') + + +class LumaKeyframes(BaseModel): + frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') + frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') + + +class LumaConceptObject(BaseModel): + key: str = Field(..., description='Camera Concept name') + + +class LumaImageGenerationRequest(BaseModel): + prompt: str = Field(..., description='The prompt of the generation') + model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation') + aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation') + image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects') + style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects') + character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object') + modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object') + + +class LumaGenerationRequest(BaseModel): + prompt: str = Field(..., description='The prompt of the generation') + model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation') + duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation') + aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation') + resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation') + loop: Optional[bool] = Field(None, description='Whether to loop the video') + keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation') + concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation') + + +class LumaGeneration(BaseModel): + id: str = Field(..., description='The ID of the generation') + generation_type: LumaGenerationType = Field(..., description='Generation type, image or video') + state: LumaState = Field(..., description='The state of the generation') + failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation') + created_at: str = Field(..., description='The date and time when the generation was created') + assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') + model: str = Field(..., description='The model used for the generation') + request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") diff --git a/comfy_api_nodes/apis/pixverse_api.py b/comfy_api_nodes/apis/pixverse_api.py new file mode 100644 index 00000000..9bb29c38 --- /dev/null +++ b/comfy_api_nodes/apis/pixverse_api.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +pixverse_templates = { + "Microwave": 324641385496960, + "Suit Swagger": 328545151283968, + "Anything, Robot": 313358700761536, + "Subject 3 Fever": 327828816843648, + "kiss kiss": 315446315336768, +} + + +class PixverseIO: + TEMPLATE = "PIXVERSE_TEMPLATE" + + +class PixverseStatus(int, Enum): + successful = 1 + generating = 5 + deleted = 6 + contents_moderation = 7 + failed = 8 + + +class PixverseAspectRatio(str, Enum): + ratio_16_9 = "16:9" + ratio_4_3 = "4:3" + ratio_1_1 = "1:1" + ratio_3_4 = "3:4" + ratio_9_16 = "9:16" + + +class PixverseQuality(str, Enum): + res_360p = "360p" + res_540p = "540p" + res_720p = "720p" + res_1080p = "1080p" + + +class PixverseDuration(int, Enum): + dur_5 = 5 + dur_8 = 8 + + +class PixverseMotionMode(str, Enum): + normal = "normal" + fast = "fast" + + +class PixverseStyle(str, Enum): + anime = "anime" + animation_3d = "3d_animation" + clay = "clay" + comic = "comic" + cyberpunk = "cyberpunk" + + +# NOTE: forgoing descriptions for now in return for dev speed +class PixverseTextVideoRequest(BaseModel): + aspect_ratio: PixverseAspectRatio = Field(...) + quality: PixverseQuality = Field(...) + duration: PixverseDuration = Field(...) + model: Optional[str] = Field("v3.5") + motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal) + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + style: Optional[str] = Field(None) + template_id: Optional[int] = Field(None) + water_mark: Optional[bool] = Field(None) + + +class PixverseImageVideoRequest(BaseModel): + quality: PixverseQuality = Field(...) + duration: PixverseDuration = Field(...) + img_id: int = Field(...) + model: Optional[str] = Field("v3.5") + motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal) + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + style: Optional[str] = Field(None) + template_id: Optional[int] = Field(None) + water_mark: Optional[bool] = Field(None) + + +class PixverseTransitionVideoRequest(BaseModel): + quality: PixverseQuality = Field(...) + duration: PixverseDuration = Field(...) + first_frame_img: int = Field(...) + last_frame_img: int = Field(...) + model: Optional[str] = Field("v3.5") + motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal) + prompt: str = Field(...) + # negative_prompt: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + # style: Optional[str] = Field(None) + # template_id: Optional[int] = Field(None) + # water_mark: Optional[bool] = Field(None) + + +class PixverseImageUploadResponse(BaseModel): + ErrCode: Optional[int] = None + ErrMsg: Optional[str] = None + Resp: Optional[PixverseImgIdResponseObject] = Field(None, alias='Resp') + + +class PixverseImgIdResponseObject(BaseModel): + img_id: Optional[int] = None + + +class PixverseVideoResponse(BaseModel): + ErrCode: Optional[int] = Field(None) + ErrMsg: Optional[str] = Field(None) + Resp: Optional[PixverseVideoIdResponseObject] = Field(None) + + +class PixverseVideoIdResponseObject(BaseModel): + video_id: int = Field(..., description='Video_id') + + +class PixverseGenerationStatusResponse(BaseModel): + ErrCode: Optional[int] = Field(None) + ErrMsg: Optional[str] = Field(None) + Resp: Optional[PixverseGenerationStatusResponseObject] = Field(None) + + +class PixverseGenerationStatusResponseObject(BaseModel): + create_time: Optional[str] = Field(None) + id: Optional[int] = Field(None) + modify_time: Optional[str] = Field(None) + negative_prompt: Optional[str] = Field(None) + outputHeight: Optional[int] = Field(None) + outputWidth: Optional[int] = Field(None) + prompt: Optional[str] = Field(None) + resolution_ratio: Optional[int] = Field(None) + seed: Optional[int] = Field(None) + size: Optional[int] = Field(None) + status: Optional[int] = Field(None) + style: Optional[str] = Field(None) + url: Optional[str] = Field(None) diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft_api.py new file mode 100644 index 00000000..c0ec9d0c --- /dev/null +++ b/comfy_api_nodes/apis/recraft_api.py @@ -0,0 +1,263 @@ +from __future__ import annotations + + + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field, conint, confloat + + +class RecraftColor: + def __init__(self, r: int, g: int, b: int): + self.color = [r, g, b] + + def create_api_model(self): + return RecraftColorObject(rgb=self.color) + + +class RecraftColorChain: + def __init__(self): + self.colors: list[RecraftColor] = [] + + def get_first(self): + if len(self.colors) > 0: + return self.colors[0] + return None + + def add(self, color: RecraftColor): + self.colors.append(color) + + def create_api_model(self): + if not self.colors: + return None + colors_api = [x.create_api_model() for x in self.colors] + return colors_api + + def clone(self): + c = RecraftColorChain() + for color in self.colors: + c.add(color) + return c + + def clone_and_merge(self, other: RecraftColorChain): + c = self.clone() + for color in other.colors: + c.add(color) + return c + + +class RecraftControls: + def __init__(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None, + artistic_level: int=None, no_text: bool=None): + self.colors = colors + self.background_color = background_color + self.artistic_level = artistic_level + self.no_text = no_text + + def create_api_model(self): + if self.colors is None and self.background_color is None and self.artistic_level is None and self.no_text is None: + return None + colors_api = None + background_color_api = None + if self.colors: + colors_api = self.colors.create_api_model() + if self.background_color: + first_background = self.background_color.get_first() + background_color_api = first_background.create_api_model() if first_background else None + + return RecraftControlsObject(colors=colors_api, background_color=background_color_api, + artistic_level=self.artistic_level, no_text=self.no_text) + + +class RecraftStyle: + def __init__(self, style: str=None, substyle: str=None, style_id: str=None): + self.style = style + if substyle == "None": + substyle = None + self.substyle = substyle + self.style_id = style_id + + +class RecraftIO: + STYLEV3 = "RECRAFT_V3_STYLE" + SVG = "SVG" # TODO: if acceptable, move into ComfyUI's typing class + COLOR = "RECRAFT_COLOR" + CONTROLS = "RECRAFT_CONTROLS" + + +class RecraftStyleV3(str, Enum): + #any = 'any' NOTE: this does not work for some reason... why? + realistic_image = 'realistic_image' + digital_illustration = 'digital_illustration' + vector_illustration = 'vector_illustration' + logo_raster = 'logo_raster' + + +def get_v3_substyles(style_v3: str, include_none=True) -> list[str]: + substyles: list[str] = [] + if include_none: + substyles.append("None") + return substyles + dict_recraft_substyles_v3.get(style_v3, []) + + +dict_recraft_substyles_v3 = { + RecraftStyleV3.realistic_image: [ + "b_and_w", + "enterprise", + "evening_light", + "faded_nostalgia", + "forest_life", + "hard_flash", + "hdr", + "motion_blur", + "mystic_naturalism", + "natural_light", + "natural_tones", + "organic_calm", + "real_life_glow", + "retro_realism", + "retro_snapshot", + "studio_portrait", + "urban_drama", + "village_realism", + "warm_folk" + ], + RecraftStyleV3.digital_illustration: [ + "2d_art_poster", + "2d_art_poster_2", + "antiquarian", + "bold_fantasy", + "child_book", + "child_books", + "cover", + "crosshatch", + "digital_engraving", + "engraving_color", + "expressionism", + "freehand_details", + "grain", + "grain_20", + "graphic_intensity", + "hand_drawn", + "hand_drawn_outline", + "handmade_3d", + "hard_comics", + "infantile_sketch", + "long_shadow", + "modern_folk", + "multicolor", + "neon_calm", + "noir", + "nostalgic_pastel", + "outline_details", + "pastel_gradient", + "pastel_sketch", + "pixel_art", + "plastic", + "pop_art", + "pop_renaissance", + "seamless", + "street_art", + "tablet_sketch", + "urban_glow", + "urban_sketching", + "vanilla_dreams", + "young_adult_book", + "young_adult_book_2" + ], + RecraftStyleV3.vector_illustration: [ + "bold_stroke", + "chemistry", + "colored_stencil", + "contour_pop_art", + "cosmics", + "cutout", + "depressive", + "editorial", + "emotional_flat", + "engraving", + "infographical", + "line_art", + "line_circuit", + "linocut", + "marker_outline", + "mosaic", + "naivector", + "roundish_flat", + "seamless", + "segmented_colors", + "sharp_contrast", + "thin", + "vector_photo", + "vivid_shapes" + ], + RecraftStyleV3.logo_raster: [ + "emblem_graffiti", + "emblem_pop_art", + "emblem_punk", + "emblem_stamp", + "emblem_vintage" + ], +} + + +class RecraftModel(str, Enum): + recraftv3 = 'recraftv3' + recraftv2 = 'recraftv2' + + +class RecraftImageSize(str, Enum): + res_1024x1024 = '1024x1024' + res_1365x1024 = '1365x1024' + res_1024x1365 = '1024x1365' + res_1536x1024 = '1536x1024' + res_1024x1536 = '1024x1536' + res_1820x1024 = '1820x1024' + res_1024x1820 = '1024x1820' + res_1024x2048 = '1024x2048' + res_2048x1024 = '2048x1024' + res_1434x1024 = '1434x1024' + res_1024x1434 = '1024x1434' + res_1024x1280 = '1024x1280' + res_1280x1024 = '1280x1024' + res_1024x1707 = '1024x1707' + res_1707x1024 = '1707x1024' + + +class RecraftColorObject(BaseModel): + rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model') + + +class RecraftControlsObject(BaseModel): + colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors') + background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color') + no_text: Optional[bool] = Field(None, description='Do not embed text layouts') + artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].') + + +class RecraftImageGenerationRequest(BaseModel): + prompt: str = Field(..., description='The text prompt describing the image to generate') + size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")') + n: conint(ge=1, le=6) = Field(..., description='The number of images to generate') + negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image') + model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') + style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') + substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input') + controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process') + style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID') + strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') + random_seed: Optional[int] = Field(None, description="Seed for video generation") + # text_layout + + +class RecraftReturnedObject(BaseModel): + image_id: str = Field(..., description='Unique identifier for the generated image') + url: str = Field(..., description='URL to access the generated image') + + +class RecraftImageGenerationResponse(BaseModel): + created: int = Field(..., description='Unix timestamp when the generation was created') + credits: int = Field(..., description='Number of credits used for the generation') + data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information') + image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image') diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability_api.py new file mode 100644 index 00000000..47c87dae --- /dev/null +++ b/comfy_api_nodes/apis/stability_api.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field, confloat + + +class StabilityFormat(str, Enum): + png = 'png' + jpeg = 'jpeg' + webp = 'webp' + + +class StabilityAspectRatio(str, Enum): + ratio_1_1 = "1:1" + ratio_16_9 = "16:9" + ratio_9_16 = "9:16" + ratio_3_2 = "3:2" + ratio_2_3 = "2:3" + ratio_5_4 = "5:4" + ratio_4_5 = "4:5" + ratio_21_9 = "21:9" + ratio_9_21 = "9:21" + + +def get_stability_style_presets(include_none=True): + presets = [] + if include_none: + presets.append("None") + return presets + [x.value for x in StabilityStylePreset] + + +class StabilityStylePreset(str, Enum): + _3d_model = "3d-model" + analog_film = "analog-film" + anime = "anime" + cinematic = "cinematic" + comic_book = "comic-book" + digital_art = "digital-art" + enhance = "enhance" + fantasy_art = "fantasy-art" + isometric = "isometric" + line_art = "line-art" + low_poly = "low-poly" + modeling_compound = "modeling-compound" + neon_punk = "neon-punk" + origami = "origami" + photographic = "photographic" + pixel_art = "pixel-art" + tile_texture = "tile-texture" + + +class Stability_SD3_5_Model(str, Enum): + sd3_5_large = "sd3.5-large" + # sd3_5_large_turbo = "sd3.5-large-turbo" + sd3_5_medium = "sd3.5-medium" + + +class Stability_SD3_5_GenerationMode(str, Enum): + text_to_image = "text-to-image" + image_to_image = "image-to-image" + + +class StabilityStable3_5Request(BaseModel): + model: str = Field(...) + mode: str = Field(...) + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + aspect_ratio: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + output_format: Optional[str] = Field(StabilityFormat.png.value) + image: Optional[str] = Field(None) + style_preset: Optional[str] = Field(None) + cfg_scale: float = Field(...) + strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None) + + +class StabilityUpscaleConservativeRequest(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + output_format: Optional[str] = Field(StabilityFormat.png.value) + image: Optional[str] = Field(None) + creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None) + + +class StabilityUpscaleCreativeRequest(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + output_format: Optional[str] = Field(StabilityFormat.png.value) + image: Optional[str] = Field(None) + creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None) + style_preset: Optional[str] = Field(None) + + +class StabilityStableUltraRequest(BaseModel): + prompt: str = Field(...) + negative_prompt: Optional[str] = Field(None) + aspect_ratio: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + output_format: Optional[str] = Field(StabilityFormat.png.value) + image: Optional[str] = Field(None) + style_preset: Optional[str] = Field(None) + strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None) + + +class StabilityStableUltraResponse(BaseModel): + image: Optional[str] = Field(None) + finish_reason: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + + +class StabilityResultsGetResponse(BaseModel): + image: Optional[str] = Field(None) + finish_reason: Optional[str] = Field(None) + seed: Optional[int] = Field(None) + id: Optional[str] = Field(None) + name: Optional[str] = Field(None) + errors: Optional[list[str]] = Field(None) + status: Optional[str] = Field(None) + result: Optional[str] = Field(None) + + +class StabilityAsyncResponse(BaseModel): + id: Optional[str] = Field(None) diff --git a/comfy_api_nodes/mapper_utils.py b/comfy_api_nodes/mapper_utils.py new file mode 100644 index 00000000..6fab8f4b --- /dev/null +++ b/comfy_api_nodes/mapper_utils.py @@ -0,0 +1,116 @@ +from enum import Enum + +from pydantic.fields import FieldInfo +from pydantic import BaseModel +from pydantic_core import PydanticUndefined + +from comfy.comfy_types.node_typing import IO, InputTypeOptions + +NodeInput = tuple[IO, InputTypeOptions] + + +def _create_base_config(field_info: FieldInfo) -> InputTypeOptions: + config = {} + if hasattr(field_info, "default") and field_info.default is not PydanticUndefined: + config["default"] = field_info.default + if hasattr(field_info, "description") and field_info.description is not None: + config["tooltip"] = field_info.description + return config + + +def _get_number_constraints_config(field_info: FieldInfo) -> dict: + config = {} + if hasattr(field_info, "metadata"): + metadata = field_info.metadata + for constraint in metadata: + if hasattr(constraint, "ge"): + config["min"] = constraint.ge + if hasattr(constraint, "le"): + config["max"] = constraint.le + if hasattr(constraint, "multiple_of"): + config["step"] = constraint.multiple_of + return config + + +def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput: + return IO.IMAGE, { + **_create_base_config(field_info), + **kwargs, + } + + +def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput: + return IO.STRING, { + **_create_base_config(field_info), + **kwargs, + } + + +def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput: + return IO.FLOAT, { + **_create_base_config(field_info), + **_get_number_constraints_config(field_info), + **kwargs, + } + + +def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput: + return IO.INT, { + **_create_base_config(field_info), + **_get_number_constraints_config(field_info), + **kwargs, + } + + +def _model_field_to_combo_input( + field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs +) -> NodeInput: + combo_config = {} + if enum_type is not None: + combo_config["options"] = [option.value for option in enum_type] + combo_config = { + **combo_config, + **_create_base_config(field_info), + **kwargs, + } + return IO.COMBO, combo_config + + +def model_field_to_node_input( + input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs +) -> NodeInput: + """ + Maps a field from a Pydantic model to a Comfy node input. + + Args: + input_type: The type of the input. + base_model: The Pydantic model to map the field from. + field_name: The name of the field to map. + **kwargs: Additional key/values to include in the input options. + + Note: + For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically. + + Example: + >>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True) + >>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum) + >>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True) + """ + field_info: FieldInfo = base_model.model_fields[field_name] + result: NodeInput + + if input_type == IO.IMAGE: + result = _model_field_to_image_input(field_info, **kwargs) + elif input_type == IO.STRING: + result = _model_field_to_string_input(field_info, **kwargs) + elif input_type == IO.FLOAT: + result = _model_field_to_float_input(field_info, **kwargs) + elif input_type == IO.INT: + result = _model_field_to_int_input(field_info, **kwargs) + elif input_type == IO.COMBO: + result = _model_field_to_combo_input(field_info, **kwargs) + else: + message = f"Invalid input type: {input_type}" + raise ValueError(message) + + return result diff --git a/comfy_api_nodes/nodes_api.py b/comfy_api_nodes/nodes_api.py deleted file mode 100644 index a977bb9b..00000000 --- a/comfy_api_nodes/nodes_api.py +++ /dev/null @@ -1,449 +0,0 @@ -import base64 -import io -import math -from inspect import cleandoc - -import numpy as np -import requests -import torch -from PIL import Image - -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from comfy.utils import common_upscale -from comfy_api_nodes.apis import ( - OpenAIImageEditRequest, - OpenAIImageGenerationRequest, - OpenAIImageGenerationResponse, -) -from comfy_api_nodes.apis.client import ApiEndpoint, HttpMethod, SynchronousOperation - - -def downscale_input(image): - samples = image.movedim(-1,1) - #downscaling input images to roughly the same size as the outputs - total = int(1536 * 1024) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: - return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1,-1) - return s - -def validate_and_cast_response(response): - # validate raw JSON response - data = response.data - if not data or len(data) == 0: - raise Exception("No images returned from API endpoint") - - # Initialize list to store image tensors - image_tensors = [] - - # Process each image in the data array - for image_data in data: - image_url = image_data.url - b64_data = image_data.b64_json - - if not image_url and not b64_data: - raise Exception("No image was generated in the response") - - if b64_data: - img_data = base64.b64decode(b64_data) - img = Image.open(io.BytesIO(img_data)) - - elif image_url: - img_response = requests.get(image_url) - if img_response.status_code != 200: - raise Exception("Failed to download the image") - img = Image.open(io.BytesIO(img_response.content)) - - img = img.convert("RGBA") - - # Convert to numpy array, normalize to float32 between 0 and 1 - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array) - - # Add to list of tensors - image_tensors.append(img_tensor) - - return torch.stack(image_tensors, dim=0) - -class OpenAIDalle2(ComfyNodeABC): - """ - Generates images synchronously via OpenAI's DALL·E 2 endpoint. - - Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived, - so download or cache results if you need to keep them. - """ - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": (IO.STRING, { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }), - }, - "optional": { - "seed": (IO.INT, { - "default": 0, - "min": 0, - "max": 2**31-1, - "step": 1, - "display": "number", - "tooltip": "not implemented yet in backend", - }), - "size": (IO.COMBO, { - "options": ["256x256", "512x512", "1024x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }), - "n": (IO.INT, { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }), - "image": (IO.IMAGE, { - "default": None, - "tooltip": "Optional reference image for image editing.", - }), - "mask": (IO.MASK, { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG" - } - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call(self, prompt, seed=0, image=None, mask=None, n=1, size="1024x1024", auth_token=None): - model = "dall-e-2" - path = "/proxy/openai/images/generations" - request_class = OpenAIImageGenerationRequest - img_binary = None - - if image is not None and mask is not None: - path = "/proxy/openai/images/edits" - request_class = OpenAIImageEditRequest - - input_tensor = image.squeeze().cpu() - height, width, channels = input_tensor.shape - rgba_tensor = torch.ones(height, width, 4, device="cpu") - rgba_tensor[:, :, :channels] = input_tensor - - if mask.shape[1:] != image.shape[1:-1]: - raise Exception("Mask and Image must be the same size") - rgba_tensor[:,:,3] = (1-mask.squeeze().cpu()) - - rgba_tensor = downscale_input(rgba_tensor.unsqueeze(0)).squeeze() - - image_np = (rgba_tensor.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') - img_byte_arr.seek(0) - img_binary = img_byte_arr#.getvalue() - img_binary.name = "image.png" - elif image is not None or mask is not None: - raise Exception("Dall-E 2 image editing requires an image AND a mask") - - # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse - ), - request=request_class( - model=model, - prompt=prompt, - n=n, - size=size, - seed=seed, - ), - files={ - "image": img_binary, - } if img_binary else None, - auth_token=auth_token - ) - - response = operation.execute() - - img_tensor = validate_and_cast_response(response) - return (img_tensor,) - -class OpenAIDalle3(ComfyNodeABC): - """ - Generates images synchronously via OpenAI's DALL·E 3 endpoint. - - Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived, - so download or cache results if you need to keep them. - """ - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": (IO.STRING, { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }), - }, - "optional": { - "seed": (IO.INT, { - "default": 0, - "min": 0, - "max": 2**31-1, - "step": 1, - "display": "number", - "tooltip": "not implemented yet in backend", - }), - "quality" : (IO.COMBO, { - "options": ["standard","hd"], - "default": "standard", - "tooltip": "Image quality", - }), - "style": (IO.COMBO, { - "options": ["natural","vivid"], - "default": "natural", - "tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", - }), - "size": (IO.COMBO, { - "options": ["1024x1024", "1024x1792", "1792x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG" - } - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call(self, prompt, seed=0, style="natural", quality="standard", size="1024x1024", auth_token=None): - model = "dall-e-3" - - # build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/images/generations", - method=HttpMethod.POST, - request_model=OpenAIImageGenerationRequest, - response_model=OpenAIImageGenerationResponse - ), - request=OpenAIImageGenerationRequest( - model=model, - prompt=prompt, - quality=quality, - size=size, - style=style, - seed=seed, - ), - auth_token=auth_token - ) - - response = operation.execute() - - img_tensor = validate_and_cast_response(response) - return (img_tensor,) - -class OpenAIGPTImage1(ComfyNodeABC): - """ - Generates images synchronously via OpenAI's GPT Image 1 endpoint. - - Uses the proxy at /proxy/openai/images/generations. Returned URLs are short‑lived, - so download or cache results if you need to keep them. - """ - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": (IO.STRING, { - "multiline": True, - "default": "", - "tooltip": "Text prompt for GPT Image 1", - }), - }, - "optional": { - "seed": (IO.INT, { - "default": 0, - "min": 0, - "max": 2**31-1, - "step": 1, - "display": "number", - "tooltip": "not implemented yet in backend", - }), - "quality": (IO.COMBO, { - "options": ["low","medium","high"], - "default": "low", - "tooltip": "Image quality, affects cost and generation time.", - }), - "background": (IO.COMBO, { - "options": ["opaque","transparent"], - "default": "opaque", - "tooltip": "Return image with or without background", - }), - "size": (IO.COMBO, { - "options": ["auto", "1024x1024", "1024x1536", "1536x1024"], - "default": "auto", - "tooltip": "Image size", - }), - "n": (IO.INT, { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }), - "image": (IO.IMAGE, { - "default": None, - "tooltip": "Optional reference image for image editing.", - }), - "mask": (IO.MASK, { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }), - "moderation": (IO.COMBO, { - "options": ["low","auto"], - "default": "low", - "tooltip": "Moderation level", - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG" - } - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - def api_call(self, prompt, seed=0, quality="low", background="opaque", image=None, mask=None, n=1, size="1024x1024", auth_token=None, moderation="low"): - model = "gpt-image-1" - path = "/proxy/openai/images/generations" - request_class = OpenAIImageGenerationRequest - img_binaries = [] - mask_binary = None - files = [] - - if image is not None: - path = "/proxy/openai/images/edits" - request_class = OpenAIImageEditRequest - - batch_size = image.shape[0] - - - for i in range(batch_size): - single_image = image[i:i+1] - scaled_image = downscale_input(single_image).squeeze() - - image_np = (scaled_image.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') - img_byte_arr.seek(0) - img_binary = img_byte_arr - img_binary.name = f"image_{i}.png" - - img_binaries.append(img_binary) - if batch_size == 1: - files.append(("image", img_binary)) - else: - files.append(("image[]", img_binary)) - - if mask is not None: - if image.shape[0] != 1: - raise Exception("Cannot use a mask with multiple image") - if image is None: - raise Exception("Cannot use a mask without an input image") - if mask.shape[1:] != image.shape[1:-1]: - raise Exception("Mask and Image must be the same size") - batch, height, width = mask.shape - rgba_mask = torch.zeros(height, width, 4, device="cpu") - rgba_mask[:,:,3] = (1-mask.squeeze().cpu()) - - scaled_mask = downscale_input(rgba_mask.unsqueeze(0)).squeeze() - - mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) - mask_img = Image.fromarray(mask_np) - mask_img_byte_arr = io.BytesIO() - mask_img.save(mask_img_byte_arr, format='PNG') - mask_img_byte_arr.seek(0) - mask_binary = mask_img_byte_arr - mask_binary.name = "mask.png" - files.append(("mask", mask_binary)) - - - # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse - ), - request=request_class( - model=model, - prompt=prompt, - quality=quality, - background=background, - n=n, - seed=seed, - size=size, - moderation=moderation, - ), - files=files if files else None, - auth_token=auth_token - ) - - response = operation.execute() - - img_tensor = validate_and_cast_response(response) - return (img_tensor,) - - -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "OpenAIDalle2": OpenAIDalle2, - "OpenAIDalle3": OpenAIDalle3, - "OpenAIGPTImage1": OpenAIGPTImage1, -} - -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "OpenAIDalle2": "OpenAI DALL·E 2", - "OpenAIDalle3": "OpenAI DALL·E 3", - "OpenAIGPTImage1": "OpenAI GPT Image 1", -} diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py new file mode 100644 index 00000000..122a6ddf --- /dev/null +++ b/comfy_api_nodes/nodes_bfl.py @@ -0,0 +1,906 @@ +import io +from inspect import cleandoc +from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api_nodes.apis.bfl_api import ( + BFLStatus, + BFLFluxExpandImageRequest, + BFLFluxFillImageRequest, + BFLFluxCannyImageRequest, + BFLFluxDepthImageRequest, + BFLFluxProGenerateRequest, + BFLFluxProUltraGenerateRequest, + BFLFluxProGenerateResponse, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, +) +from comfy_api_nodes.apinode_utils import ( + downscale_image_tensor, + validate_aspect_ratio, + process_image_response, + resize_mask_to_image, + validate_string, +) + +import numpy as np +from PIL import Image +import requests +import torch +import base64 +import time + + +def convert_mask_to_image(mask: torch.Tensor): + """ + Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. + """ + mask = mask.unsqueeze(-1) + mask = torch.cat([mask]*3, dim=-1) + return mask + + +def handle_bfl_synchronous_operation( + operation: SynchronousOperation, timeout_bfl_calls=360 +): + response_api: BFLFluxProGenerateResponse = operation.execute() + return _poll_until_generated( + response_api.polling_url, timeout=timeout_bfl_calls + ) + +def _poll_until_generated(polling_url: str, timeout=360): + # used bfl-comfy-nodes to verify code implementation: + # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main + start_time = time.time() + retries_404 = 0 + max_retries_404 = 5 + retry_404_seconds = 2 + retry_202_seconds = 2 + retry_pending_seconds = 1 + request = requests.Request(method=HttpMethod.GET, url=polling_url) + # NOTE: should True loop be replaced with checking if workflow has been interrupted? + while True: + response = requests.Session().send(request.prepare()) + if response.status_code == 200: + result = response.json() + if result["status"] == BFLStatus.ready: + img_url = result["result"]["sample"] + img_response = requests.get(img_url) + return process_image_response(img_response) + elif result["status"] in [ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + ]: + status = result["status"] + raise Exception( + f"BFL API did not return an image due to: {status}." + ) + elif result["status"] == BFLStatus.error: + raise Exception(f"BFL API encountered an error: {result}.") + elif result["status"] == BFLStatus.pending: + time.sleep(retry_pending_seconds) + continue + elif response.status_code == 404: + if retries_404 < max_retries_404: + retries_404 += 1 + time.sleep(retry_404_seconds) + continue + raise Exception( + f"BFL API could not find task after {max_retries_404} tries." + ) + elif response.status_code == 202: + time.sleep(retry_202_seconds) + elif time.time() - start_time > timeout: + raise Exception( + f"BFL API experienced a timeout; could not return request under {timeout} seconds." + ) + else: + raise Exception(f"BFL API encountered an error: {response.json()}") + +def convert_image_to_base64(image: torch.Tensor): + scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) + # remove batch dimension if present + if len(scaled_image.shape) > 3: + scaled_image = scaled_image[0] + image_np = (scaled_image.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format="PNG") + return base64.b64encode(img_byte_arr.getvalue()).decode() + + +class FluxProUltraImageNode(ComfyNodeABC): + """ + Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. + """ + + MINIMUM_RATIO = 1 / 4 + MAXIMUM_RATIO = 4 / 1 + MINIMUM_RATIO_STR = "1:4" + MAXIMUM_RATIO_STR = "4:1" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "prompt_upsampling": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + "aspect_ratio": ( + IO.STRING, + { + "default": "16:9", + "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.", + }, + ), + "raw": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "When True, generate less processed, more natural-looking images.", + }, + ), + }, + "optional": { + "image_prompt": (IO.IMAGE,), + "image_prompt_strength": ( + IO.FLOAT, + { + "default": 0.1, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Blend between the prompt and the image prompt.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + @classmethod + def VALIDATE_INPUTS(cls, aspect_ratio: str): + try: + validate_aspect_ratio( + aspect_ratio, + minimum_ratio=cls.MINIMUM_RATIO, + maximum_ratio=cls.MAXIMUM_RATIO, + minimum_ratio_str=cls.MINIMUM_RATIO_STR, + maximum_ratio_str=cls.MAXIMUM_RATIO_STR, + ) + except Exception as e: + return str(e) + return True + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/BFL" + + def api_call( + self, + prompt: str, + aspect_ratio: str, + prompt_upsampling=False, + raw=False, + seed=0, + image_prompt=None, + image_prompt_strength=0.1, + auth_token=None, + **kwargs, + ): + if image_prompt is None: + validate_string(prompt, strip_whitespace=False) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/bfl/flux-pro-1.1-ultra/generate", + method=HttpMethod.POST, + request_model=BFLFluxProUltraGenerateRequest, + response_model=BFLFluxProGenerateResponse, + ), + request=BFLFluxProUltraGenerateRequest( + prompt=prompt, + prompt_upsampling=prompt_upsampling, + seed=seed, + aspect_ratio=validate_aspect_ratio( + aspect_ratio, + minimum_ratio=self.MINIMUM_RATIO, + maximum_ratio=self.MAXIMUM_RATIO, + minimum_ratio_str=self.MINIMUM_RATIO_STR, + maximum_ratio_str=self.MAXIMUM_RATIO_STR, + ), + raw=raw, + image_prompt=( + image_prompt + if image_prompt is None + else convert_image_to_base64(image_prompt) + ), + image_prompt_strength=( + None if image_prompt is None else round(image_prompt_strength, 2) + ), + ), + auth_token=auth_token, + ) + output_image = handle_bfl_synchronous_operation(operation) + return (output_image,) + + + +class FluxProImageNode(ComfyNodeABC): + """ + Generates images synchronously based on prompt and resolution. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "prompt_upsampling": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + }, + ), + "width": ( + IO.INT, + { + "default": 1024, + "min": 256, + "max": 1440, + "step": 32, + }, + ), + "height": ( + IO.INT, + { + "default": 768, + "min": 256, + "max": 1440, + "step": 32, + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + "image_prompt": (IO.IMAGE,), + # "image_prompt_strength": ( + # IO.FLOAT, + # { + # "default": 0.1, + # "min": 0.0, + # "max": 1.0, + # "step": 0.01, + # "tooltip": "Blend between the prompt and the image prompt.", + # }, + # ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/BFL" + + def api_call( + self, + prompt: str, + prompt_upsampling, + width: int, + height: int, + seed=0, + image_prompt=None, + # image_prompt_strength=0.1, + auth_token=None, + **kwargs, + ): + image_prompt = ( + image_prompt + if image_prompt is None + else convert_image_to_base64(image_prompt) + ) + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/bfl/flux-pro-1.1/generate", + method=HttpMethod.POST, + request_model=BFLFluxProGenerateRequest, + response_model=BFLFluxProGenerateResponse, + ), + request=BFLFluxProGenerateRequest( + prompt=prompt, + prompt_upsampling=prompt_upsampling, + width=width, + height=height, + seed=seed, + image_prompt=image_prompt, + ), + auth_token=auth_token, + ) + output_image = handle_bfl_synchronous_operation(operation) + return (output_image,) + + +class FluxProExpandNode(ComfyNodeABC): + """ + Outpaints image based on prompt. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "prompt_upsampling": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + }, + ), + "top": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2048, + "tooltip": "Number of pixels to expand at the top of the image" + }, + ), + "bottom": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2048, + "tooltip": "Number of pixels to expand at the bottom of the image" + }, + ), + "left": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2048, + "tooltip": "Number of pixels to expand at the left side of the image" + }, + ), + "right": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2048, + "tooltip": "Number of pixels to expand at the right side of the image" + }, + ), + "guidance": ( + IO.FLOAT, + { + "default": 60, + "min": 1.5, + "max": 100, + "tooltip": "Guidance strength for the image generation process" + }, + ), + "steps": ( + IO.INT, + { + "default": 50, + "min": 15, + "max": 50, + "tooltip": "Number of steps for the image generation process" + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/BFL" + + def api_call( + self, + image: torch.Tensor, + prompt: str, + prompt_upsampling: bool, + top: int, + bottom: int, + left: int, + right: int, + steps: int, + guidance: float, + seed=0, + auth_token=None, + **kwargs, + ): + image = convert_image_to_base64(image) + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/bfl/flux-pro-1.0-expand/generate", + method=HttpMethod.POST, + request_model=BFLFluxExpandImageRequest, + response_model=BFLFluxProGenerateResponse, + ), + request=BFLFluxExpandImageRequest( + prompt=prompt, + prompt_upsampling=prompt_upsampling, + top=top, + bottom=bottom, + left=left, + right=right, + steps=steps, + guidance=guidance, + seed=seed, + image=image, + ), + auth_token=auth_token, + ) + output_image = handle_bfl_synchronous_operation(operation) + return (output_image,) + + + +class FluxProFillNode(ComfyNodeABC): + """ + Inpaints image based on mask and prompt. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + "mask": (IO.MASK,), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "prompt_upsampling": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + }, + ), + "guidance": ( + IO.FLOAT, + { + "default": 60, + "min": 1.5, + "max": 100, + "tooltip": "Guidance strength for the image generation process" + }, + ), + "steps": ( + IO.INT, + { + "default": 50, + "min": 15, + "max": 50, + "tooltip": "Number of steps for the image generation process" + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/BFL" + + def api_call( + self, + image: torch.Tensor, + mask: torch.Tensor, + prompt: str, + prompt_upsampling: bool, + steps: int, + guidance: float, + seed=0, + auth_token=None, + **kwargs, + ): + # prepare mask + mask = resize_mask_to_image(mask, image) + mask = convert_image_to_base64(convert_mask_to_image(mask)) + # make sure image will have alpha channel removed + image = convert_image_to_base64(image[:,:,:,:3]) + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/bfl/flux-pro-1.0-fill/generate", + method=HttpMethod.POST, + request_model=BFLFluxFillImageRequest, + response_model=BFLFluxProGenerateResponse, + ), + request=BFLFluxFillImageRequest( + prompt=prompt, + prompt_upsampling=prompt_upsampling, + steps=steps, + guidance=guidance, + seed=seed, + image=image, + mask=mask, + ), + auth_token=auth_token, + ) + output_image = handle_bfl_synchronous_operation(operation) + return (output_image,) + + +class FluxProCannyNode(ComfyNodeABC): + """ + Generate image using a control image (canny). + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "control_image": (IO.IMAGE,), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "prompt_upsampling": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + }, + ), + "canny_low_threshold": ( + IO.FLOAT, + { + "default": 0.1, + "min": 0.01, + "max": 0.99, + "step": 0.01, + "tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True" + }, + ), + "canny_high_threshold": ( + IO.FLOAT, + { + "default": 0.4, + "min": 0.01, + "max": 0.99, + "step": 0.01, + "tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True" + }, + ), + "skip_preprocessing": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", + }, + ), + "guidance": ( + IO.FLOAT, + { + "default": 30, + "min": 1, + "max": 100, + "tooltip": "Guidance strength for the image generation process" + }, + ), + "steps": ( + IO.INT, + { + "default": 50, + "min": 15, + "max": 50, + "tooltip": "Number of steps for the image generation process" + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/BFL" + + def api_call( + self, + control_image: torch.Tensor, + prompt: str, + prompt_upsampling: bool, + canny_low_threshold: float, + canny_high_threshold: float, + skip_preprocessing: bool, + steps: int, + guidance: float, + seed=0, + auth_token=None, + **kwargs, + ): + control_image = convert_image_to_base64(control_image[:,:,:,:3]) + preprocessed_image = None + + # scale canny threshold between 0-500, to match BFL's API + def scale_value(value: float, min_val=0, max_val=500): + return min_val + value * (max_val - min_val) + canny_low_threshold = int(round(scale_value(canny_low_threshold))) + canny_high_threshold = int(round(scale_value(canny_high_threshold))) + + + if skip_preprocessing: + preprocessed_image = control_image + control_image = None + canny_low_threshold = None + canny_high_threshold = None + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/bfl/flux-pro-1.0-canny/generate", + method=HttpMethod.POST, + request_model=BFLFluxCannyImageRequest, + response_model=BFLFluxProGenerateResponse, + ), + request=BFLFluxCannyImageRequest( + prompt=prompt, + prompt_upsampling=prompt_upsampling, + steps=steps, + guidance=guidance, + seed=seed, + control_image=control_image, + canny_low_threshold=canny_low_threshold, + canny_high_threshold=canny_high_threshold, + preprocessed_image=preprocessed_image, + ), + auth_token=auth_token, + ) + output_image = handle_bfl_synchronous_operation(operation) + return (output_image,) + + +class FluxProDepthNode(ComfyNodeABC): + """ + Generate image using a control image (depth). + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "control_image": (IO.IMAGE,), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "prompt_upsampling": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + }, + ), + "skip_preprocessing": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", + }, + ), + "guidance": ( + IO.FLOAT, + { + "default": 15, + "min": 1, + "max": 100, + "tooltip": "Guidance strength for the image generation process" + }, + ), + "steps": ( + IO.INT, + { + "default": 50, + "min": 15, + "max": 50, + "tooltip": "Number of steps for the image generation process" + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/BFL" + + def api_call( + self, + control_image: torch.Tensor, + prompt: str, + prompt_upsampling: bool, + skip_preprocessing: bool, + steps: int, + guidance: float, + seed=0, + auth_token=None, + **kwargs, + ): + control_image = convert_image_to_base64(control_image[:,:,:,:3]) + preprocessed_image = None + + if skip_preprocessing: + preprocessed_image = control_image + control_image = None + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/bfl/flux-pro-1.0-depth/generate", + method=HttpMethod.POST, + request_model=BFLFluxDepthImageRequest, + response_model=BFLFluxProGenerateResponse, + ), + request=BFLFluxDepthImageRequest( + prompt=prompt, + prompt_upsampling=prompt_upsampling, + steps=steps, + guidance=guidance, + seed=seed, + control_image=control_image, + preprocessed_image=preprocessed_image, + ), + auth_token=auth_token, + ) + output_image = handle_bfl_synchronous_operation(operation) + return (output_image,) + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "FluxProUltraImageNode": FluxProUltraImageNode, + # "FluxProImageNode": FluxProImageNode, + "FluxProExpandNode": FluxProExpandNode, + "FluxProFillNode": FluxProFillNode, + "FluxProCannyNode": FluxProCannyNode, + "FluxProDepthNode": FluxProDepthNode, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image", + # "FluxProImageNode": "Flux 1.1 [pro] Image", + "FluxProExpandNode": "Flux.1 Expand Image", + "FluxProFillNode": "Flux.1 Fill Image", + "FluxProCannyNode": "Flux.1 Canny Control Image", + "FluxProDepthNode": "Flux.1 Depth Control Image", +} diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py new file mode 100644 index 00000000..45c021f4 --- /dev/null +++ b/comfy_api_nodes/nodes_ideogram.py @@ -0,0 +1,777 @@ +from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict +from inspect import cleandoc +from PIL import Image +import numpy as np +import io +import torch +from comfy_api_nodes.apis import ( + IdeogramGenerateRequest, + IdeogramGenerateResponse, + ImageRequest, + IdeogramV3Request, + IdeogramV3EditRequest, +) + +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, +) + +from comfy_api_nodes.apinode_utils import ( + download_url_to_bytesio, + bytesio_to_image_tensor, + resize_mask_to_image, +) + +V1_V1_RES_MAP = { + "Auto":"AUTO", + "512 x 1536":"RESOLUTION_512_1536", + "576 x 1408":"RESOLUTION_576_1408", + "576 x 1472":"RESOLUTION_576_1472", + "576 x 1536":"RESOLUTION_576_1536", + "640 x 1024":"RESOLUTION_640_1024", + "640 x 1344":"RESOLUTION_640_1344", + "640 x 1408":"RESOLUTION_640_1408", + "640 x 1472":"RESOLUTION_640_1472", + "640 x 1536":"RESOLUTION_640_1536", + "704 x 1152":"RESOLUTION_704_1152", + "704 x 1216":"RESOLUTION_704_1216", + "704 x 1280":"RESOLUTION_704_1280", + "704 x 1344":"RESOLUTION_704_1344", + "704 x 1408":"RESOLUTION_704_1408", + "704 x 1472":"RESOLUTION_704_1472", + "720 x 1280":"RESOLUTION_720_1280", + "736 x 1312":"RESOLUTION_736_1312", + "768 x 1024":"RESOLUTION_768_1024", + "768 x 1088":"RESOLUTION_768_1088", + "768 x 1152":"RESOLUTION_768_1152", + "768 x 1216":"RESOLUTION_768_1216", + "768 x 1232":"RESOLUTION_768_1232", + "768 x 1280":"RESOLUTION_768_1280", + "768 x 1344":"RESOLUTION_768_1344", + "832 x 960":"RESOLUTION_832_960", + "832 x 1024":"RESOLUTION_832_1024", + "832 x 1088":"RESOLUTION_832_1088", + "832 x 1152":"RESOLUTION_832_1152", + "832 x 1216":"RESOLUTION_832_1216", + "832 x 1248":"RESOLUTION_832_1248", + "864 x 1152":"RESOLUTION_864_1152", + "896 x 960":"RESOLUTION_896_960", + "896 x 1024":"RESOLUTION_896_1024", + "896 x 1088":"RESOLUTION_896_1088", + "896 x 1120":"RESOLUTION_896_1120", + "896 x 1152":"RESOLUTION_896_1152", + "960 x 832":"RESOLUTION_960_832", + "960 x 896":"RESOLUTION_960_896", + "960 x 1024":"RESOLUTION_960_1024", + "960 x 1088":"RESOLUTION_960_1088", + "1024 x 640":"RESOLUTION_1024_640", + "1024 x 768":"RESOLUTION_1024_768", + "1024 x 832":"RESOLUTION_1024_832", + "1024 x 896":"RESOLUTION_1024_896", + "1024 x 960":"RESOLUTION_1024_960", + "1024 x 1024":"RESOLUTION_1024_1024", + "1088 x 768":"RESOLUTION_1088_768", + "1088 x 832":"RESOLUTION_1088_832", + "1088 x 896":"RESOLUTION_1088_896", + "1088 x 960":"RESOLUTION_1088_960", + "1120 x 896":"RESOLUTION_1120_896", + "1152 x 704":"RESOLUTION_1152_704", + "1152 x 768":"RESOLUTION_1152_768", + "1152 x 832":"RESOLUTION_1152_832", + "1152 x 864":"RESOLUTION_1152_864", + "1152 x 896":"RESOLUTION_1152_896", + "1216 x 704":"RESOLUTION_1216_704", + "1216 x 768":"RESOLUTION_1216_768", + "1216 x 832":"RESOLUTION_1216_832", + "1232 x 768":"RESOLUTION_1232_768", + "1248 x 832":"RESOLUTION_1248_832", + "1280 x 704":"RESOLUTION_1280_704", + "1280 x 720":"RESOLUTION_1280_720", + "1280 x 768":"RESOLUTION_1280_768", + "1280 x 800":"RESOLUTION_1280_800", + "1312 x 736":"RESOLUTION_1312_736", + "1344 x 640":"RESOLUTION_1344_640", + "1344 x 704":"RESOLUTION_1344_704", + "1344 x 768":"RESOLUTION_1344_768", + "1408 x 576":"RESOLUTION_1408_576", + "1408 x 640":"RESOLUTION_1408_640", + "1408 x 704":"RESOLUTION_1408_704", + "1472 x 576":"RESOLUTION_1472_576", + "1472 x 640":"RESOLUTION_1472_640", + "1472 x 704":"RESOLUTION_1472_704", + "1536 x 512":"RESOLUTION_1536_512", + "1536 x 576":"RESOLUTION_1536_576", + "1536 x 640":"RESOLUTION_1536_640", +} + +V1_V2_RATIO_MAP = { + "1:1":"ASPECT_1_1", + "4:3":"ASPECT_4_3", + "3:4":"ASPECT_3_4", + "16:9":"ASPECT_16_9", + "9:16":"ASPECT_9_16", + "2:1":"ASPECT_2_1", + "1:2":"ASPECT_1_2", + "3:2":"ASPECT_3_2", + "2:3":"ASPECT_2_3", + "4:5":"ASPECT_4_5", + "5:4":"ASPECT_5_4", +} + +V3_RATIO_MAP = { + "1:3":"1x3", + "3:1":"3x1", + "1:2":"1x2", + "2:1":"2x1", + "9:16":"9x16", + "16:9":"16x9", + "10:16":"10x16", + "16:10":"16x10", + "2:3":"2x3", + "3:2":"3x2", + "3:4":"3x4", + "4:3":"4x3", + "4:5":"4x5", + "5:4":"5x4", + "1:1":"1x1", +} + +V3_RESOLUTIONS= [ + "Auto", + "512x1536", + "576x1408", + "576x1472", + "576x1536", + "640x1344", + "640x1408", + "640x1472", + "640x1536", + "704x1152", + "704x1216", + "704x1280", + "704x1344", + "704x1408", + "704x1472", + "736x1312", + "768x1088", + "768x1216", + "768x1280", + "768x1344", + "800x1280", + "832x960", + "832x1024", + "832x1088", + "832x1152", + "832x1216", + "832x1248", + "864x1152", + "896x960", + "896x1024", + "896x1088", + "896x1120", + "896x1152", + "960x832", + "960x896", + "960x1024", + "960x1088", + "1024x832", + "1024x896", + "1024x960", + "1024x1024", + "1088x768", + "1088x832", + "1088x896", + "1088x960", + "1120x896", + "1152x704", + "1152x832", + "1152x864", + "1152x896", + "1216x704", + "1216x768", + "1216x832", + "1248x832", + "1280x704", + "1280x768", + "1280x800", + "1312x736", + "1344x640", + "1344x704", + "1344x768", + "1408x576", + "1408x640", + "1408x704", + "1472x576", + "1472x640", + "1472x704", + "1536x512", + "1536x576", + "1536x640" +] + +def download_and_process_images(image_urls): + """Helper function to download and process multiple images from URLs""" + + # Initialize list to store image tensors + image_tensors = [] + + for image_url in image_urls: + # Using functions from apinode_utils.py to handle downloading and processing + image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO + img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode + image_tensors.append(img_tensor) + + # Stack tensors to match (N, width, height, channels) + if image_tensors: + stacked_tensors = torch.cat(image_tensors, dim=0) + else: + raise Exception("No valid images were processed") + + return stacked_tensors + + +class IdeogramV1(ComfyNodeABC): + """ + Generates images synchronously using the Ideogram V1 model. + + Images links are available for a limited period of time; if you would like to keep the image, you must download it. + """ + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "turbo": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", + } + ), + }, + "optional": { + "aspect_ratio": ( + IO.COMBO, + { + "options": list(V1_V2_RATIO_MAP.keys()), + "default": "1:1", + "tooltip": "The aspect ratio for image generation.", + }, + ), + "magic_prompt_option": ( + IO.COMBO, + { + "options": ["AUTO", "ON", "OFF"], + "default": "AUTO", + "tooltip": "Determine if MagicPrompt should be used in generation", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2147483647, + "step": 1, + "control_after_generate": True, + "display": "number", + }, + ), + "negative_prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Description of what to exclude from the image", + }, + ), + "num_images": ( + IO.INT, + {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node/image/Ideogram/v1" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call( + self, + prompt, + turbo=False, + aspect_ratio="1:1", + magic_prompt_option="AUTO", + seed=0, + negative_prompt="", + num_images=1, + auth_token=None, + ): + # Determine the model based on turbo setting + aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) + model = "V_1_TURBO" if turbo else "V_1" + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/ideogram/generate", + method=HttpMethod.POST, + request_model=IdeogramGenerateRequest, + response_model=IdeogramGenerateResponse, + ), + request=IdeogramGenerateRequest( + image_request=ImageRequest( + prompt=prompt, + model=model, + num_images=num_images, + seed=seed, + aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, + magic_prompt_option=( + magic_prompt_option if magic_prompt_option != "AUTO" else None + ), + negative_prompt=negative_prompt if negative_prompt else None, + ) + ), + auth_token=auth_token, + ) + + response = operation.execute() + + if not response.data or len(response.data) == 0: + raise Exception("No images were generated in the response") + + image_urls = [image_data.url for image_data in response.data if image_data.url] + + if not image_urls: + raise Exception("No image URLs were generated in the response") + + return (download_and_process_images(image_urls),) + + +class IdeogramV2(ComfyNodeABC): + """ + Generates images synchronously using the Ideogram V2 model. + + Images links are available for a limited period of time; if you would like to keep the image, you must download it. + """ + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "turbo": ( + IO.BOOLEAN, + { + "default": False, + "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", + } + ), + }, + "optional": { + "aspect_ratio": ( + IO.COMBO, + { + "options": list(V1_V2_RATIO_MAP.keys()), + "default": "1:1", + "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", + }, + ), + "resolution": ( + IO.COMBO, + { + "options": list(V1_V1_RES_MAP.keys()), + "default": "Auto", + "tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.", + }, + ), + "magic_prompt_option": ( + IO.COMBO, + { + "options": ["AUTO", "ON", "OFF"], + "default": "AUTO", + "tooltip": "Determine if MagicPrompt should be used in generation", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2147483647, + "step": 1, + "control_after_generate": True, + "display": "number", + }, + ), + "style_type": ( + IO.COMBO, + { + "options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], + "default": "NONE", + "tooltip": "Style type for generation (V2 only)", + }, + ), + "negative_prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Description of what to exclude from the image", + }, + ), + "num_images": ( + IO.INT, + {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + ), + #"color_palette": ( + # IO.STRING, + # { + # "multiline": False, + # "default": "", + # "tooltip": "Color palette preset name or hex colors with weights", + # }, + #), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node/image/Ideogram/v2" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call( + self, + prompt, + turbo=False, + aspect_ratio="1:1", + resolution="Auto", + magic_prompt_option="AUTO", + seed=0, + style_type="NONE", + negative_prompt="", + num_images=1, + color_palette="", + auth_token=None, + ): + aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) + resolution = V1_V1_RES_MAP.get(resolution, None) + # Determine the model based on turbo setting + model = "V_2_TURBO" if turbo else "V_2" + + # Handle resolution vs aspect_ratio logic + # If resolution is not AUTO, it overrides aspect_ratio + final_resolution = None + final_aspect_ratio = None + + if resolution != "AUTO": + final_resolution = resolution + else: + final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/ideogram/generate", + method=HttpMethod.POST, + request_model=IdeogramGenerateRequest, + response_model=IdeogramGenerateResponse, + ), + request=IdeogramGenerateRequest( + image_request=ImageRequest( + prompt=prompt, + model=model, + num_images=num_images, + seed=seed, + aspect_ratio=final_aspect_ratio, + resolution=final_resolution, + magic_prompt_option=( + magic_prompt_option if magic_prompt_option != "AUTO" else None + ), + style_type=style_type if style_type != "NONE" else None, + negative_prompt=negative_prompt if negative_prompt else None, + color_palette=color_palette if color_palette else None, + ) + ), + auth_token=auth_token, + ) + + response = operation.execute() + + if not response.data or len(response.data) == 0: + raise Exception("No images were generated in the response") + + image_urls = [image_data.url for image_data in response.data if image_data.url] + + if not image_urls: + raise Exception("No image URLs were generated in the response") + + return (download_and_process_images(image_urls),) + +class IdeogramV3(ComfyNodeABC): + """ + Generates images synchronously using the Ideogram V3 model. + + Supports both regular image generation from text prompts and image editing with mask. + Images links are available for a limited period of time; if you would like to keep the image, you must download it. + """ + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation or editing", + }, + ), + }, + "optional": { + "image": ( + IO.IMAGE, + { + "default": None, + "tooltip": "Optional reference image for image editing.", + }, + ), + "mask": ( + IO.MASK, + { + "default": None, + "tooltip": "Optional mask for inpainting (white areas will be replaced)", + }, + ), + "aspect_ratio": ( + IO.COMBO, + { + "options": list(V3_RATIO_MAP.keys()), + "default": "1:1", + "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.", + }, + ), + "resolution": ( + IO.COMBO, + { + "options": V3_RESOLUTIONS, + "default": "Auto", + "tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.", + }, + ), + "magic_prompt_option": ( + IO.COMBO, + { + "options": ["AUTO", "ON", "OFF"], + "default": "AUTO", + "tooltip": "Determine if MagicPrompt should be used in generation", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2147483647, + "step": 1, + "control_after_generate": True, + "display": "number", + }, + ), + "num_images": ( + IO.INT, + {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + ), + "rendering_speed": ( + IO.COMBO, + { + "options": ["BALANCED", "TURBO", "QUALITY"], + "default": "BALANCED", + "tooltip": "Controls the trade-off between generation speed and quality", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node/image/Ideogram/v3" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call( + self, + prompt, + image=None, + mask=None, + resolution="Auto", + aspect_ratio="1:1", + magic_prompt_option="AUTO", + seed=0, + num_images=1, + rendering_speed="BALANCED", + auth_token=None, + ): + # Check if both image and mask are provided for editing mode + if image is not None and mask is not None: + # Edit mode + path = "/proxy/ideogram/ideogram-v3/edit" + + # Process image and mask + input_tensor = image.squeeze().cpu() + # Resize mask to match image dimension + mask = resize_mask_to_image(mask, image, allow_gradient=False) + # Invert mask, as Ideogram API will edit black areas instead of white areas (opposite of convention). + mask = 1.0 - mask + + # Validate mask dimensions match image + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + + # Process image + img_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(img_np) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format="PNG") + img_byte_arr.seek(0) + img_binary = img_byte_arr + img_binary.name = "image.png" + + # Process mask - white areas will be replaced + mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_byte_arr = io.BytesIO() + mask_img.save(mask_byte_arr, format="PNG") + mask_byte_arr.seek(0) + mask_binary = mask_byte_arr + mask_binary.name = "mask.png" + + # Create edit request + edit_request = IdeogramV3EditRequest( + prompt=prompt, + rendering_speed=rendering_speed, + ) + + # Add optional parameters + if magic_prompt_option != "AUTO": + edit_request.magic_prompt = magic_prompt_option + if seed != 0: + edit_request.seed = seed + if num_images > 1: + edit_request.num_images = num_images + + # Execute the operation for edit mode + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=IdeogramV3EditRequest, + response_model=IdeogramGenerateResponse, + ), + request=edit_request, + files={ + "image": img_binary, + "mask": mask_binary, + }, + content_type="multipart/form-data", + auth_token=auth_token, + ) + + elif image is not None or mask is not None: + # If only one of image or mask is provided, raise an error + raise Exception("Ideogram V3 image editing requires both an image AND a mask") + else: + # Generation mode + path = "/proxy/ideogram/ideogram-v3/generate" + + # Create generation request + gen_request = IdeogramV3Request( + prompt=prompt, + rendering_speed=rendering_speed, + ) + + # Handle resolution vs aspect ratio + if resolution != "Auto": + gen_request.resolution = resolution + elif aspect_ratio != "1:1": + v3_aspect = V3_RATIO_MAP.get(aspect_ratio) + if v3_aspect: + gen_request.aspect_ratio = v3_aspect + + # Add optional parameters + if magic_prompt_option != "AUTO": + gen_request.magic_prompt = magic_prompt_option + if seed != 0: + gen_request.seed = seed + if num_images > 1: + gen_request.num_images = num_images + + # Execute the operation for generation mode + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=IdeogramV3Request, + response_model=IdeogramGenerateResponse, + ), + request=gen_request, + auth_token=auth_token, + ) + + # Execute the operation and process response + response = operation.execute() + + if not response.data or len(response.data) == 0: + raise Exception("No images were generated in the response") + + image_urls = [image_data.url for image_data in response.data if image_data.url] + + if not image_urls: + raise Exception("No image URLs were generated in the response") + + return (download_and_process_images(image_urls),) + + +NODE_CLASS_MAPPINGS = { + "IdeogramV1": IdeogramV1, + "IdeogramV2": IdeogramV2, + "IdeogramV3": IdeogramV3, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "IdeogramV1": "Ideogram V1", + "IdeogramV2": "Ideogram V2", + "IdeogramV3": "Ideogram V3", +} + diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py new file mode 100644 index 00000000..9aa8df58 --- /dev/null +++ b/comfy_api_nodes/nodes_kling.py @@ -0,0 +1,1563 @@ +"""Kling API Nodes + +For source of truth on the allowed permutations of request fields, please reference: +- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) +""" + +from __future__ import annotations +from typing import Optional, TypeVar, Any +import math +import logging + +import torch + +from comfy_api_nodes.apis import ( + KlingTaskStatus, + KlingCameraControl, + KlingCameraConfig, + KlingCameraControlType, + KlingVideoGenDuration, + KlingVideoGenMode, + KlingVideoGenAspectRatio, + KlingVideoGenModelName, + KlingText2VideoRequest, + KlingText2VideoResponse, + KlingImage2VideoRequest, + KlingImage2VideoResponse, + KlingVideoExtendRequest, + KlingVideoExtendResponse, + KlingLipSyncVoiceLanguage, + KlingLipSyncInputObject, + KlingLipSyncRequest, + KlingLipSyncResponse, + KlingVirtualTryOnModelName, + KlingVirtualTryOnRequest, + KlingVirtualTryOnResponse, + KlingVideoResult, + KlingImageResult, + KlingImageGenerationsRequest, + KlingImageGenerationsResponse, + KlingImageGenImageReferenceType, + KlingImageGenModelName, + KlingImageGenAspectRatio, + KlingVideoEffectsRequest, + KlingVideoEffectsResponse, + KlingDualCharacterEffectsScene, + KlingSingleImageEffectsScene, + KlingDualCharacterEffectInput, + KlingSingleImageEffectInput, + KlingCharacterEffectModelName, + KlingSingleImageEffectModelName, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + tensor_to_base64_string, + download_url_to_video_output, + upload_video_to_comfyapi, + upload_audio_to_comfyapi, + download_url_to_image_tensor, +) +from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api.input.basic_types import AudioInput +from comfy_api.input.video_types import VideoInput +from comfy_api.input_impl import VideoFromFile +from comfy.comfy_types.node_typing import IO, InputTypeOptions, ComfyNodeABC + +KLING_API_VERSION = "v1" +PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" +PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video" +PATH_VIDEO_EXTEND = f"/proxy/kling/{KLING_API_VERSION}/videos/video-extend" +PATH_LIP_SYNC = f"/proxy/kling/{KLING_API_VERSION}/videos/lip-sync" +PATH_VIDEO_EFFECTS = f"/proxy/kling/{KLING_API_VERSION}/videos/effects" +PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations" +PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on" +PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations" + + +MAX_PROMPT_LENGTH_T2V = 2500 +MAX_PROMPT_LENGTH_I2V = 500 +MAX_PROMPT_LENGTH_IMAGE_GEN = 500 +MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200 +MAX_PROMPT_LENGTH_LIP_SYNC = 120 + +R = TypeVar("R") + + +class KlingApiError(Exception): + """Base exception for Kling API errors.""" + + pass + + +def poll_until_finished(auth_token: str, api_endpoint: ApiEndpoint[Any, R]) -> R: + """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" + return PollingOperation( + poll_endpoint=api_endpoint, + completed_statuses=[ + KlingTaskStatus.succeed.value, + ], + failed_statuses=[KlingTaskStatus.failed.value], + status_extractor=lambda response: ( + response.data.task_status.value + if response.data and response.data.task_status + else None + ), + auth_token=auth_token, + ).execute() + + +def is_valid_camera_control_configs(configs: list[float]) -> bool: + """Verifies that at least one camera control configuration is non-zero.""" + return any(not math.isclose(value, 0.0) for value in configs) + + +def is_valid_prompt(prompt: str) -> bool: + """Verifies that the prompt is not empty.""" + return bool(prompt) + + +def is_valid_task_creation_response(response: KlingText2VideoResponse) -> bool: + """Verifies that the initial response contains a task ID.""" + return bool(response.data.task_id) + + +def is_valid_video_response(response: KlingText2VideoResponse) -> bool: + """Verifies that the response contains a task result with at least one video.""" + return ( + response.data is not None + and response.data.task_result is not None + and response.data.task_result.videos is not None + and len(response.data.task_result.videos) > 0 + ) + + +def is_valid_image_response(response: KlingVirtualTryOnResponse) -> bool: + """Verifies that the response contains a task result with at least one image.""" + return ( + response.data is not None + and response.data.task_result is not None + and response.data.task_result.images is not None + and len(response.data.task_result.images) > 0 + ) + + +def validate_prompts(prompt: str, negative_prompt: str, max_length: int) -> bool: + """Verifies that the positive prompt is not empty and that neither promt is too long.""" + if not prompt: + raise ValueError("Positive prompt is empty") + if len(prompt) > max_length: + raise ValueError(f"Positive prompt is too long: {len(prompt)} characters") + if negative_prompt and len(negative_prompt) > max_length: + raise ValueError( + f"Negative prompt is too long: {len(negative_prompt)} characters" + ) + return True + + +def validate_task_creation_response(response) -> None: + """Validates that the Kling task creation request was successful.""" + if not is_valid_task_creation_response(response): + error_msg = f"Kling initial request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + logging.error(error_msg) + raise KlingApiError(error_msg) + + +def validate_video_result_response(response) -> None: + """Validates that the Kling task result contains a video.""" + if not is_valid_video_response(response): + error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response." + logging.error(f"Error: {error_msg}.\nResponse: {response}") + raise KlingApiError(error_msg) + + +def validate_image_result_response(response) -> None: + """Validates that the Kling task result contains an image.""" + if not is_valid_image_response(response): + error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response." + logging.error(f"Error: {error_msg}.\nResponse: {response}") + raise KlingApiError(error_msg) + + +def get_camera_control_input_config( + tooltip: str, default: float = 0.0 +) -> tuple[IO, InputTypeOptions]: + """Returns common InputTypeOptions for Kling camera control configurations.""" + input_config = { + "default": default, + "min": -10.0, + "max": 10.0, + "step": 0.25, + "display": "slider", + "tooltip": tooltip, + } + return IO.FLOAT, input_config + + +def get_video_from_response(response) -> KlingVideoResult: + """Returns the first video object from the Kling video generation task result.""" + video = response.data.task_result.videos[0] + logging.info( + "Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url + ) + return video + + +def get_images_from_response(response) -> list[KlingImageResult]: + images = response.data.task_result.images + logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images) + return images + + +def video_result_to_node_output( + video: KlingVideoResult, +) -> tuple[VideoFromFile, str, str]: + """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output.""" + return ( + download_url_to_video_output(video.url), + str(video.id), + str(video.duration), + ) + + +def image_result_to_node_output( + images: list[KlingImageResult], +) -> torch.Tensor: + """ + Converts a KlingImageResult to a tuple containing a [B, H, W, C] tensor. + If multiple images are returned, they will be stacked along the batch dimension. + """ + if len(images) == 1: + return download_url_to_image_tensor(images[0].url) + else: + return torch.cat([download_url_to_image_tensor(image.url) for image in images]) + + +class KlingNodeBase(ComfyNodeABC): + """Base class for Kling nodes.""" + + FUNCTION = "api_call" + CATEGORY = "api node/video/Kling" + API_NODE = True + + +class KlingCameraControls(KlingNodeBase): + """Kling Camera Controls Node""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "camera_control_type": model_field_to_node_input( + IO.COMBO, + KlingCameraControl, + "type", + enum_type=KlingCameraControlType, + ), + "horizontal_movement": get_camera_control_input_config( + "Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right" + ), + "vertical_movement": get_camera_control_input_config( + "Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward." + ), + "pan": get_camera_control_input_config( + "Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.", + default=0.5, + ), + "tilt": get_camera_control_input_config( + "Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.", + ), + "roll": get_camera_control_input_config( + "Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.", + ), + "zoom": get_camera_control_input_config( + "Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.", + ), + } + } + + DESCRIPTION = "Allows specifying configuration options for Kling Camera Controls and motion control effects." + RETURN_TYPES = ("CAMERA_CONTROL",) + RETURN_NAMES = ("camera_control",) + FUNCTION = "main" + + @classmethod + def VALIDATE_INPUTS( + cls, + horizontal_movement: float, + vertical_movement: float, + pan: float, + tilt: float, + roll: float, + zoom: float, + ) -> bool | str: + if not is_valid_camera_control_configs( + [ + horizontal_movement, + vertical_movement, + pan, + tilt, + roll, + zoom, + ] + ): + return "Invalid camera control configs: at least one of the values must be non-zero" + return True + + def main( + self, + camera_control_type: str, + horizontal_movement: float, + vertical_movement: float, + pan: float, + tilt: float, + roll: float, + zoom: float, + ) -> tuple[KlingCameraControl]: + return ( + KlingCameraControl( + type=KlingCameraControlType(camera_control_type), + config=KlingCameraConfig( + horizontal=horizontal_movement, + vertical=vertical_movement, + pan=pan, + roll=roll, + tilt=tilt, + zoom=zoom, + ), + ), + ) + + +class KlingTextToVideoNode(KlingNodeBase): + """Kling Text to Video Node""" + + @staticmethod + def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]: + """ + Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. + Only includes config combos that support the `image_tail` request field. + + See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) + """ + return { + "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), + "standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"), + "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), + "pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"), + "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), + "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), + "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), + "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), + "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), + "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), + } + + @classmethod + def INPUT_TYPES(s): + modes = list(KlingTextToVideoNode.get_mode_string_mapping().keys()) + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, KlingText2VideoRequest, "prompt", multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, KlingText2VideoRequest, "negative_prompt", multiline=True + ), + "cfg_scale": model_field_to_node_input( + IO.FLOAT, + KlingText2VideoRequest, + "cfg_scale", + default=1.0, + min=0.0, + max=1.0, + ), + "aspect_ratio": model_field_to_node_input( + IO.COMBO, + KlingText2VideoRequest, + "aspect_ratio", + enum_type=KlingVideoGenAspectRatio, + ), + "mode": ( + modes, + { + "default": modes[4], + "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = ("VIDEO", "STRING", "STRING") + RETURN_NAMES = ("VIDEO", "video_id", "duration") + DESCRIPTION = "Kling Text to Video Node" + + def get_response(self, task_id: str, auth_token: str) -> KlingText2VideoResponse: + return poll_until_finished( + auth_token, + ApiEndpoint( + path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingText2VideoResponse, + ), + ) + + def api_call( + self, + prompt: str, + negative_prompt: str, + cfg_scale: float, + mode: str, + aspect_ratio: str, + camera_control: Optional[KlingCameraControl] = None, + model_name: Optional[str] = None, + duration: Optional[str] = None, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile, str, str]: + validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) + if model_name is None: + mode, duration, model_name = self.get_mode_string_mapping()[mode] + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_TEXT_TO_VIDEO, + method=HttpMethod.POST, + request_model=KlingText2VideoRequest, + response_model=KlingText2VideoResponse, + ), + request=KlingText2VideoRequest( + prompt=prompt if prompt else None, + negative_prompt=negative_prompt if negative_prompt else None, + duration=KlingVideoGenDuration(duration), + mode=KlingVideoGenMode(mode), + model_name=KlingVideoGenModelName(model_name), + cfg_scale=cfg_scale, + aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), + camera_control=camera_control, + ), + auth_token=auth_token, + ) + + task_creation_response = initial_operation.execute() + validate_task_creation_response(task_creation_response) + + task_id = task_creation_response.data.task_id + final_response = self.get_response(task_id, auth_token) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return video_result_to_node_output(video) + + +class KlingCameraControlT2VNode(KlingTextToVideoNode): + """ + Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. + Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, KlingText2VideoRequest, "prompt", multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + KlingText2VideoRequest, + "negative_prompt", + multiline=True, + ), + "cfg_scale": model_field_to_node_input( + IO.FLOAT, + KlingText2VideoRequest, + "cfg_scale", + default=0.75, + min=0.0, + max=1.0, + ), + "aspect_ratio": model_field_to_node_input( + IO.COMBO, + KlingText2VideoRequest, + "aspect_ratio", + enum_type=KlingVideoGenAspectRatio, + ), + "camera_control": ( + "CAMERA_CONTROL", + { + "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." + + def api_call( + self, + prompt: str, + negative_prompt: str, + cfg_scale: float, + aspect_ratio: str, + camera_control: Optional[KlingCameraControl] = None, + auth_token: Optional[str] = None, + ): + return super().api_call( + model_name=KlingVideoGenModelName.kling_v1, + cfg_scale=cfg_scale, + mode=KlingVideoGenMode.std, + aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), + duration=KlingVideoGenDuration.field_5, + prompt=prompt, + negative_prompt=negative_prompt, + camera_control=camera_control, + auth_token=auth_token, + ) + + +class KlingImage2VideoNode(KlingNodeBase): + """Kling Image to Video Node""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_frame": model_field_to_node_input( + IO.IMAGE, KlingImage2VideoRequest, "image" + ), + "prompt": model_field_to_node_input( + IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + KlingImage2VideoRequest, + "negative_prompt", + multiline=True, + ), + "model_name": model_field_to_node_input( + IO.COMBO, + KlingImage2VideoRequest, + "model_name", + enum_type=KlingVideoGenModelName, + ), + "cfg_scale": model_field_to_node_input( + IO.FLOAT, + KlingImage2VideoRequest, + "cfg_scale", + default=0.8, + min=0.0, + max=1.0, + ), + "mode": model_field_to_node_input( + IO.COMBO, + KlingImage2VideoRequest, + "mode", + enum_type=KlingVideoGenMode, + ), + "aspect_ratio": model_field_to_node_input( + IO.COMBO, + KlingImage2VideoRequest, + "aspect_ratio", + enum_type=KlingVideoGenAspectRatio, + ), + "duration": model_field_to_node_input( + IO.COMBO, + KlingImage2VideoRequest, + "duration", + enum_type=KlingVideoGenDuration, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = ("VIDEO", "STRING", "STRING") + RETURN_NAMES = ("VIDEO", "video_id", "duration") + DESCRIPTION = "Kling Image to Video Node" + + def get_response(self, task_id: str, auth_token: str) -> KlingImage2VideoResponse: + return poll_until_finished( + auth_token, + ApiEndpoint( + path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", + method=HttpMethod.GET, + request_model=KlingImage2VideoRequest, + response_model=KlingImage2VideoResponse, + ), + ) + + def api_call( + self, + start_frame: torch.Tensor, + prompt: str, + negative_prompt: str, + model_name: str, + cfg_scale: float, + mode: str, + aspect_ratio: str, + duration: str, + camera_control: Optional[KlingCameraControl] = None, + end_frame: Optional[torch.Tensor] = None, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) + + if camera_control is not None: + # Camera control type for image 2 video is always simple + camera_control.type = KlingCameraControlType.simple + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_IMAGE_TO_VIDEO, + method=HttpMethod.POST, + request_model=KlingImage2VideoRequest, + response_model=KlingImage2VideoResponse, + ), + request=KlingImage2VideoRequest( + model_name=KlingVideoGenModelName(model_name), + image=tensor_to_base64_string(start_frame), + image_tail=( + tensor_to_base64_string(end_frame) + if end_frame is not None + else None + ), + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + cfg_scale=cfg_scale, + mode=KlingVideoGenMode(mode), + aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), + duration=KlingVideoGenDuration(duration), + camera_control=camera_control, + ), + auth_token=auth_token, + ) + + task_creation_response = initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = self.get_response(task_id, auth_token) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return video_result_to_node_output(video) + + +class KlingCameraControlI2VNode(KlingImage2VideoNode): + """ + Kling Image to Video Camera Control Node. This node is a image to video node, but it supports controlling the camera. + Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_frame": model_field_to_node_input( + IO.IMAGE, KlingImage2VideoRequest, "image" + ), + "prompt": model_field_to_node_input( + IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + KlingImage2VideoRequest, + "negative_prompt", + multiline=True, + ), + "cfg_scale": model_field_to_node_input( + IO.FLOAT, + KlingImage2VideoRequest, + "cfg_scale", + default=0.75, + min=0.0, + max=1.0, + ), + "aspect_ratio": model_field_to_node_input( + IO.COMBO, + KlingImage2VideoRequest, + "aspect_ratio", + enum_type=KlingVideoGenAspectRatio, + ), + "camera_control": ( + "CAMERA_CONTROL", + { + "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." + + def api_call( + self, + start_frame: torch.Tensor, + prompt: str, + negative_prompt: str, + cfg_scale: float, + aspect_ratio: str, + camera_control: KlingCameraControl, + auth_token: Optional[str] = None, + ): + return super().api_call( + model_name=KlingVideoGenModelName.kling_v1_5, + start_frame=start_frame, + cfg_scale=cfg_scale, + mode=KlingVideoGenMode.pro, + aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), + duration=KlingVideoGenDuration.field_5, + prompt=prompt, + negative_prompt=negative_prompt, + camera_control=camera_control, + auth_token=auth_token, + ) + + +class KlingStartEndFrameNode(KlingImage2VideoNode): + """ + Kling First Last Frame Node. This node allows creation of a video from a first and last frame. It calls the normal image to video endpoint, but only allows the subset of input options that support the `image_tail` request field. + """ + + @staticmethod + def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]: + """ + Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. + Only includes config combos that support the `image_tail` request field. + + See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) + """ + return { + "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), + "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"), + "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), + "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), + "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), + "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), + } + + @classmethod + def INPUT_TYPES(s): + modes = list(KlingStartEndFrameNode.get_mode_string_mapping().keys()) + return { + "required": { + "start_frame": model_field_to_node_input( + IO.IMAGE, KlingImage2VideoRequest, "image" + ), + "end_frame": model_field_to_node_input( + IO.IMAGE, KlingImage2VideoRequest, "image_tail" + ), + "prompt": model_field_to_node_input( + IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + KlingImage2VideoRequest, + "negative_prompt", + multiline=True, + ), + "cfg_scale": model_field_to_node_input( + IO.FLOAT, + KlingImage2VideoRequest, + "cfg_scale", + default=0.5, + min=0.0, + max=1.0, + ), + "aspect_ratio": model_field_to_node_input( + IO.COMBO, + KlingImage2VideoRequest, + "aspect_ratio", + enum_type=KlingVideoGenAspectRatio, + ), + "mode": ( + modes, + { + "default": modes[2], + "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." + + def api_call( + self, + start_frame: torch.Tensor, + end_frame: torch.Tensor, + prompt: str, + negative_prompt: str, + cfg_scale: float, + aspect_ratio: str, + mode: str, + auth_token: Optional[str] = None, + ): + mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ + mode + ] + return super().api_call( + prompt=prompt, + negative_prompt=negative_prompt, + model_name=model_name, + start_frame=start_frame, + cfg_scale=cfg_scale, + mode=mode, + aspect_ratio=aspect_ratio, + duration=duration, + end_frame=end_frame, + auth_token=auth_token, + ) + + +class KlingVideoExtendNode(KlingNodeBase): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, KlingVideoExtendRequest, "prompt", multiline=True + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + KlingVideoExtendRequest, + "negative_prompt", + multiline=True, + ), + "cfg_scale": model_field_to_node_input( + IO.FLOAT, + KlingVideoExtendRequest, + "cfg_scale", + default=0.5, + min=0.0, + max=1.0, + ), + "video_id": model_field_to_node_input( + IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = ("VIDEO", "STRING", "STRING") + RETURN_NAMES = ("VIDEO", "video_id", "duration") + DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." + + def get_response(self, task_id: str, auth_token: str) -> KlingVideoExtendResponse: + return poll_until_finished( + auth_token, + ApiEndpoint( + path=f"{PATH_VIDEO_EXTEND}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingVideoExtendResponse, + ), + ) + + def api_call( + self, + prompt: str, + negative_prompt: str, + cfg_scale: float, + video_id: str, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile, str, str]: + validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_VIDEO_EXTEND, + method=HttpMethod.POST, + request_model=KlingVideoExtendRequest, + response_model=KlingVideoExtendResponse, + ), + request=KlingVideoExtendRequest( + prompt=prompt if prompt else None, + negative_prompt=negative_prompt if negative_prompt else None, + cfg_scale=cfg_scale, + video_id=video_id, + ), + auth_token=auth_token, + ) + + task_creation_response = initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = self.get_response(task_id, auth_token) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return video_result_to_node_output(video) + + +class KlingVideoEffectsBase(KlingNodeBase): + """Kling Video Effects Base""" + + RETURN_TYPES = ("VIDEO", "STRING", "STRING") + RETURN_NAMES = ("VIDEO", "video_id", "duration") + + def get_response(self, task_id: str, auth_token: str) -> KlingVideoEffectsResponse: + return poll_until_finished( + auth_token, + ApiEndpoint( + path=f"{PATH_VIDEO_EFFECTS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingVideoEffectsResponse, + ), + ) + + def api_call( + self, + dual_character: bool, + effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, + model_name: str, + duration: KlingVideoGenDuration, + image_1: torch.Tensor, + image_2: Optional[torch.Tensor] = None, + mode: Optional[KlingVideoGenMode] = None, + auth_token: Optional[str] = None, + ): + if dual_character: + request_input_field = KlingDualCharacterEffectInput( + model_name=model_name, + mode=mode, + images=[ + tensor_to_base64_string(image_1), + tensor_to_base64_string(image_2), + ], + duration=duration, + ) + else: + request_input_field = KlingSingleImageEffectInput( + model_name=model_name, + image=tensor_to_base64_string(image_1), + duration=duration, + ) + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_VIDEO_EFFECTS, + method=HttpMethod.POST, + request_model=KlingVideoEffectsRequest, + response_model=KlingVideoEffectsResponse, + ), + request=KlingVideoEffectsRequest( + effect_scene=effect_scene, + input=request_input_field, + ), + auth_token=auth_token, + ) + + task_creation_response = initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = self.get_response(task_id, auth_token) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return video_result_to_node_output(video) + + +class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): + """Kling Dual Character Video Effect Node""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image_left": (IO.IMAGE, {"tooltip": "Left side image"}), + "image_right": (IO.IMAGE, {"tooltip": "Right side image"}), + "effect_scene": model_field_to_node_input( + IO.COMBO, + KlingVideoEffectsRequest, + "effect_scene", + enum_type=KlingDualCharacterEffectsScene, + ), + "model_name": model_field_to_node_input( + IO.COMBO, + KlingDualCharacterEffectInput, + "model_name", + enum_type=KlingCharacterEffectModelName, + ), + "mode": model_field_to_node_input( + IO.COMBO, + KlingDualCharacterEffectInput, + "mode", + enum_type=KlingVideoGenMode, + ), + "duration": model_field_to_node_input( + IO.COMBO, + KlingDualCharacterEffectInput, + "duration", + enum_type=KlingVideoGenDuration, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite." + RETURN_TYPES = ("VIDEO", "STRING") + RETURN_NAMES = ("VIDEO", "duration") + + def api_call( + self, + image_left: torch.Tensor, + image_right: torch.Tensor, + effect_scene: KlingDualCharacterEffectsScene, + model_name: KlingCharacterEffectModelName, + mode: KlingVideoGenMode, + duration: KlingVideoGenDuration, + auth_token: Optional[str] = None, + ): + video, _, duration = super().api_call( + dual_character=True, + effect_scene=effect_scene, + model_name=model_name, + mode=mode, + duration=duration, + image_1=image_left, + image_2=image_right, + auth_token=auth_token, + ) + return video, duration + +class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): + """Kling Single Image Video Effect Node""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ( + IO.IMAGE, + { + "tooltip": " Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1" + }, + ), + "effect_scene": model_field_to_node_input( + IO.COMBO, + KlingVideoEffectsRequest, + "effect_scene", + enum_type=KlingSingleImageEffectsScene, + ), + "model_name": model_field_to_node_input( + IO.COMBO, + KlingSingleImageEffectInput, + "model_name", + enum_type=KlingSingleImageEffectModelName, + ), + "duration": model_field_to_node_input( + IO.COMBO, + KlingSingleImageEffectInput, + "duration", + enum_type=KlingVideoGenDuration, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." + + def api_call( + self, + image: torch.Tensor, + effect_scene: KlingSingleImageEffectsScene, + model_name: KlingSingleImageEffectModelName, + duration: KlingVideoGenDuration, + auth_token: Optional[str] = None, + ): + return super().api_call( + dual_character=False, + effect_scene=effect_scene, + model_name=model_name, + duration=duration, + image_1=image, + auth_token=auth_token, + ) + + +class KlingLipSyncBase(KlingNodeBase): + """Kling Lip Sync Base""" + + RETURN_TYPES = ("VIDEO", "STRING", "STRING") + RETURN_NAMES = ("VIDEO", "video_id", "duration") + + def validate_text(self, text: str): + if not text: + raise ValueError("Text is required") + if len(text) > MAX_PROMPT_LENGTH_LIP_SYNC: + raise ValueError( + f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." + ) + + def get_response(self, task_id: str, auth_token: str) -> KlingLipSyncResponse: + """Polls the Kling API endpoint until the task reaches a terminal state.""" + return poll_until_finished( + auth_token, + ApiEndpoint( + path=f"{PATH_LIP_SYNC}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingLipSyncResponse, + ), + ) + + def api_call( + self, + video: VideoInput, + audio: Optional[AudioInput] = None, + voice_language: Optional[str] = None, + mode: Optional[str] = None, + text: Optional[str] = None, + voice_speed: Optional[float] = None, + voice_id: Optional[str] = None, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile, str, str]: + if text: + self.validate_text(text) + + # Upload video to Comfy API and get download URL + video_url = upload_video_to_comfyapi(video, auth_token) + logging.info("Uploaded video to Comfy API. URL: %s", video_url) + + # Upload the audio file to Comfy API and get download URL + if audio: + audio_url = upload_audio_to_comfyapi(audio, auth_token) + logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) + else: + audio_url = None + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_LIP_SYNC, + method=HttpMethod.POST, + request_model=KlingLipSyncRequest, + response_model=KlingLipSyncResponse, + ), + request=KlingLipSyncRequest( + input=KlingLipSyncInputObject( + video_url=video_url, + mode=mode, + text=text, + voice_language=voice_language, + voice_speed=voice_speed, + audio_type="url", + audio_url=audio_url, + voice_id=voice_id, + ), + ), + auth_token=auth_token, + ) + + task_creation_response = initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = self.get_response(task_id, auth_token) + validate_video_result_response(final_response) + + video = get_video_from_response(final_response) + return video_result_to_node_output(video) + + +class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): + """Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "video": (IO.VIDEO, {}), + "audio": (IO.AUDIO, {}), + "voice_language": model_field_to_node_input( + IO.COMBO, + KlingLipSyncInputObject, + "voice_language", + enum_type=KlingLipSyncVoiceLanguage, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file." + + def api_call( + self, + video: VideoInput, + audio: AudioInput, + voice_language: str, + auth_token: Optional[str] = None, + ): + return super().api_call( + video=video, + audio=audio, + voice_language=voice_language, + mode="audio2video", + auth_token=auth_token, + ) + + +class KlingLipSyncTextToVideoNode(KlingLipSyncBase): + """Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt.""" + + @staticmethod + def get_voice_config() -> dict[str, tuple[str, str]]: + return { + # English voices + "Melody": ("girlfriend_4_speech02", "en"), + "Sunny": ("genshin_vindi2", "en"), + "Sage": ("zhinen_xuesheng", "en"), + "Ace": ("AOT", "en"), + "Blossom": ("ai_shatang", "en"), + "Peppy": ("genshin_klee2", "en"), + "Dove": ("genshin_kirara", "en"), + "Shine": ("ai_kaiya", "en"), + "Anchor": ("oversea_male1", "en"), + "Lyric": ("ai_chenjiahao_712", "en"), + "Tender": ("chat1_female_new-3", "en"), + "Siren": ("chat_0407_5-1", "en"), + "Zippy": ("cartoon-boy-07", "en"), + "Bud": ("uk_boy1", "en"), + "Sprite": ("cartoon-girl-01", "en"), + "Candy": ("PeppaPig_platform", "en"), + "Beacon": ("ai_huangzhong_712", "en"), + "Rock": ("ai_huangyaoshi_712", "en"), + "Titan": ("ai_laoguowang_712", "en"), + "Grace": ("chengshu_jiejie", "en"), + "Helen": ("you_pingjing", "en"), + "Lore": ("calm_story1", "en"), + "Crag": ("uk_man2", "en"), + "Prattle": ("laopopo_speech02", "en"), + "Hearth": ("heainainai_speech02", "en"), + "The Reader": ("reader_en_m-v1", "en"), + "Commercial Lady": ("commercial_lady_en_f-v1", "en"), + # Chinese voices + "阳光少年": ("genshin_vindi2", "zh"), + "懂事小弟": ("zhinen_xuesheng", "zh"), + "运动少年": ("tiyuxi_xuedi", "zh"), + "青春少女": ("ai_shatang", "zh"), + "温柔小妹": ("genshin_klee2", "zh"), + "元气少女": ("genshin_kirara", "zh"), + "阳光男生": ("ai_kaiya", "zh"), + "幽默小哥": ("tiexin_nanyou", "zh"), + "文艺小哥": ("ai_chenjiahao_712", "zh"), + "甜美邻家": ("girlfriend_1_speech02", "zh"), + "温柔姐姐": ("chat1_female_new-3", "zh"), + "职场女青": ("girlfriend_2_speech02", "zh"), + "活泼男童": ("cartoon-boy-07", "zh"), + "俏皮女童": ("cartoon-girl-01", "zh"), + "稳重老爸": ("ai_huangyaoshi_712", "zh"), + "温柔妈妈": ("you_pingjing", "zh"), + "严肃上司": ("ai_laoguowang_712", "zh"), + "优雅贵妇": ("chengshu_jiejie", "zh"), + "慈祥爷爷": ("zhuxi_speech02", "zh"), + "唠叨爷爷": ("uk_oldman3", "zh"), + "唠叨奶奶": ("laopopo_speech02", "zh"), + "和蔼奶奶": ("heainainai_speech02", "zh"), + "东北老铁": ("dongbeilaotie_speech02", "zh"), + "重庆小伙": ("chongqingxiaohuo_speech02", "zh"), + "四川妹子": ("chuanmeizi_speech02", "zh"), + "潮汕大叔": ("chaoshandashu_speech02", "zh"), + "台湾男生": ("ai_taiwan_man2_speech02", "zh"), + "西安掌柜": ("xianzhanggui_speech02", "zh"), + "天津姐姐": ("tianjinjiejie_speech02", "zh"), + "新闻播报男": ("diyinnansang_DB_CN_M_04-v2", "zh"), + "译制片男": ("yizhipiannan-v1", "zh"), + "撒娇女友": ("tianmeixuemei-v1", "zh"), + "刀片烟嗓": ("daopianyansang-v1", "zh"), + "乖巧正太": ("mengwa-v1", "zh"), + } + + @classmethod + def INPUT_TYPES(s): + voice_options = list(s.get_voice_config().keys()) + return { + "required": { + "video": (IO.VIDEO, {}), + "text": model_field_to_node_input( + IO.STRING, KlingLipSyncInputObject, "text", multiline=True + ), + "voice": (voice_options, {"default": voice_options[0]}), + "voice_speed": model_field_to_node_input( + IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt." + + def api_call( + self, + video: VideoInput, + text: str, + voice: str, + voice_speed: float, + auth_token: Optional[str] = None, + ): + voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] + return super().api_call( + video=video, + text=text, + voice_language=voice_language, + voice_id=voice_id, + voice_speed=voice_speed, + mode="text2video", + auth_token=auth_token, + ) + + +class KlingImageGenerationBase(KlingNodeBase): + """Kling Image Generation Base Node.""" + + RETURN_TYPES = ("IMAGE",) + CATEGORY = "api node/image/Kling" + + def validate_prompt(self, prompt: str, negative_prompt: Optional[str] = None): + if not prompt or len(prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN: + raise ValueError( + f"Prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters" + ) + if negative_prompt and len(negative_prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN: + raise ValueError( + f"Negative prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters" + ) + + +class KlingVirtualTryOnNode(KlingImageGenerationBase): + """Kling Virtual Try On Node.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "human_image": (IO.IMAGE, {}), + "cloth_image": (IO.IMAGE, {}), + "model_name": model_field_to_node_input( + IO.COMBO, + KlingVirtualTryOnRequest, + "model_name", + enum_type=KlingVirtualTryOnModelName, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human." + + def get_response( + self, task_id: str, auth_token: Optional[str] = None + ) -> KlingVirtualTryOnResponse: + return poll_until_finished( + auth_token, + ApiEndpoint( + path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingVirtualTryOnResponse, + ), + ) + + def api_call( + self, + human_image: torch.Tensor, + cloth_image: torch.Tensor, + model_name: KlingVirtualTryOnModelName, + auth_token: Optional[str] = None, + ): + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_VIRTUAL_TRY_ON, + method=HttpMethod.POST, + request_model=KlingVirtualTryOnRequest, + response_model=KlingVirtualTryOnResponse, + ), + request=KlingVirtualTryOnRequest( + human_image=tensor_to_base64_string(human_image), + cloth_image=tensor_to_base64_string(cloth_image), + model_name=model_name, + ), + auth_token=auth_token, + ) + + task_creation_response = initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = self.get_response(task_id, auth_token) + validate_image_result_response(final_response) + + images = get_images_from_response(final_response) + return (image_result_to_node_output(images),) + + +class KlingImageGenerationNode(KlingImageGenerationBase): + """Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": model_field_to_node_input( + IO.STRING, + KlingImageGenerationsRequest, + "prompt", + multiline=True, + max_length=MAX_PROMPT_LENGTH_IMAGE_GEN, + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + KlingImageGenerationsRequest, + "negative_prompt", + multiline=True, + ), + "image_type": model_field_to_node_input( + IO.COMBO, + KlingImageGenerationsRequest, + "image_reference", + enum_type=KlingImageGenImageReferenceType, + ), + "image_fidelity": model_field_to_node_input( + IO.FLOAT, + KlingImageGenerationsRequest, + "image_fidelity", + slider=True, + step=0.01, + ), + "human_fidelity": model_field_to_node_input( + IO.FLOAT, + KlingImageGenerationsRequest, + "human_fidelity", + slider=True, + step=0.01, + ), + "model_name": model_field_to_node_input( + IO.COMBO, + KlingImageGenerationsRequest, + "model_name", + enum_type=KlingImageGenModelName, + ), + "aspect_ratio": model_field_to_node_input( + IO.COMBO, + KlingImageGenerationsRequest, + "aspect_ratio", + enum_type=KlingImageGenAspectRatio, + ), + "n": model_field_to_node_input( + IO.INT, + KlingImageGenerationsRequest, + "n", + ), + }, + "optional": { + "image": (IO.IMAGE, {}), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." + + def get_response( + self, task_id: str, auth_token: Optional[str] = None + ) -> KlingImageGenerationsResponse: + return poll_until_finished( + auth_token, + ApiEndpoint( + path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=KlingImageGenerationsResponse, + ), + ) + + def api_call( + self, + model_name: KlingImageGenModelName, + prompt: str, + negative_prompt: str, + image_type: KlingImageGenImageReferenceType, + image_fidelity: float, + human_fidelity: float, + n: int, + aspect_ratio: KlingImageGenAspectRatio, + image: Optional[torch.Tensor] = None, + auth_token: Optional[str] = None, + ): + self.validate_prompt(prompt, negative_prompt) + + if image is not None: + image = tensor_to_base64_string(image) + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_IMAGE_GENERATIONS, + method=HttpMethod.POST, + request_model=KlingImageGenerationsRequest, + response_model=KlingImageGenerationsResponse, + ), + request=KlingImageGenerationsRequest( + model_name=model_name, + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + image_reference=image_type, + image_fidelity=image_fidelity, + human_fidelity=human_fidelity, + n=n, + aspect_ratio=aspect_ratio, + ), + auth_token=auth_token, + ) + + task_creation_response = initial_operation.execute() + validate_task_creation_response(task_creation_response) + task_id = task_creation_response.data.task_id + + final_response = self.get_response(task_id, auth_token) + validate_image_result_response(final_response) + + images = get_images_from_response(final_response) + return (image_result_to_node_output(images),) + + +NODE_CLASS_MAPPINGS = { + "KlingCameraControls": KlingCameraControls, + "KlingTextToVideoNode": KlingTextToVideoNode, + "KlingImage2VideoNode": KlingImage2VideoNode, + "KlingCameraControlI2VNode": KlingCameraControlI2VNode, + "KlingCameraControlT2VNode": KlingCameraControlT2VNode, + "KlingStartEndFrameNode": KlingStartEndFrameNode, + "KlingVideoExtendNode": KlingVideoExtendNode, + "KlingLipSyncAudioToVideoNode": KlingLipSyncAudioToVideoNode, + "KlingLipSyncTextToVideoNode": KlingLipSyncTextToVideoNode, + "KlingVirtualTryOnNode": KlingVirtualTryOnNode, + "KlingImageGenerationNode": KlingImageGenerationNode, + "KlingSingleImageVideoEffectNode": KlingSingleImageVideoEffectNode, + "KlingDualCharacterVideoEffectNode": KlingDualCharacterVideoEffectNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "KlingCameraControls": "Kling Camera Controls", + "KlingTextToVideoNode": "Kling Text to Video", + "KlingImage2VideoNode": "Kling Image to Video", + "KlingCameraControlI2VNode": "Kling Image to Video (Camera Control)", + "KlingCameraControlT2VNode": "Kling Text to Video (Camera Control)", + "KlingStartEndFrameNode": "Kling Start-End Frame to Video", + "KlingVideoExtendNode": "Kling Video Extend", + "KlingLipSyncAudioToVideoNode": "Kling Lip Sync Video with Audio", + "KlingLipSyncTextToVideoNode": "Kling Lip Sync Video with Text", + "KlingVirtualTryOnNode": "Kling Virtual Try On", + "KlingImageGenerationNode": "Kling Image Generation", + "KlingSingleImageVideoEffectNode": "Kling Video Effects", + "KlingDualCharacterVideoEffectNode": "Kling Dual Character Video Effects", +} diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py new file mode 100644 index 00000000..0f0d9aa8 --- /dev/null +++ b/comfy_api_nodes/nodes_luma.py @@ -0,0 +1,702 @@ +from inspect import cleandoc +from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.input_impl.video_types import VideoFromFile +from comfy_api_nodes.apis.luma_api import ( + LumaImageModel, + LumaVideoModel, + LumaVideoOutputResolution, + LumaVideoModelOutputDuration, + LumaAspectRatio, + LumaState, + LumaImageGenerationRequest, + LumaGenerationRequest, + LumaGeneration, + LumaCharacterRef, + LumaModifyImageRef, + LumaImageIdentity, + LumaReference, + LumaReferenceChain, + LumaImageReference, + LumaKeyframes, + LumaConceptChain, + LumaIO, + get_luma_concepts, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + upload_images_to_comfyapi, + process_image_response, + validate_string, +) + +import requests +import torch +from io import BytesIO + + +class LumaReferenceNode(ComfyNodeABC): + """ + Holds an image and weight for use with Luma Generate Image node. + """ + + RETURN_TYPES = (LumaIO.LUMA_REF,) + RETURN_NAMES = ("luma_ref",) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "create_luma_reference" + CATEGORY = "api node/image/Luma" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ( + IO.IMAGE, + { + "tooltip": "Image to use as reference.", + }, + ), + "weight": ( + IO.FLOAT, + { + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Weight of image reference.", + }, + ), + }, + "optional": {"luma_ref": (LumaIO.LUMA_REF,)}, + } + + def create_luma_reference( + self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None + ): + if luma_ref is not None: + luma_ref = luma_ref.clone() + else: + luma_ref = LumaReferenceChain() + luma_ref.add(LumaReference(image=image, weight=round(weight, 2))) + return (luma_ref,) + + +class LumaConceptsNode(ComfyNodeABC): + """ + Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. + """ + + RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,) + RETURN_NAMES = ("luma_concepts",) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "create_concepts" + CATEGORY = "api node/video/Luma" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "concept1": (get_luma_concepts(include_none=True),), + "concept2": (get_luma_concepts(include_none=True),), + "concept3": (get_luma_concepts(include_none=True),), + "concept4": (get_luma_concepts(include_none=True),), + }, + "optional": { + "luma_concepts": ( + LumaIO.LUMA_CONCEPTS, + { + "tooltip": "Optional Camera Concepts to add to the ones chosen here." + }, + ), + }, + } + + def create_concepts( + self, + concept1: str, + concept2: str, + concept3: str, + concept4: str, + luma_concepts: LumaConceptChain = None, + ): + chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4]) + if luma_concepts is not None: + chain = luma_concepts.clone_and_merge(chain) + return (chain,) + + +class LumaImageGenerationNode(ComfyNodeABC): + """ + Generates images synchronously based on prompt and aspect ratio. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Luma" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "model": ([model.value for model in LumaImageModel],), + "aspect_ratio": ( + [ratio.value for ratio in LumaAspectRatio], + { + "default": LumaAspectRatio.ratio_16_9, + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + "style_image_weight": ( + IO.FLOAT, + { + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Weight of style image. Ignored if no style_image provided.", + }, + ), + }, + "optional": { + "image_luma_ref": ( + LumaIO.LUMA_REF, + { + "tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered." + }, + ), + "style_image": ( + IO.IMAGE, + {"tooltip": "Style reference image; only 1 image will be used."}, + ), + "character_image": ( + IO.IMAGE, + { + "tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + model: str, + aspect_ratio: str, + seed, + style_image_weight: float, + image_luma_ref: LumaReferenceChain = None, + style_image: torch.Tensor = None, + character_image: torch.Tensor = None, + auth_token=None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=True, min_length=3) + # handle image_luma_ref + api_image_ref = None + if image_luma_ref is not None: + api_image_ref = self._convert_luma_refs( + image_luma_ref, max_refs=4, auth_token=auth_token + ) + # handle style_luma_ref + api_style_ref = None + if style_image is not None: + api_style_ref = self._convert_style_image( + style_image, weight=style_image_weight, auth_token=auth_token + ) + # handle character_ref images + character_ref = None + if character_image is not None: + download_urls = upload_images_to_comfyapi( + character_image, max_images=4, auth_token=auth_token + ) + character_ref = LumaCharacterRef( + identity0=LumaImageIdentity(images=download_urls) + ) + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/luma/generations/image", + method=HttpMethod.POST, + request_model=LumaImageGenerationRequest, + response_model=LumaGeneration, + ), + request=LumaImageGenerationRequest( + prompt=prompt, + model=model, + aspect_ratio=aspect_ratio, + image_ref=api_image_ref, + style_ref=api_style_ref, + character_ref=character_ref, + ), + auth_token=auth_token, + ) + response_api: LumaGeneration = operation.execute() + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/luma/generations/{response_api.id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=LumaGeneration, + ), + completed_statuses=[LumaState.completed], + failed_statuses=[LumaState.failed], + status_extractor=lambda x: x.state, + auth_token=auth_token, + ) + response_poll = operation.execute() + + img_response = requests.get(response_poll.assets.image) + img = process_image_response(img_response) + return (img,) + + def _convert_luma_refs( + self, luma_ref: LumaReferenceChain, max_refs: int, auth_token=None + ): + luma_urls = [] + ref_count = 0 + for ref in luma_ref.refs: + download_urls = upload_images_to_comfyapi( + ref.image, max_images=1, auth_token=auth_token + ) + luma_urls.append(download_urls[0]) + ref_count += 1 + if ref_count >= max_refs: + break + return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) + + def _convert_style_image( + self, style_image: torch.Tensor, weight: float, auth_token=None + ): + chain = LumaReferenceChain( + first_ref=LumaReference(image=style_image, weight=weight) + ) + return self._convert_luma_refs(chain, max_refs=1, auth_token=auth_token) + + +class LumaImageModifyNode(ComfyNodeABC): + """ + Modifies images synchronously based on prompt and aspect ratio. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Luma" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation", + }, + ), + "image_weight": ( + IO.FLOAT, + { + "default": 0.1, + "min": 0.0, + "max": 0.98, + "step": 0.01, + "tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.", + }, + ), + "model": ([model.value for model in LumaImageModel],), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": {}, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + model: str, + image: torch.Tensor, + image_weight: float, + seed, + auth_token=None, + **kwargs, + ): + # first, upload image + download_urls = upload_images_to_comfyapi( + image, max_images=1, auth_token=auth_token + ) + image_url = download_urls[0] + # next, make Luma call with download url provided + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/luma/generations/image", + method=HttpMethod.POST, + request_model=LumaImageGenerationRequest, + response_model=LumaGeneration, + ), + request=LumaImageGenerationRequest( + prompt=prompt, + model=model, + modify_image_ref=LumaModifyImageRef( + url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) + ), + ), + auth_token=auth_token, + ) + response_api: LumaGeneration = operation.execute() + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/luma/generations/{response_api.id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=LumaGeneration, + ), + completed_statuses=[LumaState.completed], + failed_statuses=[LumaState.failed], + status_extractor=lambda x: x.state, + auth_token=auth_token, + ) + response_poll = operation.execute() + + img_response = requests.get(response_poll.assets.image) + img = process_image_response(img_response) + return (img,) + + +class LumaTextToVideoGenerationNode(ComfyNodeABC): + """ + Generates videos synchronously based on prompt and output_size. + """ + + RETURN_TYPES = (IO.VIDEO,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/video/Luma" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the video generation", + }, + ), + "model": ([model.value for model in LumaVideoModel],), + "aspect_ratio": ( + [ratio.value for ratio in LumaAspectRatio], + { + "default": LumaAspectRatio.ratio_16_9, + }, + ), + "resolution": ( + [resolution.value for resolution in LumaVideoOutputResolution], + { + "default": LumaVideoOutputResolution.res_540p, + }, + ), + "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), + "loop": ( + IO.BOOLEAN, + { + "default": False, + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": { + "luma_concepts": ( + LumaIO.LUMA_CONCEPTS, + { + "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + model: str, + aspect_ratio: str, + resolution: str, + duration: str, + loop: bool, + seed, + luma_concepts: LumaConceptChain = None, + auth_token=None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False, min_length=3) + duration = duration if model != LumaVideoModel.ray_1_6 else None + resolution = resolution if model != LumaVideoModel.ray_1_6 else None + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/luma/generations", + method=HttpMethod.POST, + request_model=LumaGenerationRequest, + response_model=LumaGeneration, + ), + request=LumaGenerationRequest( + prompt=prompt, + model=model, + resolution=resolution, + aspect_ratio=aspect_ratio, + duration=duration, + loop=loop, + concepts=luma_concepts.create_api_model() if luma_concepts else None, + ), + auth_token=auth_token, + ) + response_api: LumaGeneration = operation.execute() + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/luma/generations/{response_api.id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=LumaGeneration, + ), + completed_statuses=[LumaState.completed], + failed_statuses=[LumaState.failed], + status_extractor=lambda x: x.state, + auth_token=auth_token, + ) + response_poll = operation.execute() + + vid_response = requests.get(response_poll.assets.video) + return (VideoFromFile(BytesIO(vid_response.content)),) + + +class LumaImageToVideoGenerationNode(ComfyNodeABC): + """ + Generates videos synchronously based on prompt, input images, and output_size. + """ + + RETURN_TYPES = (IO.VIDEO,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/video/Luma" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the video generation", + }, + ), + "model": ([model.value for model in LumaVideoModel],), + # "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], { + # "default": LumaAspectRatio.ratio_16_9, + # }), + "resolution": ( + [resolution.value for resolution in LumaVideoOutputResolution], + { + "default": LumaVideoOutputResolution.res_540p, + }, + ), + "duration": ([dur.value for dur in LumaVideoModelOutputDuration],), + "loop": ( + IO.BOOLEAN, + { + "default": False, + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": { + "first_image": ( + IO.IMAGE, + {"tooltip": "First frame of generated video."}, + ), + "last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}), + "luma_concepts": ( + LumaIO.LUMA_CONCEPTS, + { + "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + model: str, + resolution: str, + duration: str, + loop: bool, + seed, + first_image: torch.Tensor = None, + last_image: torch.Tensor = None, + luma_concepts: LumaConceptChain = None, + auth_token=None, + **kwargs, + ): + if first_image is None and last_image is None: + raise Exception( + "At least one of first_image and last_image requires an input." + ) + keyframes = self._convert_to_keyframes(first_image, last_image, auth_token) + duration = duration if model != LumaVideoModel.ray_1_6 else None + resolution = resolution if model != LumaVideoModel.ray_1_6 else None + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/luma/generations", + method=HttpMethod.POST, + request_model=LumaGenerationRequest, + response_model=LumaGeneration, + ), + request=LumaGenerationRequest( + prompt=prompt, + model=model, + aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason + resolution=resolution, + duration=duration, + loop=loop, + keyframes=keyframes, + concepts=luma_concepts.create_api_model() if luma_concepts else None, + ), + auth_token=auth_token, + ) + response_api: LumaGeneration = operation.execute() + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/luma/generations/{response_api.id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=LumaGeneration, + ), + completed_statuses=[LumaState.completed], + failed_statuses=[LumaState.failed], + status_extractor=lambda x: x.state, + auth_token=auth_token, + ) + response_poll = operation.execute() + + vid_response = requests.get(response_poll.assets.video) + return (VideoFromFile(BytesIO(vid_response.content)),) + + def _convert_to_keyframes( + self, + first_image: torch.Tensor = None, + last_image: torch.Tensor = None, + auth_token=None, + ): + if first_image is None and last_image is None: + return None + frame0 = None + frame1 = None + if first_image is not None: + download_urls = upload_images_to_comfyapi( + first_image, max_images=1, auth_token=auth_token + ) + frame0 = LumaImageReference(type="image", url=download_urls[0]) + if last_image is not None: + download_urls = upload_images_to_comfyapi( + last_image, max_images=1, auth_token=auth_token + ) + frame1 = LumaImageReference(type="image", url=download_urls[0]) + return LumaKeyframes(frame0=frame0, frame1=frame1) + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "LumaImageNode": LumaImageGenerationNode, + "LumaImageModifyNode": LumaImageModifyNode, + "LumaVideoNode": LumaTextToVideoGenerationNode, + "LumaImageToVideoNode": LumaImageToVideoGenerationNode, + "LumaReferenceNode": LumaReferenceNode, + "LumaConceptsNode": LumaConceptsNode, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "LumaImageNode": "Luma Text to Image", + "LumaImageModifyNode": "Luma Image to Image", + "LumaVideoNode": "Luma Text to Video", + "LumaImageToVideoNode": "Luma Image to Video", + "LumaReferenceNode": "Luma Reference", + "LumaConceptsNode": "Luma Concepts", +} diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py new file mode 100644 index 00000000..cacda22c --- /dev/null +++ b/comfy_api_nodes/nodes_minimax.py @@ -0,0 +1,306 @@ +from comfy.comfy_types.node_typing import IO +from comfy_api.input_impl.video_types import VideoFromFile +from comfy_api_nodes.apis import ( + MinimaxVideoGenerationRequest, + MinimaxVideoGenerationResponse, + MinimaxFileRetrieveResponse, + MinimaxTaskResultResponse, + SubjectReferenceItem, + Model +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + download_url_to_bytesio, + upload_images_to_comfyapi, + validate_string, +) + +import torch +import logging + + +class MinimaxTextToVideoNode: + """ + Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt_text": ( + "STRING", + { + "multiline": True, + "default": "", + "tooltip": "Text prompt to guide the video generation", + }, + ), + "model": ( + [ + "T2V-01", + "T2V-01-Director", + ], + { + "default": "T2V-01", + "tooltip": "Model to use for video generation", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = ("VIDEO",) + DESCRIPTION = "Generates videos from prompts using MiniMax's API" + FUNCTION = "generate_video" + CATEGORY = "api node/video/MiniMax" + API_NODE = True + OUTPUT_NODE = True + + def generate_video( + self, + prompt_text, + seed=0, + model="T2V-01", + image: torch.Tensor=None, # used for ImageToVideo + subject: torch.Tensor=None, # used for SubjectToVideo + auth_token=None, + ): + ''' + Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments. + ''' + if image is None: + validate_string(prompt_text, field_name="prompt_text") + # upload image, if passed in + image_url = None + if image is not None: + image_url = upload_images_to_comfyapi(image, max_images=1, auth_token=auth_token)[0] + + # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model + subject_reference = None + if subject is not None: + subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_token=auth_token)[0] + subject_reference = [SubjectReferenceItem(image=subject_url)] + + + video_generate_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/video_generation", + method=HttpMethod.POST, + request_model=MinimaxVideoGenerationRequest, + response_model=MinimaxVideoGenerationResponse, + ), + request=MinimaxVideoGenerationRequest( + model=Model(model), + prompt=prompt_text, + callback_url=None, + first_frame_image=image_url, + subject_reference=subject_reference, + prompt_optimizer=None, + ), + auth_token=auth_token, + ) + response = video_generate_operation.execute() + + task_id = response.task_id + if not task_id: + raise Exception(f"MiniMax generation failed: {response.base_resp}") + + video_generate_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/minimax/query/video_generation", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxTaskResultResponse, + query_params={"task_id": task_id}, + ), + completed_statuses=["Success"], + failed_statuses=["Fail"], + status_extractor=lambda x: x.status.value, + auth_token=auth_token, + ) + task_result = video_generate_operation.execute() + + file_id = task_result.file_id + if file_id is None: + raise Exception("Request was not successful. Missing file ID.") + file_retrieve_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/files/retrieve", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxFileRetrieveResponse, + query_params={"file_id": int(file_id)}, + ), + request=EmptyRequest(), + auth_token=auth_token, + ) + file_result = file_retrieve_operation.execute() + + file_url = file_result.file.download_url + if file_url is None: + raise Exception( + f"No video was found in the response. Full response: {file_result.model_dump()}" + ) + logging.info(f"Generated video URL: {file_url}") + + video_io = download_url_to_bytesio(file_url) + if video_io is None: + error_msg = f"Failed to download video from {file_url}" + logging.error(error_msg) + raise Exception(error_msg) + return (VideoFromFile(video_io),) + + +class MinimaxImageToVideoNode(MinimaxTextToVideoNode): + """ + Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ( + IO.IMAGE, + { + "tooltip": "Image to use as first frame of video generation" + }, + ), + "prompt_text": ( + "STRING", + { + "multiline": True, + "default": "", + "tooltip": "Text prompt to guide the video generation", + }, + ), + "model": ( + [ + "I2V-01-Director", + "I2V-01", + "I2V-01-live", + ], + { + "default": "I2V-01", + "tooltip": "Model to use for video generation", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = ("VIDEO",) + DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" + FUNCTION = "generate_video" + CATEGORY = "api node/video/MiniMax" + API_NODE = True + OUTPUT_NODE = True + + +class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): + """ + Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "subject": ( + IO.IMAGE, + { + "tooltip": "Image of subject to reference video generation" + }, + ), + "prompt_text": ( + "STRING", + { + "multiline": True, + "default": "", + "tooltip": "Text prompt to guide the video generation", + }, + ), + "model": ( + [ + "S2V-01", + ], + { + "default": "S2V-01", + "tooltip": "Model to use for video generation", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = ("VIDEO",) + DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API" + FUNCTION = "generate_video" + CATEGORY = "api node/video/MiniMax" + API_NODE = True + OUTPUT_NODE = True + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "MinimaxTextToVideoNode": MinimaxTextToVideoNode, + "MinimaxImageToVideoNode": MinimaxImageToVideoNode, + # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "MinimaxTextToVideoNode": "MiniMax Text to Video", + "MinimaxImageToVideoNode": "MiniMax Image to Video", + "MinimaxSubjectToVideoNode": "MiniMax Subject to Video", +} diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py new file mode 100644 index 00000000..c18c65d7 --- /dev/null +++ b/comfy_api_nodes/nodes_openai.py @@ -0,0 +1,487 @@ +import io +from inspect import cleandoc +import numpy as np +import torch +from PIL import Image + +from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict + + +from comfy_api_nodes.apis import ( + OpenAIImageGenerationRequest, + OpenAIImageEditRequest, + OpenAIImageGenerationResponse, +) + +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, +) + +from comfy_api_nodes.apinode_utils import ( + downscale_image_tensor, + validate_and_cast_response, + validate_string, +) + +class OpenAIDalle2(ComfyNodeABC): + """ + Generates images synchronously via OpenAI's DALL·E 2 endpoint. + """ + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Text prompt for DALL·E", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2**31 - 1, + "step": 1, + "display": "number", + "control_after_generate": True, + "tooltip": "not implemented yet in backend", + }, + ), + "size": ( + IO.COMBO, + { + "options": ["256x256", "512x512", "1024x1024"], + "default": "1024x1024", + "tooltip": "Image size", + }, + ), + "n": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 8, + "step": 1, + "display": "number", + "tooltip": "How many images to generate", + }, + ), + "image": ( + IO.IMAGE, + { + "default": None, + "tooltip": "Optional reference image for image editing.", + }, + ), + "mask": ( + IO.MASK, + { + "default": None, + "tooltip": "Optional mask for inpainting (white areas will be replaced)", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node/image/OpenAI" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call( + self, + prompt, + seed=0, + image=None, + mask=None, + n=1, + size="1024x1024", + auth_token=None, + ): + validate_string(prompt, strip_whitespace=False) + model = "dall-e-2" + path = "/proxy/openai/images/generations" + content_type = "application/json" + request_class = OpenAIImageGenerationRequest + img_binary = None + + if image is not None and mask is not None: + path = "/proxy/openai/images/edits" + content_type = "multipart/form-data" + request_class = OpenAIImageEditRequest + + input_tensor = image.squeeze().cpu() + height, width, channels = input_tensor.shape + rgba_tensor = torch.ones(height, width, 4, device="cpu") + rgba_tensor[:, :, :channels] = input_tensor + + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + rgba_tensor[:, :, 3] = 1 - mask.squeeze().cpu() + + rgba_tensor = downscale_image_tensor(rgba_tensor.unsqueeze(0)).squeeze() + + image_np = (rgba_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format="PNG") + img_byte_arr.seek(0) + img_binary = img_byte_arr # .getvalue() + img_binary.name = "image.png" + elif image is not None or mask is not None: + raise Exception("Dall-E 2 image editing requires an image AND a mask") + + # Build the operation + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=request_class, + response_model=OpenAIImageGenerationResponse, + ), + request=request_class( + model=model, + prompt=prompt, + n=n, + size=size, + seed=seed, + ), + files=( + { + "image": img_binary, + } + if img_binary + else None + ), + content_type=content_type, + auth_token=auth_token, + ) + + response = operation.execute() + + img_tensor = validate_and_cast_response(response) + return (img_tensor,) + + +class OpenAIDalle3(ComfyNodeABC): + """ + Generates images synchronously via OpenAI's DALL·E 3 endpoint. + """ + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Text prompt for DALL·E", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2**31 - 1, + "step": 1, + "display": "number", + "control_after_generate": True, + "tooltip": "not implemented yet in backend", + }, + ), + "quality": ( + IO.COMBO, + { + "options": ["standard", "hd"], + "default": "standard", + "tooltip": "Image quality", + }, + ), + "style": ( + IO.COMBO, + { + "options": ["natural", "vivid"], + "default": "natural", + "tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", + }, + ), + "size": ( + IO.COMBO, + { + "options": ["1024x1024", "1024x1792", "1792x1024"], + "default": "1024x1024", + "tooltip": "Image size", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node/image/OpenAI" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call( + self, + prompt, + seed=0, + style="natural", + quality="standard", + size="1024x1024", + auth_token=None, + ): + validate_string(prompt, strip_whitespace=False) + model = "dall-e-3" + + # build the operation + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/openai/images/generations", + method=HttpMethod.POST, + request_model=OpenAIImageGenerationRequest, + response_model=OpenAIImageGenerationResponse, + ), + request=OpenAIImageGenerationRequest( + model=model, + prompt=prompt, + quality=quality, + size=size, + style=style, + seed=seed, + ), + auth_token=auth_token, + ) + + response = operation.execute() + + img_tensor = validate_and_cast_response(response) + return (img_tensor,) + + +class OpenAIGPTImage1(ComfyNodeABC): + """ + Generates images synchronously via OpenAI's GPT Image 1 endpoint. + """ + + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Text prompt for GPT Image 1", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2**31 - 1, + "step": 1, + "display": "number", + "control_after_generate": True, + "tooltip": "not implemented yet in backend", + }, + ), + "quality": ( + IO.COMBO, + { + "options": ["low", "medium", "high"], + "default": "low", + "tooltip": "Image quality, affects cost and generation time.", + }, + ), + "background": ( + IO.COMBO, + { + "options": ["opaque", "transparent"], + "default": "opaque", + "tooltip": "Return image with or without background", + }, + ), + "size": ( + IO.COMBO, + { + "options": ["auto", "1024x1024", "1024x1536", "1536x1024"], + "default": "auto", + "tooltip": "Image size", + }, + ), + "n": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 8, + "step": 1, + "display": "number", + "tooltip": "How many images to generate", + }, + ), + "image": ( + IO.IMAGE, + { + "default": None, + "tooltip": "Optional reference image for image editing.", + }, + ), + "mask": ( + IO.MASK, + { + "default": None, + "tooltip": "Optional mask for inpainting (white areas will be replaced)", + }, + ), + }, + "hidden": {"auth_token": "AUTH_TOKEN_COMFY_ORG"}, + } + + RETURN_TYPES = (IO.IMAGE,) + FUNCTION = "api_call" + CATEGORY = "api node/image/OpenAI" + DESCRIPTION = cleandoc(__doc__ or "") + API_NODE = True + + def api_call( + self, + prompt, + seed=0, + quality="low", + background="opaque", + image=None, + mask=None, + n=1, + size="1024x1024", + auth_token=None, + ): + validate_string(prompt, strip_whitespace=False) + model = "gpt-image-1" + path = "/proxy/openai/images/generations" + content_type="application/json" + request_class = OpenAIImageGenerationRequest + img_binaries = [] + mask_binary = None + files = [] + + if image is not None: + path = "/proxy/openai/images/edits" + request_class = OpenAIImageEditRequest + content_type ="multipart/form-data" + + batch_size = image.shape[0] + + for i in range(batch_size): + single_image = image[i : i + 1] + scaled_image = downscale_image_tensor(single_image).squeeze() + + image_np = (scaled_image.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format="PNG") + img_byte_arr.seek(0) + img_binary = img_byte_arr + img_binary.name = f"image_{i}.png" + + img_binaries.append(img_binary) + if batch_size == 1: + files.append(("image", img_binary)) + else: + files.append(("image[]", img_binary)) + + if mask is not None: + if image is None: + raise Exception("Cannot use a mask without an input image") + if image.shape[0] != 1: + raise Exception("Cannot use a mask with multiple image") + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + batch, height, width = mask.shape + rgba_mask = torch.zeros(height, width, 4, device="cpu") + rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() + + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze() + + mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_img_byte_arr = io.BytesIO() + mask_img.save(mask_img_byte_arr, format="PNG") + mask_img_byte_arr.seek(0) + mask_binary = mask_img_byte_arr + mask_binary.name = "mask.png" + files.append(("mask", mask_binary)) + + # Build the operation + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=request_class, + response_model=OpenAIImageGenerationResponse, + ), + request=request_class( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + ), + files=files if files else None, + content_type=content_type, + auth_token=auth_token, + ) + + response = operation.execute() + + img_tensor = validate_and_cast_response(response) + return (img_tensor,) + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "OpenAIDalle2": OpenAIDalle2, + "OpenAIDalle3": OpenAIDalle3, + "OpenAIGPTImage1": OpenAIGPTImage1, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "OpenAIDalle2": "OpenAI DALL·E 2", + "OpenAIDalle3": "OpenAI DALL·E 3", + "OpenAIGPTImage1": "OpenAI GPT Image 1", +} diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py new file mode 100644 index 00000000..ba4e8457 --- /dev/null +++ b/comfy_api_nodes/nodes_pika.py @@ -0,0 +1,749 @@ +""" +Pika x ComfyUI API Nodes + +Pika API docs: https://pika-827374fb.mintlify.app/api-reference +""" + +import io +from typing import Optional, TypeVar +import logging +import torch +import numpy as np +from comfy_api_nodes.apis import ( + PikaBodyGenerate22T2vGenerate22T2vPost, + PikaGenerateResponse, + PikaBodyGenerate22I2vGenerate22I2vPost, + PikaVideoResponse, + PikaBodyGenerate22C2vGenerate22PikascenesPost, + IngredientsMode, + PikaDurationEnum, + PikaResolutionEnum, + PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + PikaBodyGeneratePikaswapsGeneratePikaswapsPost, + PikaBodyGenerate22KeyframeGenerate22PikaframesPost, + Pikaffect, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + tensor_to_bytesio, + download_url_to_video_output, +) +from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec +from comfy_api.input_impl import VideoFromFile +from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions + +R = TypeVar("R") + +PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions" +PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps" +PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects" + +PIKA_API_VERSION = "2.2" +PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v" +PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v" +PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes" +PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes" + +PATH_VIDEO_GET = "/proxy/pika/videos" + + +class PikaApiError(Exception): + """Exception for Pika API errors.""" + + pass + + +def is_valid_video_response(response: PikaVideoResponse) -> bool: + """Check if the video response is valid.""" + return hasattr(response, "url") and response.url is not None + + +def is_valid_initial_response(response: PikaGenerateResponse) -> bool: + """Check if the initial response is valid.""" + return hasattr(response, "video_id") and response.video_id is not None + + +class PikaNodeBase(ComfyNodeABC): + """Base class for Pika nodes.""" + + @classmethod + def get_base_inputs_types( + cls, request_model + ) -> dict[str, tuple[IO, InputTypeOptions]]: + """Get the base required inputs types common to all Pika nodes.""" + return { + "prompt_text": model_field_to_node_input( + IO.STRING, + request_model, + "promptText", + multiline=True, + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + request_model, + "negativePrompt", + multiline=True, + ), + "seed": model_field_to_node_input( + IO.INT, + request_model, + "seed", + min=0, + max=0xFFFFFFFF, + control_after_generate=True, + ), + "resolution": model_field_to_node_input( + IO.COMBO, + request_model, + "resolution", + enum_type=PikaResolutionEnum, + ), + "duration": model_field_to_node_input( + IO.COMBO, + request_model, + "duration", + enum_type=PikaDurationEnum, + ), + } + + CATEGORY = "api node/video/Pika" + API_NODE = True + FUNCTION = "api_call" + RETURN_TYPES = ("VIDEO",) + + def poll_for_task_status( + self, task_id: str, auth_token: str + ) -> PikaGenerateResponse: + polling_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"{PATH_VIDEO_GET}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=PikaVideoResponse, + ), + completed_statuses=[ + "finished", + ], + failed_statuses=["failed", "cancelled"], + status_extractor=lambda response: ( + response.status.value if response.status else None + ), + progress_extractor=lambda response: ( + response.progress if hasattr(response, "progress") else None + ), + auth_token=auth_token, + ) + return polling_operation.execute() + + def execute_task( + self, + initial_operation: SynchronousOperation[R, PikaGenerateResponse], + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + """Executes the initial operation then polls for the task status until it is completed. + + Args: + initial_operation: The initial operation to execute. + auth_token: The authentication token to use for the API call. + + Returns: + A tuple containing the video file as a VIDEO output. + """ + initial_response = initial_operation.execute() + if not is_valid_initial_response(initial_response): + error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}" + logging.error(error_msg) + raise PikaApiError(error_msg) + + task_id = initial_response.video_id + final_response = self.poll_for_task_status(task_id, auth_token) + if not is_valid_video_response(final_response): + error_msg = ( + f"Pika task {task_id} succeeded but no video data found in response." + ) + logging.error(error_msg) + raise PikaApiError(error_msg) + + video_url = str(final_response.url) + logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) + + return (download_url_to_video_output(video_url),) + + +class PikaImageToVideoV2_2(PikaNodeBase): + """Pika 2.2 Image to Video Node.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ( + IO.IMAGE, + {"tooltip": "The image to convert to video"}, + ), + **cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video." + + def api_call( + self, + image: torch.Tensor, + prompt_text: str, + negative_prompt: str, + seed: int, + resolution: str, + duration: int, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + # Convert image to BytesIO + image_bytes_io = tensor_to_bytesio(image) + image_bytes_io.seek(0) + + pika_files = {"image": ("image.png", image_bytes_io, "image/png")} + + # Prepare non-file data + pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost( + promptText=prompt_text, + negativePrompt=negative_prompt, + seed=seed, + resolution=resolution, + duration=duration, + ) + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_IMAGE_TO_VIDEO, + method=HttpMethod.POST, + request_model=PikaBodyGenerate22I2vGenerate22I2vPost, + response_model=PikaGenerateResponse, + ), + request=pika_request_data, + files=pika_files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + + return self.execute_task(initial_operation, auth_token) + + +class PikaTextToVideoNodeV2_2(PikaNodeBase): + """Pika Text2Video v2.2 Node.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + **cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost), + "aspect_ratio": model_field_to_node_input( + IO.FLOAT, + PikaBodyGenerate22T2vGenerate22T2vPost, + "aspectRatio", + step=0.001, + min=0.4, + max=2.5, + default=1.7777777777777777, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video." + + def api_call( + self, + prompt_text: str, + negative_prompt: str, + seed: int, + resolution: str, + duration: int, + aspect_ratio: float, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_TEXT_TO_VIDEO, + method=HttpMethod.POST, + request_model=PikaBodyGenerate22T2vGenerate22T2vPost, + response_model=PikaGenerateResponse, + ), + request=PikaBodyGenerate22T2vGenerate22T2vPost( + promptText=prompt_text, + negativePrompt=negative_prompt, + seed=seed, + resolution=resolution, + duration=duration, + aspectRatio=aspect_ratio, + ), + auth_token=auth_token, + content_type="application/x-www-form-urlencoded", + ) + + return self.execute_task(initial_operation, auth_token) + + +class PikaScenesV2_2(PikaNodeBase): + """PikaScenes v2.2 Node.""" + + @classmethod + def INPUT_TYPES(cls): + image_ingredient_input = ( + IO.IMAGE, + {"tooltip": "Image that will be used as ingredient to create a video."}, + ) + return { + "required": { + **cls.get_base_inputs_types( + PikaBodyGenerate22C2vGenerate22PikascenesPost, + ), + "ingredients_mode": model_field_to_node_input( + IO.COMBO, + PikaBodyGenerate22C2vGenerate22PikascenesPost, + "ingredientsMode", + enum_type=IngredientsMode, + default="creative", + ), + "aspect_ratio": model_field_to_node_input( + IO.FLOAT, + PikaBodyGenerate22C2vGenerate22PikascenesPost, + "aspectRatio", + step=0.001, + min=0.4, + max=2.5, + default=1.7777777777777777, + ), + }, + "optional": { + "image_ingredient_1": image_ingredient_input, + "image_ingredient_2": image_ingredient_input, + "image_ingredient_3": image_ingredient_input, + "image_ingredient_4": image_ingredient_input, + "image_ingredient_5": image_ingredient_input, + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them." + + def api_call( + self, + prompt_text: str, + negative_prompt: str, + seed: int, + resolution: str, + duration: int, + ingredients_mode: str, + aspect_ratio: float, + image_ingredient_1: Optional[torch.Tensor] = None, + image_ingredient_2: Optional[torch.Tensor] = None, + image_ingredient_3: Optional[torch.Tensor] = None, + image_ingredient_4: Optional[torch.Tensor] = None, + image_ingredient_5: Optional[torch.Tensor] = None, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + # Convert all passed images to BytesIO + all_image_bytes_io = [] + for image in [ + image_ingredient_1, + image_ingredient_2, + image_ingredient_3, + image_ingredient_4, + image_ingredient_5, + ]: + if image is not None: + image_bytes_io = tensor_to_bytesio(image) + image_bytes_io.seek(0) + all_image_bytes_io.append(image_bytes_io) + + pika_files = [ + ("images", (f"image_{i}.png", image_bytes_io, "image/png")) + for i, image_bytes_io in enumerate(all_image_bytes_io) + ] + + pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost( + ingredientsMode=ingredients_mode, + promptText=prompt_text, + negativePrompt=negative_prompt, + seed=seed, + resolution=resolution, + duration=duration, + aspectRatio=aspect_ratio, + ) + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_PIKASCENES, + method=HttpMethod.POST, + request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost, + response_model=PikaGenerateResponse, + ), + request=pika_request_data, + files=pika_files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + + return self.execute_task(initial_operation, auth_token) + + +class PikAdditionsNode(PikaNodeBase): + """Pika Pikadditions Node. Add an image into a video.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": (IO.VIDEO, {"tooltip": "The video to add an image to."}), + "image": (IO.IMAGE, {"tooltip": "The image to add to the video."}), + "prompt_text": model_field_to_node_input( + IO.STRING, + PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + "promptText", + multiline=True, + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + "negativePrompt", + multiline=True, + ), + "seed": model_field_to_node_input( + IO.INT, + PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + "seed", + min=0, + max=0xFFFFFFFF, + control_after_generate=True, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result." + + def api_call( + self, + video: VideoInput, + image: torch.Tensor, + prompt_text: str, + negative_prompt: str, + seed: int, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + # Convert video to BytesIO + video_bytes_io = io.BytesIO() + video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) + video_bytes_io.seek(0) + + # Convert image to BytesIO + image_bytes_io = tensor_to_bytesio(image) + image_bytes_io.seek(0) + + pika_files = [ + ("video", ("video.mp4", video_bytes_io, "video/mp4")), + ("image", ("image.png", image_bytes_io, "image/png")), + ] + + # Prepare non-file data + pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost( + promptText=prompt_text, + negativePrompt=negative_prompt, + seed=seed, + ) + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_PIKADDITIONS, + method=HttpMethod.POST, + request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + response_model=PikaGenerateResponse, + ), + request=pika_request_data, + files=pika_files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + + return self.execute_task(initial_operation, auth_token) + + +class PikaSwapsNode(PikaNodeBase): + """Pika Pikaswaps Node.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}), + "image": ( + IO.IMAGE, + { + "tooltip": "The image used to replace the masked object in the video." + }, + ), + "mask": ( + IO.MASK, + {"tooltip": "Use the mask to define areas in the video to replace"}, + ), + "prompt_text": model_field_to_node_input( + IO.STRING, + PikaBodyGeneratePikaswapsGeneratePikaswapsPost, + "promptText", + multiline=True, + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + PikaBodyGeneratePikaswapsGeneratePikaswapsPost, + "negativePrompt", + multiline=True, + ), + "seed": model_field_to_node_input( + IO.INT, + PikaBodyGeneratePikaswapsGeneratePikaswapsPost, + "seed", + min=0, + max=0xFFFFFFFF, + control_after_generate=True, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates." + RETURN_TYPES = ("VIDEO",) + + def api_call( + self, + video: VideoInput, + image: torch.Tensor, + mask: torch.Tensor, + prompt_text: str, + negative_prompt: str, + seed: int, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + # Convert video to BytesIO + video_bytes_io = io.BytesIO() + video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) + video_bytes_io.seek(0) + + # Convert mask to binary mask with three channels + mask = torch.round(mask) + mask = mask.repeat(1, 3, 1, 1) + + # Convert 3-channel binary mask to BytesIO + mask_bytes_io = io.BytesIO() + mask_bytes_io.write(mask.numpy().astype(np.uint8)) + mask_bytes_io.seek(0) + + # Convert image to BytesIO + image_bytes_io = tensor_to_bytesio(image) + image_bytes_io.seek(0) + + pika_files = [ + ("video", ("video.mp4", video_bytes_io, "video/mp4")), + ("image", ("image.png", image_bytes_io, "image/png")), + ("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")), + ] + + # Prepare non-file data + pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost( + promptText=prompt_text, + negativePrompt=negative_prompt, + seed=seed, + ) + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_PIKADDITIONS, + method=HttpMethod.POST, + request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, + response_model=PikaGenerateResponse, + ), + request=pika_request_data, + files=pika_files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + + return self.execute_task(initial_operation, auth_token) + + +class PikaffectsNode(PikaNodeBase): + """Pika Pikaffects Node.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ( + IO.IMAGE, + {"tooltip": "The reference image to apply the Pikaffect to."}, + ), + "pikaffect": model_field_to_node_input( + IO.COMBO, + PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + "pikaffect", + enum_type=Pikaffect, + default="Cake-ify", + ), + "prompt_text": model_field_to_node_input( + IO.STRING, + PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + "promptText", + multiline=True, + ), + "negative_prompt": model_field_to_node_input( + IO.STRING, + PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + "negativePrompt", + multiline=True, + ), + "seed": model_field_to_node_input( + IO.INT, + PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + "seed", + min=0, + max=0xFFFFFFFF, + control_after_generate=True, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear" + + def api_call( + self, + image: torch.Tensor, + pikaffect: str, + prompt_text: str, + negative_prompt: str, + seed: int, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_PIKAFFECTS, + method=HttpMethod.POST, + request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost, + response_model=PikaGenerateResponse, + ), + request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost( + pikaffect=pikaffect, + promptText=prompt_text, + negativePrompt=negative_prompt, + seed=seed, + ), + files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, + content_type="multipart/form-data", + auth_token=auth_token, + ) + + return self.execute_task(initial_operation, auth_token) + + +class PikaStartEndFrameNode2_2(PikaNodeBase): + """PikaFrames v2.2 Node.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}), + "image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}), + **cls.get_base_inputs_types( + PikaBodyGenerate22KeyframeGenerate22PikaframesPost + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them." + + def api_call( + self, + image_start: torch.Tensor, + image_end: torch.Tensor, + prompt_text: str, + negative_prompt: str, + seed: int, + resolution: str, + duration: int, + auth_token: Optional[str] = None, + ) -> tuple[VideoFromFile]: + + pika_files = [ + ( + "keyFrames", + ("image_start.png", tensor_to_bytesio(image_start), "image/png"), + ), + ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), + ] + + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_PIKAFRAMES, + method=HttpMethod.POST, + request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost, + response_model=PikaGenerateResponse, + ), + request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost( + promptText=prompt_text, + negativePrompt=negative_prompt, + seed=seed, + resolution=resolution, + duration=duration, + ), + files=pika_files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + + return self.execute_task(initial_operation, auth_token) + + +NODE_CLASS_MAPPINGS = { + "PikaImageToVideoNode2_2": PikaImageToVideoV2_2, + "PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2, + "PikaScenesV2_2": PikaScenesV2_2, + "Pikadditions": PikAdditionsNode, + "Pikaswaps": PikaSwapsNode, + "Pikaffects": PikaffectsNode, + "PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PikaImageToVideoNode2_2": "Pika Image to Video", + "PikaTextToVideoNode2_2": "Pika Text to Video", + "PikaScenesV2_2": "Pika Scenes (Video Image Composition)", + "Pikadditions": "Pikadditions (Video Object Insertion)", + "Pikaswaps": "Pika Swaps (Video Object Replacement)", + "Pikaffects": "Pikaffects (Video Effects)", + "PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video", +} diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py new file mode 100644 index 00000000..dbb90c1d --- /dev/null +++ b/comfy_api_nodes/nodes_pixverse.py @@ -0,0 +1,492 @@ +from inspect import cleandoc + +from comfy_api_nodes.apis.pixverse_api import ( + PixverseTextVideoRequest, + PixverseImageVideoRequest, + PixverseTransitionVideoRequest, + PixverseImageUploadResponse, + PixverseVideoResponse, + PixverseGenerationStatusResponse, + PixverseAspectRatio, + PixverseQuality, + PixverseDuration, + PixverseMotionMode, + PixverseStatus, + PixverseIO, + pixverse_templates, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + tensor_to_bytesio, + validate_string, +) +from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.input_impl import VideoFromFile + +import torch +import requests +from io import BytesIO + + +def upload_image_to_pixverse(image: torch.Tensor, auth_token=None): + # first, upload image to Pixverse and get image id to use in actual generation call + files = { + "image": tensor_to_bytesio(image) + } + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/pixverse/image/upload", + method=HttpMethod.POST, + request_model=EmptyRequest, + response_model=PixverseImageUploadResponse, + ), + request=EmptyRequest(), + files=files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + response_upload: PixverseImageUploadResponse = operation.execute() + + if response_upload.Resp is None: + raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") + + return response_upload.Resp.img_id + + +class PixverseTemplateNode: + """ + Select template for PixVerse Video generation. + """ + + RETURN_TYPES = (PixverseIO.TEMPLATE,) + RETURN_NAMES = ("pixverse_template",) + FUNCTION = "create_template" + CATEGORY = "api node/video/PixVerse" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "template": (list(pixverse_templates.keys()), ), + } + } + + def create_template(self, template: str): + template_id = pixverse_templates.get(template, None) + if template_id is None: + raise Exception(f"Template '{template}' is not recognized.") + # just return the integer + return (template_id,) + + +class PixverseTextToVideoNode(ComfyNodeABC): + """ + Generates videos synchronously based on prompt and output_size. + """ + + RETURN_TYPES = (IO.VIDEO,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/video/PixVerse" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the video generation", + }, + ), + "aspect_ratio": ( + [ratio.value for ratio in PixverseAspectRatio], + ), + "quality": ( + [resolution.value for resolution in PixverseQuality], + { + "default": PixverseQuality.res_540p, + }, + ), + "duration_seconds": ([dur.value for dur in PixverseDuration],), + "motion_mode": ([mode.value for mode in PixverseMotionMode],), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2147483647, + "control_after_generate": True, + "tooltip": "Seed for video generation.", + }, + ), + }, + "optional": { + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + "pixverse_template": ( + PixverseIO.TEMPLATE, + { + "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." + } + ) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + aspect_ratio: str, + quality: str, + duration_seconds: int, + motion_mode: str, + seed, + negative_prompt: str=None, + pixverse_template: int=None, + auth_token=None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False) + # 1080p is limited to 5 seconds duration + # only normal motion_mode supported for 1080p or for non-5 second duration + if quality == PixverseQuality.res_1080p: + motion_mode = PixverseMotionMode.normal + duration_seconds = PixverseDuration.dur_5 + elif duration_seconds != PixverseDuration.dur_5: + motion_mode = PixverseMotionMode.normal + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/pixverse/video/text/generate", + method=HttpMethod.POST, + request_model=PixverseTextVideoRequest, + response_model=PixverseVideoResponse, + ), + request=PixverseTextVideoRequest( + prompt=prompt, + aspect_ratio=aspect_ratio, + quality=quality, + duration=duration_seconds, + motion_mode=motion_mode, + negative_prompt=negative_prompt if negative_prompt else None, + template_id=pixverse_template, + seed=seed, + ), + auth_token=auth_token, + ) + response_api = operation.execute() + + if response_api.Resp is None: + raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=PixverseGenerationStatusResponse, + ), + completed_statuses=[PixverseStatus.successful], + failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + status_extractor=lambda x: x.Resp.status, + auth_token=auth_token, + ) + response_poll = operation.execute() + + vid_response = requests.get(response_poll.Resp.url) + return (VideoFromFile(BytesIO(vid_response.content)),) + + +class PixverseImageToVideoNode(ComfyNodeABC): + """ + Generates videos synchronously based on prompt and output_size. + """ + + RETURN_TYPES = (IO.VIDEO,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/video/PixVerse" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ( + IO.IMAGE, + ), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the video generation", + }, + ), + "quality": ( + [resolution.value for resolution in PixverseQuality], + { + "default": PixverseQuality.res_540p, + }, + ), + "duration_seconds": ([dur.value for dur in PixverseDuration],), + "motion_mode": ([mode.value for mode in PixverseMotionMode],), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2147483647, + "control_after_generate": True, + "tooltip": "Seed for video generation.", + }, + ), + }, + "optional": { + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + "pixverse_template": ( + PixverseIO.TEMPLATE, + { + "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node." + } + ) + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + image: torch.Tensor, + prompt: str, + quality: str, + duration_seconds: int, + motion_mode: str, + seed, + negative_prompt: str=None, + pixverse_template: int=None, + auth_token=None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False) + img_id = upload_image_to_pixverse(image, auth_token=auth_token) + + # 1080p is limited to 5 seconds duration + # only normal motion_mode supported for 1080p or for non-5 second duration + if quality == PixverseQuality.res_1080p: + motion_mode = PixverseMotionMode.normal + duration_seconds = PixverseDuration.dur_5 + elif duration_seconds != PixverseDuration.dur_5: + motion_mode = PixverseMotionMode.normal + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/pixverse/video/img/generate", + method=HttpMethod.POST, + request_model=PixverseImageVideoRequest, + response_model=PixverseVideoResponse, + ), + request=PixverseImageVideoRequest( + img_id=img_id, + prompt=prompt, + quality=quality, + duration=duration_seconds, + motion_mode=motion_mode, + negative_prompt=negative_prompt if negative_prompt else None, + template_id=pixverse_template, + seed=seed, + ), + auth_token=auth_token, + ) + response_api = operation.execute() + + if response_api.Resp is None: + raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=PixverseGenerationStatusResponse, + ), + completed_statuses=[PixverseStatus.successful], + failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + status_extractor=lambda x: x.Resp.status, + auth_token=auth_token, + ) + response_poll = operation.execute() + + vid_response = requests.get(response_poll.Resp.url) + return (VideoFromFile(BytesIO(vid_response.content)),) + + +class PixverseTransitionVideoNode(ComfyNodeABC): + """ + Generates videos synchronously based on prompt and output_size. + """ + + RETURN_TYPES = (IO.VIDEO,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/video/PixVerse" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "first_frame": ( + IO.IMAGE, + ), + "last_frame": ( + IO.IMAGE, + ), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the video generation", + }, + ), + "quality": ( + [resolution.value for resolution in PixverseQuality], + { + "default": PixverseQuality.res_540p, + }, + ), + "duration_seconds": ([dur.value for dur in PixverseDuration],), + "motion_mode": ([mode.value for mode in PixverseMotionMode],), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 2147483647, + "control_after_generate": True, + "tooltip": "Seed for video generation.", + }, + ), + }, + "optional": { + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + first_frame: torch.Tensor, + last_frame: torch.Tensor, + prompt: str, + quality: str, + duration_seconds: int, + motion_mode: str, + seed, + negative_prompt: str=None, + auth_token=None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False) + first_frame_id = upload_image_to_pixverse(first_frame, auth_token=auth_token) + last_frame_id = upload_image_to_pixverse(last_frame, auth_token=auth_token) + + # 1080p is limited to 5 seconds duration + # only normal motion_mode supported for 1080p or for non-5 second duration + if quality == PixverseQuality.res_1080p: + motion_mode = PixverseMotionMode.normal + duration_seconds = PixverseDuration.dur_5 + elif duration_seconds != PixverseDuration.dur_5: + motion_mode = PixverseMotionMode.normal + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/pixverse/video/transition/generate", + method=HttpMethod.POST, + request_model=PixverseTransitionVideoRequest, + response_model=PixverseVideoResponse, + ), + request=PixverseTransitionVideoRequest( + first_frame_img=first_frame_id, + last_frame_img=last_frame_id, + prompt=prompt, + quality=quality, + duration=duration_seconds, + motion_mode=motion_mode, + negative_prompt=negative_prompt if negative_prompt else None, + seed=seed, + ), + auth_token=auth_token, + ) + response_api = operation.execute() + + if response_api.Resp is None: + raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=PixverseGenerationStatusResponse, + ), + completed_statuses=[PixverseStatus.successful], + failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted], + status_extractor=lambda x: x.Resp.status, + auth_token=auth_token, + ) + response_poll = operation.execute() + + vid_response = requests.get(response_poll.Resp.url) + return (VideoFromFile(BytesIO(vid_response.content)),) + + +NODE_CLASS_MAPPINGS = { + "PixverseTextToVideoNode": PixverseTextToVideoNode, + "PixverseImageToVideoNode": PixverseImageToVideoNode, + "PixverseTransitionVideoNode": PixverseTransitionVideoNode, + "PixverseTemplateNode": PixverseTemplateNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PixverseTextToVideoNode": "PixVerse Text to Video", + "PixverseImageToVideoNode": "PixVerse Image to Video", + "PixverseTransitionVideoNode": "PixVerse Transition Video", + "PixverseTemplateNode": "PixVerse Template", +} diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py new file mode 100644 index 00000000..994f377d --- /dev/null +++ b/comfy_api_nodes/nodes_recraft.py @@ -0,0 +1,1217 @@ +from __future__ import annotations +from inspect import cleandoc +from comfy.utils import ProgressBar +from comfy.comfy_types.node_typing import IO +from comfy_api_nodes.apis.recraft_api import ( + RecraftImageGenerationRequest, + RecraftImageGenerationResponse, + RecraftImageSize, + RecraftModel, + RecraftStyle, + RecraftStyleV3, + RecraftColor, + RecraftColorChain, + RecraftControls, + RecraftIO, + get_v3_substyles, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + bytesio_to_image_tensor, + download_url_to_bytesio, + tensor_to_bytesio, + resize_mask_to_image, + validate_string, +) +import folder_paths +import json +import os +import torch +from io import BytesIO +from PIL import UnidentifiedImageError + + +def handle_recraft_file_request( + image: torch.Tensor, + path: str, + mask: torch.Tensor=None, + total_pixels=4096*4096, + timeout=1024, + request=None, + auth_token=None + ) -> list[BytesIO]: + """ + Handle sending common Recraft file-only request to get back file bytes. + """ + if request is None: + request = EmptyRequest() + + files = { + 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() + } + if mask is not None: + files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=path, + method=HttpMethod.POST, + request_model=type(request), + response_model=RecraftImageGenerationResponse, + ), + request=request, + files=files, + content_type="multipart/form-data", + auth_token=auth_token, + multipart_parser=recraft_multipart_parser, + ) + response: RecraftImageGenerationResponse = operation.execute() + all_bytesio = [] + if response.image is not None: + all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout)) + else: + for data in response.data: + all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout)) + + return all_bytesio + + +def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict: + """ + Formats data such that multipart/form-data will work with requests library + when both files and data are present. + + The OpenAI client that Recraft uses has a bizarre way of serializing lists: + + It does NOT keep track of indeces of each list, so for background_color, that must be serialized as: + 'background_color[rgb][]' = [0, 0, 255] + where the array is assigned to a key that has '[]' at the end, to signal it's an array. + + This has the consequence of nested lists having the exact same key, forcing arrays to merge; all colors inputs fall under the same key: + if 1 color -> 'controls[colors][][rgb][]' = [0, 0, 255] + if 2 colors -> 'controls[colors][][rgb][]' = [0, 0, 255, 255, 0, 0] + if 3 colors -> 'controls[colors][][rgb][]' = [0, 0, 255, 255, 0, 0, 0, 255, 0] + etc. + Whoever made this serialization up at OpenAI added the constraint that lists must be of uniform length on objects of same 'type'. + """ + # Modification of a function that handled a different type of multipart parsing, big ups: + # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b + + def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]): + # if list already exists exists, just extend list with data + for check_list in lists_to_check: + for conv_tuple in check_list: + if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list: + conv_tuple[1].append(formatter(data)) + return True + return False + + if converted_to_check is None: + converted_to_check = [] + + + if formatter is None: + formatter = lambda v: v # Multipart representation of value + + if type(data) is not dict: + # if list already exists exists, just extend list with data + added = handle_converted_lists(data, parent_key, converted_to_check) + if added: + return {} + # otherwise if is_list, create new list with data + if is_list: + return {parent_key: [formatter(data)]} + # return new key with data + return {parent_key: formatter(data)} + + converted = [] + next_check = [converted] + next_check.extend(converted_to_check) + + for key, value in data.items(): + current_key = key if parent_key is None else f"{parent_key}[{key}]" + if type(value) is dict: + converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items()) + elif type(value) is list: + for ind, list_value in enumerate(value): + iter_key = f"{current_key}[]" + converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) + else: + converted.append((current_key, formatter(value))) + + return dict(converted) + + +class handle_recraft_image_output: + """ + Catch an exception related to receiving SVG data instead of image, when Infinite Style Library style_id is in use. + """ + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None and exc_type is UnidentifiedImageError: + raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") + + +class SVG: + """ + Stores SVG representations via a list of BytesIO objects. + """ + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: SVG): + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list[SVG]): + all_svgs = [] + for svg in svgs: + all_svgs.extend(svg.data) + return SVG(all_svgs) + + +class SaveSVGNode: + """ + Save SVG files on disk. + """ + + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + RETURN_TYPES = () + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "save_svg" + CATEGORY = "api node/image/Recraft" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "svg": (RecraftIO.SVG,), + "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + }, + "hidden": { + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO" + } + } + + def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + results = list() + + # Prepare metadata JSON + metadata_dict = {} + if prompt is not None: + metadata_dict["prompt"] = prompt + if extra_pnginfo is not None: + metadata_dict.update(extra_pnginfo) + + # Convert metadata to JSON string + metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + + for batch_number, svg_bytes in enumerate(svg.data): + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.svg" + + # Read SVG content + svg_bytes.seek(0) + svg_content = svg_bytes.read().decode('utf-8') + + # Inject metadata if available + if metadata_json: + # Create metadata element with CDATA section + metadata_element = f""" + + +""" + # Insert metadata after opening svg tag using regex + import re + svg_content = re.sub(r'(]*>)', r'\1\n' + metadata_element, svg_content) + + # Write the modified SVG to file + with open(os.path.join(full_output_folder, file), 'wb') as svg_file: + svg_file.write(svg_content.encode('utf-8')) + + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + return { "ui": { "images": results } } + + +class RecraftColorRGBNode: + """ + Create Recraft Color by choosing specific RGB values. + """ + + RETURN_TYPES = (RecraftIO.COLOR,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + RETURN_NAMES = ("recraft_color",) + FUNCTION = "create_color" + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "r": (IO.INT, { + "default": 0, + "min": 0, + "max": 255, + "tooltip": "Red value of color." + }), + "g": (IO.INT, { + "default": 0, + "min": 0, + "max": 255, + "tooltip": "Green value of color." + }), + "b": (IO.INT, { + "default": 0, + "min": 0, + "max": 255, + "tooltip": "Blue value of color." + }), + }, + "optional": { + "recraft_color": (RecraftIO.COLOR,), + } + } + + def create_color(self, r: int, g: int, b: int, recraft_color: RecraftColorChain=None): + recraft_color = recraft_color.clone() if recraft_color else RecraftColorChain() + recraft_color.add(RecraftColor(r, g, b)) + return (recraft_color, ) + + +class RecraftControlsNode: + """ + Create Recraft Controls for customizing Recraft generation. + """ + + RETURN_TYPES = (RecraftIO.CONTROLS,) + RETURN_NAMES = ("recraft_controls",) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "create_controls" + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "colors": (RecraftIO.COLOR,), + "background_color": (RecraftIO.COLOR,), + } + } + + def create_controls(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None): + return (RecraftControls(colors=colors, background_color=background_color), ) + + +class RecraftStyleV3RealisticImageNode: + """ + Select realistic_image style and optional substyle. + """ + + RETURN_TYPES = (RecraftIO.STYLEV3,) + RETURN_NAMES = ("recraft_style",) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "create_style" + CATEGORY = "api node/image/Recraft" + + RECRAFT_STYLE = RecraftStyleV3.realistic_image + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "substyle": (get_v3_substyles(s.RECRAFT_STYLE),), + } + } + + def create_style(self, substyle: str): + if substyle == "None": + substyle = None + return (RecraftStyle(self.RECRAFT_STYLE, substyle),) + + +class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): + """ + Select digital_illustration style and optional substyle. + """ + + RECRAFT_STYLE = RecraftStyleV3.digital_illustration + + +class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): + """ + Select vector_illustration style and optional substyle. + """ + + RECRAFT_STYLE = RecraftStyleV3.vector_illustration + + +class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): + """ + Select vector_illustration style and optional substyle. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "substyle": (get_v3_substyles(s.RECRAFT_STYLE, include_none=False),), + } + } + + RECRAFT_STYLE = RecraftStyleV3.logo_raster + + +class RecraftStyleInfiniteStyleLibrary: + """ + Select style based on preexisting UUID from Recraft's Infinite Style Library. + """ + + RETURN_TYPES = (RecraftIO.STYLEV3,) + RETURN_NAMES = ("recraft_style",) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "create_style" + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "style_id": (IO.STRING, { + "default": "", + "tooltip": "UUID of style from Infinite Style Library.", + }) + } + } + + def create_style(self, style_id: str): + if not style_id: + raise Exception("The style_id input cannot be empty.") + return (RecraftStyle(style_id=style_id),) + + +class RecraftTextToImageNode: + """ + Generates images synchronously based on prompt and resolution. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation.", + }, + ), + "size": ( + [res.value for res in RecraftImageSize], + { + "default": RecraftImageSize.res_1024x1024, + "tooltip": "The size of the generated image.", + }, + ), + "n": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 6, + "tooltip": "The number of images to generate.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": { + "recraft_style": (RecraftIO.STYLEV3,), + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + "recraft_controls": ( + RecraftIO.CONTROLS, + { + "tooltip": "Optional additional controls over the generation via the Recraft Controls node." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + size: str, + n: int, + seed, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + recraft_controls: RecraftControls = None, + auth_token=None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False, max_length=1000) + default_style = RecraftStyle(RecraftStyleV3.realistic_image) + if recraft_style is None: + recraft_style = default_style + + controls_api = None + if recraft_controls: + controls_api = recraft_controls.create_api_model() + + if not negative_prompt: + negative_prompt = None + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/recraft/image_generation", + method=HttpMethod.POST, + request_model=RecraftImageGenerationRequest, + response_model=RecraftImageGenerationResponse, + ), + request=RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt, + model=RecraftModel.recraftv3, + size=size, + n=n, + style=recraft_style.style, + substyle=recraft_style.substyle, + style_id=recraft_style.style_id, + controls=controls_api, + ), + auth_token=auth_token, + ) + response: RecraftImageGenerationResponse = operation.execute() + images = [] + for data in response.data: + with handle_recraft_image_output(): + image = bytesio_to_image_tensor( + download_url_to_bytesio(data.url, timeout=1024) + ) + if len(image.shape) < 4: + image = image.unsqueeze(0) + images.append(image) + output_image = torch.cat(images, dim=0) + + return (output_image,) + + +class RecraftImageToImageNode: + """ + Modify image based on prompt and strength. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE, ), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation.", + }, + ), + "n": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 6, + "tooltip": "The number of images to generate.", + }, + ), + "strength": ( + IO.FLOAT, + { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity." + } + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": { + "recraft_style": (RecraftIO.STYLEV3,), + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + "recraft_controls": ( + RecraftIO.CONTROLS, + { + "tooltip": "Optional additional controls over the generation via the Recraft Controls node." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + image: torch.Tensor, + prompt: str, + n: int, + strength: float, + seed, + auth_token=None, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + recraft_controls: RecraftControls = None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False, max_length=1000) + default_style = RecraftStyle(RecraftStyleV3.realistic_image) + if recraft_style is None: + recraft_style = default_style + + controls_api = None + if recraft_controls: + controls_api = recraft_controls.create_api_model() + + if not negative_prompt: + negative_prompt = None + + request = RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt, + model=RecraftModel.recraftv3, + n=n, + strength=round(strength, 2), + style=recraft_style.style, + substyle=recraft_style.substyle, + style_id=recraft_style.style_id, + controls=controls_api, + ) + + images = [] + total = image.shape[0] + pbar = ProgressBar(total) + for i in range(total): + sub_bytes = handle_recraft_file_request( + image=image[i], + path="/proxy/recraft/images/imageToImage", + request=request, + auth_token=auth_token, + ) + with handle_recraft_image_output(): + images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) + pbar.update(1) + + images_tensor = torch.cat(images, dim=0) + return (images_tensor, ) + + +class RecraftImageInpaintingNode: + """ + Modify image based on prompt and mask. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE, ), + "mask": (IO.MASK, ), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation.", + }, + ), + "n": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 6, + "tooltip": "The number of images to generate.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": { + "recraft_style": (RecraftIO.STYLEV3,), + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + image: torch.Tensor, + mask: torch.Tensor, + prompt: str, + n: int, + seed, + auth_token=None, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False, max_length=1000) + default_style = RecraftStyle(RecraftStyleV3.realistic_image) + if recraft_style is None: + recraft_style = default_style + + if not negative_prompt: + negative_prompt = None + + request = RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt, + model=RecraftModel.recraftv3, + n=n, + style=recraft_style.style, + substyle=recraft_style.substyle, + style_id=recraft_style.style_id, + ) + + # prepare mask tensor + mask = resize_mask_to_image(mask, image, allow_gradient=False, add_channel_dim=True) + + images = [] + total = image.shape[0] + pbar = ProgressBar(total) + for i in range(total): + sub_bytes = handle_recraft_file_request( + image=image[i], + mask=mask[i:i+1], + path="/proxy/recraft/images/inpaint", + request=request, + auth_token=auth_token, + ) + with handle_recraft_image_output(): + images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) + pbar.update(1) + + images_tensor = torch.cat(images, dim=0) + return (images_tensor, ) + + +class RecraftTextToVectorNode: + """ + Generates SVG synchronously based on prompt and resolution. + """ + + RETURN_TYPES = (RecraftIO.SVG,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation.", + }, + ), + "substyle": (get_v3_substyles(RecraftStyleV3.vector_illustration),), + "size": ( + [res.value for res in RecraftImageSize], + { + "default": RecraftImageSize.res_1024x1024, + "tooltip": "The size of the generated image.", + }, + ), + "n": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 6, + "tooltip": "The number of images to generate.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": { + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + "recraft_controls": ( + RecraftIO.CONTROLS, + { + "tooltip": "Optional additional controls over the generation via the Recraft Controls node." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + prompt: str, + substyle: str, + size: str, + n: int, + seed, + negative_prompt: str = None, + recraft_controls: RecraftControls = None, + auth_token=None, + **kwargs, + ): + validate_string(prompt, strip_whitespace=False, max_length=1000) + # create RecraftStyle so strings will be formatted properly (i.e. "None" will become None) + recraft_style = RecraftStyle(RecraftStyleV3.vector_illustration, substyle=substyle) + + controls_api = None + if recraft_controls: + controls_api = recraft_controls.create_api_model() + + if not negative_prompt: + negative_prompt = None + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/recraft/image_generation", + method=HttpMethod.POST, + request_model=RecraftImageGenerationRequest, + response_model=RecraftImageGenerationResponse, + ), + request=RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt, + model=RecraftModel.recraftv3, + size=size, + n=n, + style=recraft_style.style, + substyle=recraft_style.substyle, + controls=controls_api, + ), + auth_token=auth_token, + ) + response: RecraftImageGenerationResponse = operation.execute() + svg_data = [] + for data in response.data: + svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) + + return (SVG(svg_data),) + + +class RecraftVectorizeImageNode: + """ + Generates SVG synchronously from an input image. + """ + + RETURN_TYPES = (RecraftIO.SVG,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE, ), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + image: torch.Tensor, + auth_token=None, + **kwargs, + ): + svgs = [] + total = image.shape[0] + pbar = ProgressBar(total) + for i in range(total): + sub_bytes = handle_recraft_file_request( + image=image[i], + path="/proxy/recraft/images/vectorize", + auth_token=auth_token, + ) + svgs.append(SVG(sub_bytes)) + pbar.update(1) + + return (SVG.combine_all(svgs), ) + + +class RecraftReplaceBackgroundNode: + """ + Replace background on image, based on provided prompt. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE, ), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Prompt for the image generation.", + }, + ), + "n": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 6, + "tooltip": "The number of images to generate.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", + }, + ), + }, + "optional": { + "recraft_style": (RecraftIO.STYLEV3,), + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "An optional text description of undesired elements on an image.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + image: torch.Tensor, + prompt: str, + n: int, + seed, + auth_token=None, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + **kwargs, + ): + default_style = RecraftStyle(RecraftStyleV3.realistic_image) + if recraft_style is None: + recraft_style = default_style + + if not negative_prompt: + negative_prompt = None + + request = RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt, + model=RecraftModel.recraftv3, + n=n, + style=recraft_style.style, + substyle=recraft_style.substyle, + style_id=recraft_style.style_id, + ) + + images = [] + total = image.shape[0] + pbar = ProgressBar(total) + for i in range(total): + sub_bytes = handle_recraft_file_request( + image=image[i], + path="/proxy/recraft/images/replaceBackground", + request=request, + auth_token=auth_token, + ) + images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) + pbar.update(1) + + images_tensor = torch.cat(images, dim=0) + return (images_tensor, ) + + +class RecraftRemoveBackgroundNode: + """ + Remove background from image, and return processed image and mask. + """ + + RETURN_TYPES = (IO.IMAGE, IO.MASK) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE, ), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + image: torch.Tensor, + auth_token=None, + **kwargs, + ): + images = [] + total = image.shape[0] + pbar = ProgressBar(total) + for i in range(total): + sub_bytes = handle_recraft_file_request( + image=image[i], + path="/proxy/recraft/images/removeBackground", + auth_token=auth_token, + ) + images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) + pbar.update(1) + + images_tensor = torch.cat(images, dim=0) + # use alpha channel as masks, in B,H,W format + masks_tensor = images_tensor[:,:,:,-1:].squeeze(-1) + return (images_tensor, masks_tensor) + + +class RecraftCrispUpscaleNode: + """ + Upscale image synchronously. + Enhances a given raster image using ‘crisp upscale’ tool, increasing image resolution, making the image sharper and cleaner. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + RECRAFT_PATH = "/proxy/recraft/images/crispUpscale" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE, ), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call( + self, + image: torch.Tensor, + auth_token=None, + **kwargs, + ): + images = [] + total = image.shape[0] + pbar = ProgressBar(total) + for i in range(total): + sub_bytes = handle_recraft_file_request( + image=image[i], + path=self.RECRAFT_PATH, + auth_token=auth_token, + ) + images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) + pbar.update(1) + + images_tensor = torch.cat(images, dim=0) + return (images_tensor,) + + +class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): + """ + Upscale image synchronously. + Enhances a given raster image using ‘creative upscale’ tool, boosting resolution with a focus on refining small details and faces. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Recraft" + + RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale" + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "RecraftTextToImageNode": RecraftTextToImageNode, + "RecraftImageToImageNode": RecraftImageToImageNode, + "RecraftImageInpaintingNode": RecraftImageInpaintingNode, + "RecraftTextToVectorNode": RecraftTextToVectorNode, + "RecraftVectorizeImageNode": RecraftVectorizeImageNode, + "RecraftRemoveBackgroundNode": RecraftRemoveBackgroundNode, + "RecraftReplaceBackgroundNode": RecraftReplaceBackgroundNode, + "RecraftCrispUpscaleNode": RecraftCrispUpscaleNode, + "RecraftCreativeUpscaleNode": RecraftCreativeUpscaleNode, + "RecraftStyleV3RealisticImage": RecraftStyleV3RealisticImageNode, + "RecraftStyleV3DigitalIllustration": RecraftStyleV3DigitalIllustrationNode, + "RecraftStyleV3LogoRaster": RecraftStyleV3LogoRasterNode, + "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, + "RecraftColorRGB": RecraftColorRGBNode, + "RecraftControls": RecraftControlsNode, + "SaveSVG": SaveSVGNode, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "RecraftTextToImageNode": "Recraft Text to Image", + "RecraftImageToImageNode": "Recraft Image to Image", + "RecraftImageInpaintingNode": "Recraft Image Inpainting", + "RecraftTextToVectorNode": "Recraft Text to Vector", + "RecraftVectorizeImageNode": "Recraft Vectorize Image", + "RecraftRemoveBackgroundNode": "Recraft Remove Background", + "RecraftReplaceBackgroundNode": "Recraft Replace Background", + "RecraftCrispUpscaleNode": "Recraft Crisp Upscale Image", + "RecraftCreativeUpscaleNode": "Recraft Creative Upscale Image", + "RecraftStyleV3RealisticImage": "Recraft Style - Realistic Image", + "RecraftStyleV3DigitalIllustration": "Recraft Style - Digital Illustration", + "RecraftStyleV3LogoRaster": "Recraft Style - Logo Raster", + "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", + "RecraftColorRGB": "Recraft Color RGB", + "RecraftControls": "Recraft Controls", + "SaveSVG": "Save SVG", +} diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py new file mode 100644 index 00000000..52fe2417 --- /dev/null +++ b/comfy_api_nodes/nodes_stability.py @@ -0,0 +1,609 @@ +from inspect import cleandoc +from comfy.comfy_types.node_typing import IO +from comfy_api_nodes.apis.stability_api import ( + StabilityUpscaleConservativeRequest, + StabilityUpscaleCreativeRequest, + StabilityAsyncResponse, + StabilityResultsGetResponse, + StabilityStable3_5Request, + StabilityStableUltraRequest, + StabilityStableUltraResponse, + StabilityAspectRatio, + Stability_SD3_5_Model, + Stability_SD3_5_GenerationMode, + get_stability_style_presets, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import ( + bytesio_to_image_tensor, + tensor_to_bytesio, + validate_string, +) + +import torch +import base64 +from io import BytesIO +from enum import Enum + + +class StabilityPollStatus(str, Enum): + finished = "finished" + in_progress = "in_progress" + failed = "failed" + + +def get_async_dummy_status(x: StabilityResultsGetResponse): + if x.name is not None or x.errors is not None: + return StabilityPollStatus.failed + elif x.finish_reason is not None: + return StabilityPollStatus.finished + return StabilityPollStatus.in_progress + + +class StabilityStableImageUltraNode: + """ + Generates images synchronously based on prompt and resolution. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Stability AI" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + + "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + + "elements, colors, and subjects will lead to better results. " + + "To control the weight of a given word use the format `(word:weight)`," + + "where `word` is the word you'd like to control the weight of and `weight`" + + "is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" + + "would convey a sky that was blue and green, but more green than blue." + }, + ), + "aspect_ratio": ([x.value for x in StabilityAspectRatio], + { + "default": StabilityAspectRatio.ratio_1_1, + "tooltip": "Aspect ratio of generated image.", + }, + ), + "style_preset": (get_stability_style_presets(), + { + "tooltip": "Optional desired style of generated image.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 4294967294, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + "image": (IO.IMAGE,), + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature." + }, + ), + "image_denoise": ( + IO.FLOAT, + { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, + negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, + auth_token=None): + validate_string(prompt, strip_whitespace=False) + # prepare image binary if image present + image_binary = None + if image is not None: + image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read() + else: + image_denoise = None + + if not negative_prompt: + negative_prompt = None + if style_preset == "None": + style_preset = None + + files = { + "image": image_binary + } + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/stable-image/generate/ultra", + method=HttpMethod.POST, + request_model=StabilityStableUltraRequest, + response_model=StabilityStableUltraResponse, + ), + request=StabilityStableUltraRequest( + prompt=prompt, + negative_prompt=negative_prompt, + aspect_ratio=aspect_ratio, + seed=seed, + strength=image_denoise, + style_preset=style_preset, + ), + files=files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + response_api = operation.execute() + + if response_api.finish_reason != "SUCCESS": + raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") + + image_data = base64.b64decode(response_api.image) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + + return (returned_image,) + + +class StabilityStableImageSD_3_5Node: + """ + Generates images synchronously based on prompt and resolution. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Stability AI" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." + }, + ), + "model": ([x.value for x in Stability_SD3_5_Model],), + "aspect_ratio": ([x.value for x in StabilityAspectRatio], + { + "default": StabilityAspectRatio.ratio_1_1, + "tooltip": "Aspect ratio of generated image.", + }, + ), + "style_preset": (get_stability_style_presets(), + { + "tooltip": "Optional desired style of generated image.", + }, + ), + "cfg_scale": ( + IO.FLOAT, + { + "default": 4.0, + "min": 1.0, + "max": 10.0, + "step": 0.1, + "tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 4294967294, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + "image": (IO.IMAGE,), + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." + }, + ), + "image_denoise": ( + IO.FLOAT, + { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, + negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, + auth_token=None): + validate_string(prompt, strip_whitespace=False) + # prepare image binary if image present + image_binary = None + mode = Stability_SD3_5_GenerationMode.text_to_image + if image is not None: + image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read() + mode = Stability_SD3_5_GenerationMode.image_to_image + aspect_ratio = None + else: + image_denoise = None + + if not negative_prompt: + negative_prompt = None + if style_preset == "None": + style_preset = None + + files = { + "image": image_binary + } + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/stable-image/generate/sd3", + method=HttpMethod.POST, + request_model=StabilityStable3_5Request, + response_model=StabilityStableUltraResponse, + ), + request=StabilityStable3_5Request( + prompt=prompt, + negative_prompt=negative_prompt, + aspect_ratio=aspect_ratio, + seed=seed, + strength=image_denoise, + style_preset=style_preset, + cfg_scale=cfg_scale, + model=model, + mode=mode, + ), + files=files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + response_api = operation.execute() + + if response_api.finish_reason != "SUCCESS": + raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") + + image_data = base64.b64decode(response_api.image) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + + return (returned_image,) + + +class StabilityUpscaleConservativeNode: + """ + Upscale image with minimal alterations to 4K resolution. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Stability AI" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." + }, + ), + "creativity": ( + IO.FLOAT, + { + "default": 0.35, + "min": 0.2, + "max": 0.5, + "step": 0.01, + "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 4294967294, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, + auth_token=None): + validate_string(prompt, strip_whitespace=False) + image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() + + if not negative_prompt: + negative_prompt = None + + files = { + "image": image_binary + } + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/stable-image/upscale/conservative", + method=HttpMethod.POST, + request_model=StabilityUpscaleConservativeRequest, + response_model=StabilityStableUltraResponse, + ), + request=StabilityUpscaleConservativeRequest( + prompt=prompt, + negative_prompt=negative_prompt, + creativity=round(creativity,2), + seed=seed, + ), + files=files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + response_api = operation.execute() + + if response_api.finish_reason != "SUCCESS": + raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") + + image_data = base64.b64decode(response_api.image) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + + return (returned_image,) + + +class StabilityUpscaleCreativeNode: + """ + Upscale image with minimal alterations to 4K resolution. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Stability AI" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." + }, + ), + "creativity": ( + IO.FLOAT, + { + "default": 0.3, + "min": 0.1, + "max": 0.5, + "step": 0.01, + "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", + }, + ), + "style_preset": (get_stability_style_presets(), + { + "tooltip": "Optional desired style of generated image.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 4294967294, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + }, + "optional": { + "negative_prompt": ( + IO.STRING, + { + "default": "", + "forceInput": True, + "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, + auth_token=None): + validate_string(prompt, strip_whitespace=False) + image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() + + if not negative_prompt: + negative_prompt = None + if style_preset == "None": + style_preset = None + + files = { + "image": image_binary + } + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/stable-image/upscale/creative", + method=HttpMethod.POST, + request_model=StabilityUpscaleCreativeRequest, + response_model=StabilityAsyncResponse, + ), + request=StabilityUpscaleCreativeRequest( + prompt=prompt, + negative_prompt=negative_prompt, + creativity=round(creativity,2), + style_preset=style_preset, + seed=seed, + ), + files=files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + response_api = operation.execute() + + operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"/proxy/stability/v2beta/results/{response_api.id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=StabilityResultsGetResponse, + ), + poll_interval=3, + completed_statuses=[StabilityPollStatus.finished], + failed_statuses=[StabilityPollStatus.failed], + status_extractor=lambda x: get_async_dummy_status(x), + auth_token=auth_token, + ) + response_poll: StabilityResultsGetResponse = operation.execute() + + if response_poll.finish_reason != "SUCCESS": + raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") + + image_data = base64.b64decode(response_poll.result) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + + return (returned_image,) + + +class StabilityUpscaleFastNode: + """ + Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images. + """ + + RETURN_TYPES = (IO.IMAGE,) + DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value + FUNCTION = "api_call" + API_NODE = True + CATEGORY = "api node/image/Stability AI" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + }, + "optional": { + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + def api_call(self, image: torch.Tensor, + auth_token=None): + image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() + + files = { + "image": image_binary + } + + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/stable-image/upscale/fast", + method=HttpMethod.POST, + request_model=EmptyRequest, + response_model=StabilityStableUltraResponse, + ), + request=EmptyRequest(), + files=files, + content_type="multipart/form-data", + auth_token=auth_token, + ) + response_api = operation.execute() + + if response_api.finish_reason != "SUCCESS": + raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") + + image_data = base64.b64decode(response_api.image) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + + return (returned_image,) + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "StabilityStableImageUltraNode": StabilityStableImageUltraNode, + "StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node, + "StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode, + "StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode, + "StabilityUpscaleFastNode": StabilityUpscaleFastNode, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "StabilityStableImageUltraNode": "Stability AI Stable Image Ultra", + "StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image", + "StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative", + "StabilityUpscaleCreativeNode": "Stability AI Upscale Creative", + "StabilityUpscaleFastNode": "Stability AI Upscale Fast", +} diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py new file mode 100644 index 00000000..9233944b --- /dev/null +++ b/comfy_api_nodes/nodes_veo2.py @@ -0,0 +1,283 @@ +import io +import logging +import base64 +import requests +import torch + +from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.input_impl.video_types import VideoFromFile +from comfy_api_nodes.apis import ( + Veo2GenVidRequest, + Veo2GenVidResponse, + Veo2GenVidPollRequest, + Veo2GenVidPollResponse +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, +) + +from comfy_api_nodes.apinode_utils import ( + downscale_image_tensor, + tensor_to_base64_string +) + +def convert_image_to_base64(image: torch.Tensor): + if image is None: + return None + + scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) + return tensor_to_base64_string(scaled_image) + +class VeoVideoGenerationNode(ComfyNodeABC): + """ + Generates videos from text prompts using Google's Veo API. + + This node can create videos from text descriptions and optional image inputs, + with control over parameters like aspect ratio, duration, and more. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Text description of the video", + }, + ), + "aspect_ratio": ( + IO.COMBO, + { + "options": ["16:9", "9:16"], + "default": "16:9", + "tooltip": "Aspect ratio of the output video", + }, + ), + }, + "optional": { + "negative_prompt": ( + IO.STRING, + { + "multiline": True, + "default": "", + "tooltip": "Negative text prompt to guide what to avoid in the video", + }, + ), + "duration_seconds": ( + IO.INT, + { + "default": 5, + "min": 5, + "max": 8, + "step": 1, + "display": "number", + "tooltip": "Duration of the output video in seconds", + }, + ), + "enhance_prompt": ( + IO.BOOLEAN, + { + "default": True, + "tooltip": "Whether to enhance the prompt with AI assistance", + } + ), + "person_generation": ( + IO.COMBO, + { + "options": ["ALLOW", "BLOCK"], + "default": "ALLOW", + "tooltip": "Whether to allow generating people in the video", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFF, + "step": 1, + "display": "number", + "control_after_generate": True, + "tooltip": "Seed for video generation (0 for random)", + }, + ), + "image": (IO.IMAGE, { + "default": None, + "tooltip": "Optional reference image to guide video generation", + }), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + }, + } + + RETURN_TYPES = (IO.VIDEO,) + FUNCTION = "generate_video" + CATEGORY = "api node/video/Veo" + DESCRIPTION = "Generates videos from text prompts using Google's Veo API" + API_NODE = True + + def generate_video( + self, + prompt, + aspect_ratio="16:9", + negative_prompt="", + duration_seconds=5, + enhance_prompt=True, + person_generation="ALLOW", + seed=0, + image=None, + auth_token=None, + ): + # Prepare the instances for the request + instances = [] + + instance = { + "prompt": prompt + } + + # Add image if provided + if image is not None: + image_base64 = convert_image_to_base64(image) + if image_base64: + instance["image"] = { + "bytesBase64Encoded": image_base64, + "mimeType": "image/png" + } + + instances.append(instance) + + # Create parameters dictionary + parameters = { + "aspectRatio": aspect_ratio, + "personGeneration": person_generation, + "durationSeconds": duration_seconds, + "enhancePrompt": enhance_prompt, + } + + # Add optional parameters if provided + if negative_prompt: + parameters["negativePrompt"] = negative_prompt + if seed > 0: + parameters["seed"] = seed + + # Initial request to start video generation + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/veo/generate", + method=HttpMethod.POST, + request_model=Veo2GenVidRequest, + response_model=Veo2GenVidResponse + ), + request=Veo2GenVidRequest( + instances=instances, + parameters=parameters + ), + auth_token=auth_token + ) + + initial_response = initial_operation.execute() + operation_name = initial_response.name + + logging.info(f"Veo generation started with operation name: {operation_name}") + + # Define status extractor function + def status_extractor(response): + # Only return "completed" if the operation is done, regardless of success or failure + # We'll check for errors after polling completes + return "completed" if response.done else "pending" + + # Define progress extractor function + def progress_extractor(response): + # Could be enhanced if the API provides progress information + return None + + # Define the polling operation + poll_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/veo/poll", + method=HttpMethod.POST, + request_model=Veo2GenVidPollRequest, + response_model=Veo2GenVidPollResponse + ), + completed_statuses=["completed"], + failed_statuses=[], # No failed statuses, we'll handle errors after polling + status_extractor=status_extractor, + progress_extractor=progress_extractor, + request=Veo2GenVidPollRequest( + operationName=operation_name + ), + auth_token=auth_token, + poll_interval=5.0 + ) + + # Execute the polling operation + poll_response = poll_operation.execute() + + # Now check for errors in the final response + # Check for error in poll response + if hasattr(poll_response, 'error') and poll_response.error: + error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" + logging.error(error_message) + raise Exception(error_message) + + # Check for RAI filtered content + if (hasattr(poll_response.response, 'raiMediaFilteredCount') and + poll_response.response.raiMediaFilteredCount > 0): + + # Extract reason message if available + if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and + poll_response.response.raiMediaFilteredReasons): + reason = poll_response.response.raiMediaFilteredReasons[0] + error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" + else: + error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" + + logging.error(error_message) + raise Exception(error_message) + + # Extract video data + video_data = None + if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: + video = poll_response.response.videos[0] + + # Check if video is provided as base64 or URL + if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: + # Decode base64 string to bytes + video_data = base64.b64decode(video.bytesBase64Encoded) + elif hasattr(video, 'gcsUri') and video.gcsUri: + # Download from URL + video_url = video.gcsUri + video_response = requests.get(video_url) + video_data = video_response.content + else: + raise Exception("Video returned but no data or URL was provided") + else: + raise Exception("Video generation completed but no video was returned") + + if not video_data: + raise Exception("No video data was returned") + + logging.info("Video generation completed successfully") + + # Convert video data to BytesIO object + video_io = io.BytesIO(video_data) + + # Return VideoFromFile object + return (VideoFromFile(video_io),) + + +# Register the node +NODE_CLASS_MAPPINGS = { + "VeoVideoGenerationNode": VeoVideoGenerationNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "VeoVideoGenerationNode": "Google Veo2 Video Generation", +} diff --git a/comfy_api_nodes/redocly-dev.yaml b/comfy_api_nodes/redocly-dev.yaml new file mode 100644 index 00000000..d9e3cab7 --- /dev/null +++ b/comfy_api_nodes/redocly-dev.yaml @@ -0,0 +1,10 @@ +# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. +# This is used for development purposes to generate stubs for unreleased API endpoints. +apis: + filter: + root: openapi.yaml + decorators: + filter-in: + property: tags + value: ['API Nodes'] + matchStrategy: all diff --git a/comfy_api_nodes/redocly.yaml b/comfy_api_nodes/redocly.yaml new file mode 100644 index 00000000..d102345b --- /dev/null +++ b/comfy_api_nodes/redocly.yaml @@ -0,0 +1,10 @@ +# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. + +apis: + filter: + root: openapi.yaml + decorators: + filter-in: + property: tags + value: ['API Nodes', 'Released'] + matchStrategy: all diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 184b990c..1f93f87a 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -21,6 +21,21 @@ class String(ComfyNodeABC): return (value,) +class StringMultiline(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.STRING, {"multiline": True,},)}, + } + + RETURN_TYPES = (IO.STRING,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: str) -> tuple[str]: + return (value,) + + class Int(ComfyNodeABC): @classmethod def INPUT_TYPES(cls) -> InputTypeDict: @@ -68,6 +83,7 @@ class Boolean(ComfyNodeABC): NODE_CLASS_MAPPINGS = { "PrimitiveString": String, + "PrimitiveStringMultiline": StringMultiline, "PrimitiveInt": Int, "PrimitiveFloat": Float, "PrimitiveBoolean": Boolean, @@ -75,6 +91,7 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "PrimitiveString": "String", + "PrimitiveStringMultiline": "String (Multiline)", "PrimitiveInt": "Int", "PrimitiveFloat": "Float", "PrimitiveBoolean": "Boolean", diff --git a/nodes.py b/nodes.py index 92b8ca6a..d31e0774 100644 --- a/nodes.py +++ b/nodes.py @@ -2263,7 +2263,17 @@ def init_builtin_extra_nodes(): api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") api_nodes_files = [ - "nodes_api.py", + "nodes_ideogram.py", + "nodes_openai.py", + "nodes_minimax.py", + "nodes_veo2.py", + "nodes_kling.py", + "nodes_bfl.py", + "nodes_luma.py", + "nodes_recraft.py", + "nodes_pixverse.py", + "nodes_stability.py", + "nodes_pika.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index 05ceba00..29cf0e2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.18.6 -comfyui-workflow-templates==0.1.3 +comfyui-frontend-package==1.18.9 +comfyui-workflow-templates==0.1.11 torch torchsde torchvision diff --git a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py b/tests-unit/comfy_api_nodes_test/mapper_utils_test.py new file mode 100644 index 00000000..69488f69 --- /dev/null +++ b/tests-unit/comfy_api_nodes_test/mapper_utils_test.py @@ -0,0 +1,297 @@ +from typing import Optional +from enum import Enum + +from pydantic import BaseModel, Field + +from comfy.comfy_types.node_typing import IO +from comfy_api_nodes.mapper_utils import model_field_to_node_input + + +def test_model_field_to_float_input(): + """Tests mapping a float field with constraints.""" + + class ModelWithFloatField(BaseModel): + cfg_scale: Optional[float] = Field( + default=0.5, + description="Flexibility in video generation", + ge=0.0, + le=1.0, + multiple_of=0.001, + ) + + expected_output = ( + IO.FLOAT, + { + "default": 0.5, + "tooltip": "Flexibility in video generation", + "min": 0.0, + "max": 1.0, + "step": 0.001, + }, + ) + + actual_output = model_field_to_node_input( + IO.FLOAT, ModelWithFloatField, "cfg_scale" + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_float_input_no_constraints(): + """Tests mapping a float field with no constraints.""" + + class ModelWithFloatField(BaseModel): + cfg_scale: Optional[float] = Field(default=0.5) + + expected_output = ( + IO.FLOAT, + { + "default": 0.5, + }, + ) + + actual_output = model_field_to_node_input( + IO.FLOAT, ModelWithFloatField, "cfg_scale" + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_int_input(): + """Tests mapping an int field with constraints.""" + + class ModelWithIntField(BaseModel): + num_frames: Optional[int] = Field( + default=10, + description="Number of frames to generate", + ge=1, + le=100, + multiple_of=1, + ) + + expected_output = ( + IO.INT, + { + "default": 10, + "tooltip": "Number of frames to generate", + "min": 1, + "max": 100, + "step": 1, + }, + ) + + actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames") + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_string_input(): + """Tests mapping a string field.""" + + class ModelWithStringField(BaseModel): + prompt: Optional[str] = Field( + default="A beautiful sunset over a calm ocean", + description="A prompt for the video generation", + ) + + expected_output = ( + IO.STRING, + { + "default": "A beautiful sunset over a calm ocean", + "tooltip": "A prompt for the video generation", + }, + ) + + actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt") + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_string_input_multiline(): + """Tests mapping a string field.""" + + class ModelWithStringField(BaseModel): + prompt: Optional[str] = Field( + default="A beautiful sunset over a calm ocean", + description="A prompt for the video generation", + ) + + expected_output = ( + IO.STRING, + { + "default": "A beautiful sunset over a calm ocean", + "tooltip": "A prompt for the video generation", + "multiline": True, + }, + ) + + actual_output = model_field_to_node_input( + IO.STRING, ModelWithStringField, "prompt", multiline=True + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_combo_input(): + """Tests mapping a combo field.""" + + class MockEnum(str, Enum): + option_1 = "option 1" + option_2 = "option 2" + option_3 = "option 3" + + class ModelWithComboField(BaseModel): + model_name: Optional[MockEnum] = Field("option 1", description="Model Name") + + expected_output = ( + IO.COMBO, + { + "options": ["option 1", "option 2", "option 3"], + "default": "option 1", + "tooltip": "Model Name", + }, + ) + + actual_output = model_field_to_node_input( + IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_combo_input_no_options(): + """Tests mapping a combo field with no options.""" + + class ModelWithComboField(BaseModel): + model_name: Optional[str] = Field(description="Model Name") + + expected_output = ( + IO.COMBO, + { + "tooltip": "Model Name", + }, + ) + + actual_output = model_field_to_node_input( + IO.COMBO, ModelWithComboField, "model_name" + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_image_input(): + """Tests mapping an image field.""" + + class ModelWithImageField(BaseModel): + image: Optional[str] = Field( + default=None, + description="An image for the video generation", + ) + + expected_output = ( + IO.IMAGE, + { + "default": None, + "tooltip": "An image for the video generation", + }, + ) + + actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image") + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_node_input_no_description(): + """Tests mapping a field with no description.""" + + class ModelWithNoDescriptionField(BaseModel): + field: Optional[str] = Field(default="default value") + + expected_output = ( + IO.STRING, + { + "default": "default value", + }, + ) + + actual_output = model_field_to_node_input( + IO.STRING, ModelWithNoDescriptionField, "field" + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_node_input_no_default(): + """Tests mapping a field with no default.""" + + class ModelWithNoDefaultField(BaseModel): + field: Optional[str] = Field(description="A field with no default") + + expected_output = ( + IO.STRING, + { + "tooltip": "A field with no default", + }, + ) + + actual_output = model_field_to_node_input( + IO.STRING, ModelWithNoDefaultField, "field" + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_node_input_no_metadata(): + """Tests mapping a field with no metadata or properties defined on the schema.""" + + class ModelWithNoMetadataField(BaseModel): + field: Optional[str] = Field() + + expected_output = ( + IO.STRING, + {}, + ) + + actual_output = model_field_to_node_input( + IO.STRING, ModelWithNoMetadataField, "field" + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1] + + +def test_model_field_to_node_input_default_is_none(): + """ + Tests mapping a field with a default of `None`. + I.e., the default field should be included as the schema explicitly sets it to `None`. + """ + + class ModelWithNoneDefaultField(BaseModel): + field: Optional[str] = Field( + default=None, description="A field with a default of None" + ) + + expected_output = ( + IO.STRING, + { + "default": None, + "tooltip": "A field with a default of None", + }, + ) + + actual_output = model_field_to_node_input( + IO.STRING, ModelWithNoneDefaultField, "field" + ) + + assert actual_output[0] == expected_output[0] + assert actual_output[1] == expected_output[1]