From f87810cd3ed2cc3922811422181d0572f98b103d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 Jul 2024 10:52:09 -0400 Subject: [PATCH] Let tokenizers return weights to be stored in the saved checkpoint. --- comfy/sd.py | 6 +++++- comfy/sd1_clip.py | 4 ++++ comfy/sdxl_clip.py | 3 +++ comfy/text_encoders/sd3_clip.py | 3 +++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index fe6ce4c9..c6fcd810 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -135,7 +135,11 @@ class CLIP: return self.cond_stage_model.load_sd(sd) def get_sd(self): - return self.cond_stage_model.state_dict() + sd_clip = self.cond_stage_model.state_dict() + sd_tokenizer = self.tokenizer.state_dict() + for k in sd_tokenizer: + sd_clip[k] = sd_tokenizer[k] + return sd_clip def load_model(self): model_management.load_model_gpu(self.patcher) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ea66da00..c7bc1e4d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -519,6 +519,8 @@ class SDTokenizer: def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) + def state_dict(self): + return {} class SD1Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): @@ -534,6 +536,8 @@ class SD1Tokenizer: def untokenize(self, token_weight_pair): return getattr(self, self.clip).untokenize(token_weight_pair) + def state_dict(self): + return {} class SD1ClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, name=None, **kwargs): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 57c30f2e..6e6b87d6 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -34,6 +34,9 @@ class SDXLTokenizer: def untokenize(self, token_weight_pair): return self.clip_g.untokenize(token_weight_pair) + def state_dict(self): + return {} + class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 286c5dc0..b01fad22 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -34,6 +34,9 @@ class SD3Tokenizer: def untokenize(self, token_weight_pair): return self.clip_g.untokenize(token_weight_pair) + def state_dict(self): + return {} + class SD3ClipModel(torch.nn.Module): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): super().__init__()