From aa9d759df36faa2e34cc5722463749a09a7f529b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 1 May 2025 03:33:42 -0700 Subject: [PATCH] Switch ltxv to use the pytorch RMSNorm. (#7897) --- comfy/ldm/lightricks/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 6e8e0618..056e101a 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,7 +1,6 @@ import torch from torch import nn import comfy.ldm.modules.attention -from comfy.ldm.genmo.joint_model.layers import RMSNorm import comfy.ldm.common_dit from einops import rearrange import math @@ -262,8 +261,8 @@ class CrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device) - self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)