diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index dd0b72f7..35da91ee 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -1,7 +1,7 @@ import torch from torch import Tensor, nn -from .math import attention +from comfy.ldm.flux.math import attention from comfy.ldm.flux.layers import ( MLPEmbedder, RMSNorm, diff --git a/comfy/ldm/chroma/math.py b/comfy/ldm/chroma/math.py deleted file mode 100644 index 36b67931..00000000 --- a/comfy/ldm/chroma/math.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from einops import rearrange -from torch import Tensor - -from comfy.ldm.modules.attention import optimized_attention -import comfy.model_management - - -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: - q_shape = q.shape - k_shape = k.shape - - q = q.float().reshape(*q.shape[:-1], -1, 1, 2) - k = k.float().reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) - - heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) - return x - - -def rope(pos: Tensor, dim: int, theta: int) -> Tensor: - assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): - device = torch.device("cpu") - else: - device = pos.device - - scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) - omega = 1.0 / (theta**scale) - out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) - out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) - out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) - return out.to(dtype=torch.float32, device=pos.device) - - -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) - xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] - xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - diff --git a/comfy/lora.py b/comfy/lora.py index dbabd333..fff524be 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -252,7 +252,7 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format - if isinstance(model, comfy.model_base.Flux) or isinstance(model, comfy.model_base.Chroma): #Diffusers lora Flux or a diffusers lora Chroma + if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: if k.endswith(".weight"): diff --git a/comfy/model_base.py b/comfy/model_base.py index 1a06bb50..3d33086d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -787,8 +787,8 @@ class PixArt(BaseModel): return out class Flux(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) + def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux): + super().__init__(model_config, model_type, device=device, unet_model=unet_model) def concat_cond(self, **kwargs): try: @@ -1110,63 +1110,14 @@ class HiDream(BaseModel): out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond)) return out -class Chroma(BaseModel): +class Chroma(Flux): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma) - def concat_cond(self, **kwargs): - try: - #Handle Flux control loras dynamically changing the img_in weight. - num_channels = self.diffusion_model.img_in.weight.shape[1] - except: - #Some cases like tensorrt might not have the weights accessible - num_channels = self.model_config.unet_config["in_channels"] - - out_channels = self.model_config.unet_config["out_channels"] - - if num_channels <= out_channels: - return None - - image = kwargs.get("concat_latent_image", None) - noise = kwargs.get("noise", None) - device = kwargs["device"] - - if image is None: - image = torch.zeros_like(noise) - - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = utils.resize_to_batch_size(image, noise.shape[0]) - image = self.process_latent_in(image) - if num_channels <= out_channels * 2: - return image - - #inpaint model - mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) - if mask is None: - mask = torch.ones_like(noise)[:, :1] - - mask = torch.mean(mask, dim=1, keepdim=True) - mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center") - mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8) - mask = utils.resize_to_batch_size(mask, noise.shape[0]) - return torch.cat((image, mask), dim=1) - - def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) - cross_attn = kwargs.get("cross_attn", None) - if cross_attn is not None: - out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) - # upscale the attention mask, since now we - attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None: - shape = kwargs["noise"].shape - mask_ref_size = kwargs["attention_mask_img_shape"] - # the model will pad to the patch size, and then divide - # essentially dividing and rounding up - (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) - attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) - out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) - guidance = 0.0 - out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor((guidance,))) + + guidance = kwargs.get("guidance", 0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index daf6d04e..9254843e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -154,32 +154,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma - dit_config = {} - dit_config["image_model"] = "chroma" - dit_config["depth"] = 48 - dit_config["in_channels"] = 64 - patch_size = 2 - dit_config["patch_size"] = patch_size - in_key = "{}img_in.weight".format(key_prefix) - if in_key in state_dict_keys: - dit_config["in_channels"] = state_dict[in_key].shape[1] - dit_config["out_channels"] = 64 - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 - dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 - dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') - dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] - dit_config["theta"] = 10000 - dit_config["qkv_bias"] = True - dit_config["in_dim"] = 64 - dit_config["out_dim"] = 3072 - dit_config["hidden_dim"] = 5120 - dit_config["n_layers"] = 5 - return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" @@ -190,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if in_key in state_dict_keys: dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) dit_config["out_channels"] = 16 - dit_config["vec_in_dim"] = 768 + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) + if vec_in_key in state_dict_keys: + dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] dit_config["context_in_dim"] = 4096 dit_config["hidden_size"] = 3072 dit_config["mlp_ratio"] = 4.0 @@ -200,7 +176,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 10000 dit_config["qkv_bias"] = True - dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma + dit_config["image_model"] = "chroma" + dit_config["in_channels"] = 64 + dit_config["out_channels"] = 64 + dit_config["in_dim"] = 64 + dit_config["out_dim"] = 3072 + dit_config["hidden_dim"] = 5120 + dit_config["n_layers"] = 5 + else: + dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview diff --git a/comfy/sd.py b/comfy/sd.py index 454b5929..da9b36d0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -42,7 +42,6 @@ import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan import comfy.text_encoders.hidream -import comfy.text_encoders.chroma import comfy.model_patcher import comfy.lora @@ -820,7 +819,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.LTXV: clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer - elif clip_type == CLIPType.PIXART: + elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA: clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer elif clip_type == CLIPType.WAN: @@ -831,9 +830,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer - elif clip_type == CLIPType.CHROMA: - clip_target.clip = comfy.text_encoders.chroma.chroma_te(**t5xxl_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.chroma.ChromaT5Tokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index f03f2790..d5210cfa 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,7 +17,6 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan -import comfy.text_encoders.chroma from . import supported_models_base from . import latent_formats @@ -1095,7 +1094,7 @@ class Chroma(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.chroma.ChromaTokenizer, comfy.text_encoders.chroma.chroma_te(**t5_detect)) + return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma] diff --git a/comfy/text_encoders/chroma.py b/comfy/text_encoders/chroma.py deleted file mode 100644 index aa8dffb2..00000000 --- a/comfy/text_encoders/chroma.py +++ /dev/null @@ -1,43 +0,0 @@ -from comfy import sd1_clip -import comfy.text_encoders.t5 -import os -from transformers import T5TokenizerFast - - -class T5XXLModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: - model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 - - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) - - -class ChromaT5XXL(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}): - super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) - - -class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) - - -class ChromaT5Tokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) - - -def chroma_te(dtype_t5=None, t5xxl_scaled_fp8=None): - class ChromaTEModel_(ChromaT5XXL): - def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if dtype is None: - dtype = dtype_t5 - super().__init__(device=device, dtype=dtype, model_options=model_options) - return ChromaTEModel_ diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py index 10295873..e7c851ca 100644 --- a/comfy_extras/nodes_optimalsteps.py +++ b/comfy_extras/nodes_optimalsteps.py @@ -20,7 +20,7 @@ def loglinear_interp(t_steps, num_steps): NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001], "Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001], -"Chroma": [0.9919999837875366, 0.9900000095367432, 0.9879999756813049, 0.9850000143051147, 0.9819999933242798, 0.9779999852180481, 0.9729999899864197, 0.9679999947547913, 0.9610000252723694, 0.953000009059906, 0.9430000185966492, 0.9309999942779541, 0.9169999957084656, 0.8999999761581421, 0.8809999823570251, 0.8579999804496765, 0.8320000171661377, 0.8019999861717224, 0.7689999938011169, 0.7310000061988831, 0.6899999976158142, 0.6460000276565552, 0.5989999771118164, 0.550000011920929, 0.5009999871253967, 0.45100000500679016, 0.4020000100135803, 0.35499998927116394, 0.3109999895095825, 0.27000001072883606, 0.23199999332427979, 0.19900000095367432, 0.16899999976158142, 0.14300000667572021, 0.11999999731779099, 0.10100000351667404, 0.08399999886751175, 0.07000000029802322, 0.057999998331069946, 0.04800000041723251, 0.0], +"Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001], } class OptimalStepsScheduler: