From 1b8089528502a881d0ed2918b2abd54441743dd0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 10 Oct 2024 15:06:15 -0400 Subject: [PATCH] Make clip loader nodes support loading sd3 t5xxl in lower precision. Add attention mask support in the SD3 text encoder code. --- comfy/sd.py | 29 +++++++++++++++++------------ comfy/text_encoders/sd3_clip.py | 22 ++++++++++++++-------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 65d8117c..97d6b2e9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -431,6 +431,19 @@ def detect_te_model(sd): return TEModel.T5_BASE return None + +def t5xxl_weight_dtype(clip_data): + weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" + + dtype_t5 = None + for sd in clip_data: + weight = sd.get(weight_name, None) + if weight is not None: + dtype_t5 = weight.dtype + break + return dtype_t5 + + def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = state_dicts @@ -462,9 +475,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer elif te_model == TEModel.T5_XXL: - weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] - dtype_t5 = weight.dtype - clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) + clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=t5xxl_weight_dtype(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif te_model == TEModel.T5_XL: clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model @@ -482,25 +493,19 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif len(clip_data) == 2: if clip_type == CLIPType.SD3: te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])] - clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models) + clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, dtype_t5=t5xxl_weight_dtype(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HUNYUAN_DIT: clip_target.clip = comfy.text_encoders.hydit.HyditModel clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer elif clip_type == CLIPType.FLUX: - weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" - weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None)) - dtype_t5 = None - if weight is not None: - dtype_t5 = weight.dtype - - clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5) + clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=t5xxl_weight_dtype(clip_data)) clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif len(clip_data) == 3: - clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel + clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(dtype_t5=t5xxl_weight_dtype(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer parameters = 0 diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index c54f2885..0340e65b 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -8,9 +8,9 @@ import comfy.model_management import logging class T5XXLModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -39,7 +39,7 @@ class SD3Tokenizer: return {} class SD3ClipModel(torch.nn.Module): - def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}): + def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}): super().__init__() self.dtypes = set() if clip_l: @@ -57,7 +57,8 @@ class SD3ClipModel(torch.nn.Module): if t5: dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) + self.t5_attention_mask = t5_attention_mask + self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask) self.dtypes.add(dtype_t5) else: self.t5xxl = None @@ -87,6 +88,7 @@ class SD3ClipModel(torch.nn.Module): lg_out = None pooled = None out = None + extra = {} if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: if self.clip_l is not None: @@ -111,7 +113,11 @@ class SD3ClipModel(torch.nn.Module): pooled = torch.cat((l_pooled, g_pooled), dim=-1) if self.t5xxl is not None: - t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5) + t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) + t5_out, t5_pooled = t5_output[:2] + if self.t5_attention_mask: + extra["attention_mask"] = t5_output[2]["attention_mask"] + if lg_out is not None: out = torch.cat([lg_out, t5_out], dim=-2) else: @@ -123,7 +129,7 @@ class SD3ClipModel(torch.nn.Module): if pooled is None: pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) - return out, pooled + return out, pooled, extra def load_sd(self, sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -133,8 +139,8 @@ class SD3ClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_