Change cosmos and hydit models to use the native RMSNorm. (#7934)

This commit is contained in:
comfyanonymous 2025-05-04 03:26:20 -07:00 committed by GitHub
parent 3041e5c354
commit 9187a09483
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 11 deletions

View File

@ -23,7 +23,6 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from torch import nn from torch import nn
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
return t_out return t_out
def get_normalization(name: str, channels: int, weight_args={}): def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I": if name == "I":
return nn.Identity() return nn.Identity()
elif name == "R": elif name == "R":
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else: else:
raise ValueError(f"Normalization {name} not found") raise ValueError(f"Normalization {name} not found")
@ -120,15 +119,15 @@ class Attention(nn.Module):
self.to_q = nn.Sequential( self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[0], norm_dim), get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_k = nn.Sequential( self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[1], norm_dim), get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_v = nn.Sequential( self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
get_normalization(qkv_norm[2], norm_dim), get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
) )
self.to_out = nn.Sequential( self.to_out = nn.Sequential(

View File

@ -27,8 +27,6 @@ from torchvision import transforms
from enum import Enum from enum import Enum
import logging import logging
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from .blocks import ( from .blocks import (
FinalLayer, FinalLayer,
GeneralDITTransformerBlock, GeneralDITTransformerBlock,
@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
if self.affline_emb_norm: if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer") logging.debug("Building affine embedding normalization layer")
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6) self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
else: else:
self.affline_norm = nn.Identity() self.affline_norm = nn.Identity()

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint from torch.utils import checkpoint
@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
if norm_type == "layer": if norm_type == "layer":
norm_layer = operations.LayerNorm norm_layer = operations.LayerNorm
elif norm_type == "rms": elif norm_type == "rms":
norm_layer = RMSNorm norm_layer = operations.RMSNorm
else: else:
raise ValueError(f"Unknown norm_type: {norm_type}") raise ValueError(f"Unknown norm_type: {norm_type}")