diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 6e8e0618..056e101a 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,7 +1,6 @@ import torch from torch import nn import comfy.ldm.modules.attention -from comfy.ldm.genmo.joint_model.layers import RMSNorm import comfy.ldm.common_dit from einops import rearrange import math @@ -262,8 +261,8 @@ class CrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device) - self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)