From 656c0b5d90239efb8be4281d2c16d52ca722064c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 6 Nov 2023 13:43:50 -0500 Subject: [PATCH] CLIP code refactor and improvements. More generic clip model class that can be used on more types of text encoders. Don't apply weighting algorithm when weight is 1.0 Don't compute an empty token output when it's not needed. --- comfy/sd1_clip.py | 84 ++++++++++++++++++++++++++++++++-------------- comfy/sd2_clip.py | 3 +- comfy/sdxl_clip.py | 8 ++--- 3 files changed, 62 insertions(+), 33 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4761230a..7db7ee0f 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -8,32 +8,54 @@ import zipfile from . import model_management import contextlib +def gen_empty_tokens(special_tokens, length): + start_token = special_tokens.get("start", None) + end_token = special_tokens.get("end", None) + pad_token = special_tokens.get("pad") + output = [] + if start_token is not None: + output.append(start_token) + if end_token is not None: + output.append(end_token) + output += [pad_token] * (length - len(output)) + return output + class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): - to_encode = list(self.empty_tokens) + to_encode = list() + max_token_len = 0 + has_weights = False for x in token_weight_pairs: tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) to_encode.append(tokens) + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + out, pooled = self.encode(to_encode) - z_empty = out[0:1] - if pooled.shape[0] > 1: - first_pooled = pooled[1:2] + if pooled is not None: + first_pooled = pooled[0:1].cpu() else: - first_pooled = pooled[0:1] + first_pooled = pooled output = [] - for k in range(1, out.shape[0]): + for k in range(0, sections): z = out[k:k+1] - for i in range(len(z)): - for j in range(len(z[i])): - weight = token_weight_pairs[k - 1][j][1] - z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] + if has_weights: + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] output.append(z) if (len(output) == 0): - return z_empty.cpu(), first_pooled.cpu() - return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() + return out[-1:].cpu(), first_pooled + return torch.cat(output, dim=-2).cpu(), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" @@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, + special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, + model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.num_layers = 12 if textmodel_path is not None: - self.transformer = CLIPTextModel.from_pretrained(textmodel_path) + self.transformer = model_class.from_pretrained(textmodel_path) else: if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - config = CLIPTextConfig.from_json_file(textmodel_json_config) + config = config_class.from_json_file(textmodel_json_config) self.num_layers = config.num_hidden_layers with comfy.ops.use_comfy_ops(device, dtype): with modeling_utils.no_init_weights(): - self.transformer = CLIPTextModel(config) + self.transformer = model_class(config) + self.inner_name = inner_name if dtype is not None: self.transformer.to(dtype) - self.transformer.text_model.embeddings.token_embedding.to(torch.float32) - self.transformer.text_model.embeddings.position_embedding.to(torch.float32) + inner_model = getattr(self.transformer, self.inner_name) + if hasattr(inner_model, "embeddings"): + inner_model.embeddings.to(torch.float32) + else: + self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = None - self.empty_tokens = [[49406] + [49407] * 76] + self.special_tokens = special_tokens self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = False - self.layer_norm_hidden_state = True + self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) <= self.num_layers @@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) while len(tokens_temp) < len(x): - tokens_temp += [self.empty_tokens[0][-1]] + tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] n = token_dict_size @@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: + if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, b: contextlib.nullcontext(a) @@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: z = outputs.hidden_states[self.layer_idx] if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) + z = getattr(self.transformer, self.inner_name).final_layer_norm(z) - pooled_output = outputs.pooler_output - if self.text_projection is not None: + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() + else: + pooled_output = None + + if self.text_projection is not None and pooled_output is not None: pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() - return z.float(), pooled_output.float() + return z.float(), pooled_output def encode(self, tokens): return self(tokens) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index ebabf7cc..2ee0ca05 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): layer_idx=23 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) - self.empty_tokens = [[49406] + [49407] + [0] * 75] + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 4c508a0e..673399e2 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) - self.empty_tokens = [[49406] + [49407] + [0] * 75] - self.layer_norm_hidden_state = False + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) def load_sd(self, sd): return super().load_sd(sd) @@ -38,8 +37,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) - self.clip_l.layer_norm_hidden_state = False + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx):