From 53f326a3d8cfcab008d00a7603de3c90fe7f6288 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Aug 2023 12:22:46 -0400 Subject: [PATCH] Support diffusers mini controlnets. --- comfy/model_detection.py | 40 +++++++++++++++++++++++++++++++--------- comfy/sd.py | 5 ++++- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 49ee9ea7..d18e019f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -121,9 +121,20 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): return model_config_from_unet_config(unet_config) -def model_config_from_diffusers_unet(state_dict, use_fp16): +def unet_config_from_diffusers_unet(state_dict, use_fp16): match = {} - match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + attention_resolutions = [] + + attn_res = 1 + for i in range(5): + k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i) + if k in state_dict: + match["context_dim"] = state_dict[k].shape[1] + attention_resolutions.append(attn_res) + attn_res *= 2 + + match["attention_resolutions"] = attention_resolutions + match["model_channels"] = state_dict["conv_in.weight"].shape[0] match["in_channels"] = state_dict["conv_in.weight"].shape[1] match["adm_in_channels"] = None @@ -135,22 +146,22 @@ def model_config_from_diffusers_unet(state_dict, use_fp16): SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} + 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, @@ -160,9 +171,14 @@ def model_config_from_diffusers_unet(state_dict, use_fp16): SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} + 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] + SDXL_mini_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mini_cnet] for unet_config in supported_models: matches = True @@ -171,5 +187,11 @@ def model_config_from_diffusers_unet(state_dict, use_fp16): matches = False break if matches: - return model_config_from_unet_config(unet_config) + return unet_config + return None + +def model_config_from_diffusers_unet(state_dict, use_fp16): + unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16) + if unet_config is not None: + return model_config_from_unet_config(unet_config) return None diff --git a/comfy/sd.py b/comfy/sd.py index bff9ee14..06b64096 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -835,7 +835,7 @@ def load_controlnet(ckpt_path, model=None): controlnet_config = None if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config + controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) diffusers_keys = utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -874,6 +874,9 @@ def load_controlnet(ckpt_path, model=None): if k in controlnet_data: new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + leftover_keys = controlnet_data.keys() + if len(leftover_keys) > 0: + print("leftover keys:", leftover_keys) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight'