Optimize first attention block in cosmos VAE.
This commit is contained in:
parent
bfd5dfd611
commit
008761166f
@ -30,6 +30,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||||
|
|
||||||
from .patching import (
|
from .patching import (
|
||||||
Patcher,
|
Patcher,
|
||||||
Patcher3D,
|
Patcher3D,
|
||||||
@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
|
|||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.optimized_attention = vae_attention()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
|
|||||||
v, batch_size = time2batch(v)
|
v, batch_size = time2batch(v)
|
||||||
|
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = q.reshape(b, c, h * w)
|
h_ = self.optimized_attention(q, k, v)
|
||||||
q = q.permute(0, 2, 1)
|
|
||||||
k = k.reshape(b, c, h * w)
|
|
||||||
w_ = torch.bmm(q, k)
|
|
||||||
w_ = w_ * (int(c) ** (-0.5))
|
|
||||||
w_ = F.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b, c, h * w)
|
|
||||||
w_ = w_.permute(0, 2, 1)
|
|
||||||
h_ = torch.bmm(v, w_)
|
|
||||||
h_ = h_.reshape(b, c, h, w)
|
|
||||||
|
|
||||||
h_ = batch2time(h_, batch_size)
|
h_ = batch2time(h_, batch_size)
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def vae_attention():
|
||||||
|
if model_management.xformers_enabled_vae():
|
||||||
|
logging.info("Using xformers attention in VAE")
|
||||||
|
return xformers_attention
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
logging.info("Using pytorch attention in VAE")
|
||||||
|
return pytorch_attention
|
||||||
|
else:
|
||||||
|
logging.info("Using split attention in VAE")
|
||||||
|
return normal_attention
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
self.optimized_attention = vae_attention()
|
||||||
logging.info("Using xformers attention in VAE")
|
|
||||||
self.optimized_attention = xformers_attention
|
|
||||||
elif model_management.pytorch_attention_enabled():
|
|
||||||
logging.info("Using pytorch attention in VAE")
|
|
||||||
self.optimized_attention = pytorch_attention
|
|
||||||
else:
|
|
||||||
logging.info("Using split attention in VAE")
|
|
||||||
self.optimized_attention = normal_attention
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user