Switch mochi and wan modes to use pytorch RMSNorm. (#7925)
* Switch genmo model to native RMSNorm. * Switch WAN to native RMSNorm.
This commit is contained in:
parent
7689917113
commit
3041e5c354
@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
from .layers import (
|
from .layers import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
PatchEmbed,
|
PatchEmbed,
|
||||||
RMSNorm,
|
|
||||||
TimestepEmbedder,
|
TimestepEmbedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
# Query and key normalization for stability.
|
# Query and key normalization for stability.
|
||||||
assert qk_norm
|
assert qk_norm
|
||||||
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
|
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Output layers. y features go back down from dim_x -> dim_y.
|
# Output layers. y features go back down from dim_x -> dim_y.
|
||||||
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
|
||||||
|
@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
|
|||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
|
||||||
|
@ -9,7 +9,6 @@ from einops import repeat
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
|
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
@ -49,8 +48,8 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, freqs):
|
def forward(self, x, freqs):
|
||||||
r"""
|
r"""
|
||||||
@ -114,7 +113,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context, context_img_len):
|
def forward(self, x, context, context_img_len):
|
||||||
r"""
|
r"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user