Change cosmos and hydit models to use the native RMSNorm. (#7934)
This commit is contained in:
parent
3041e5c354
commit
9187a09483
@ -23,7 +23,6 @@ from einops import rearrange, repeat
|
|||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
|
|||||||
return t_out
|
return t_out
|
||||||
|
|
||||||
|
|
||||||
def get_normalization(name: str, channels: int, weight_args={}):
|
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||||
if name == "I":
|
if name == "I":
|
||||||
return nn.Identity()
|
return nn.Identity()
|
||||||
elif name == "R":
|
elif name == "R":
|
||||||
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Normalization {name} not found")
|
raise ValueError(f"Normalization {name} not found")
|
||||||
|
|
||||||
@ -120,15 +119,15 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.to_q = nn.Sequential(
|
self.to_q = nn.Sequential(
|
||||||
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[0], norm_dim),
|
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
self.to_k = nn.Sequential(
|
self.to_k = nn.Sequential(
|
||||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[1], norm_dim),
|
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
self.to_v = nn.Sequential(
|
self.to_v = nn.Sequential(
|
||||||
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
||||||
get_normalization(qkv_norm[2], norm_dim),
|
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
|
@ -27,8 +27,6 @@ from torchvision import transforms
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
|
|
||||||
from .blocks import (
|
from .blocks import (
|
||||||
FinalLayer,
|
FinalLayer,
|
||||||
GeneralDITTransformerBlock,
|
GeneralDITTransformerBlock,
|
||||||
@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
|
|||||||
|
|
||||||
if self.affline_emb_norm:
|
if self.affline_emb_norm:
|
||||||
logging.debug("Building affine embedding normalization layer")
|
logging.debug("Building affine embedding normalization layer")
|
||||||
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
|
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.affline_norm = nn.Identity()
|
self.affline_norm = nn.Identity()
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
||||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
from torch.utils import checkpoint
|
from torch.utils import checkpoint
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
|
|||||||
if norm_type == "layer":
|
if norm_type == "layer":
|
||||||
norm_layer = operations.LayerNorm
|
norm_layer = operations.LayerNorm
|
||||||
elif norm_type == "rms":
|
elif norm_type == "rms":
|
||||||
norm_layer = RMSNorm
|
norm_layer = operations.RMSNorm
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user