Controlnet union model basic implementation.

This is only the model code itself, it currently defaults to an empty
embedding [0] * 6 which seems to work better than treating it like a
regular controlnet.

TODO: Add nodes to select the image type.
This commit is contained in:
comfyanonymous 2024-07-08 23:26:34 -04:00
parent bb663bcd6c
commit faa57430b0
2 changed files with 118 additions and 1 deletions

View File

@ -13,7 +13,47 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists from ..ldm.util import exists
from ..ldm.cascade.common import OptimizedAttention
from collections import OrderedDict
import comfy.ops import comfy.ops
from comfy.ldm.modules.attention import optimized_attention
class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.c = c
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.in_proj(x)
q, k, v = x.split(self.c, dim=2)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResBlockUnionControlnet(nn.Module):
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
def attention(self, x: torch.Tensor):
return self.attn(x)
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):
#implemented in the ldm unet #implemented in the ldm unet
@ -53,6 +93,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
attn_precision=None, attn_precision=None,
union_controlnet=False,
device=None, device=None,
operations=comfy.ops.disable_weight_init, operations=comfy.ops.disable_weight_init,
**kwargs, **kwargs,
@ -280,6 +321,65 @@ class ControlNet(nn.Module):
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch self._feature_size += ch
if union_controlnet:
self.num_control_type = 6
num_trans_channel = 320
num_trans_head = 8
num_trans_layer = 1
num_proj_channel = 320
# task_scale_factor = num_trans_channel ** 0.5
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
#-----------------------------------------------------------------------------------------------------
control_add_embed_dim = 256
class ControlAddEmbedding(nn.Module):
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
super().__init__()
self.num_control_type = num_control_type
self.in_dim = in_dim
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
def forward(self, control_type, dtype, device):
c_type = torch.zeros((self.num_control_type,), device=device)
c_type[control_type] = 1.0
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
else:
self.task_embedding = None
self.control_add_embedding = None
def union_controlnet_merge(self, hint, control_type, emb, context):
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
inputs = []
condition_list = []
for idx in range(min(1, len(control_type))):
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond)
x = torch.cat(inputs, dim=1)
x = self.transformer_layes(x)
controlnet_cond_fuser = None
for idx in range(len(control_type)):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
o = condition_list[idx] + alpha
if controlnet_cond_fuser is None:
controlnet_cond_fuser = o
else:
controlnet_cond_fuser += o
return controlnet_cond_fuser
def make_zero_conv(self, channels, operations=None, dtype=None, device=None): def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
@ -287,7 +387,18 @@ class ControlNet(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context) guided_hint = None
if self.control_add_embedding is not None:
control_type = kwargs.get("control_type", [])
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0:
if len(hint.shape) < 5:
hint = hint.unsqueeze(dim=0)
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
if guided_hint is None:
guided_hint = self.input_hint_block(hint, emb, context)
out_output = [] out_output = []
out_middle = [] out_middle = []

View File

@ -413,6 +413,12 @@ def load_controlnet(ckpt_path, model=None):
if k in controlnet_data: if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k) new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
controlnet_config["union_controlnet"] = True
for k in list(controlnet_data.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
new_sd[new_k] = controlnet_data.pop(k)
leftover_keys = controlnet_data.keys() leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0: if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys)) logging.warning("leftover keys: {}".format(leftover_keys))