Lower cosmos diffusion model memory usage.
This commit is contained in:
parent
4758fb64b9
commit
25683b5b02
@ -168,14 +168,18 @@ class Attention(nn.Module):
|
|||||||
k = self.to_k[1](k)
|
k = self.to_k[1](k)
|
||||||
v = self.to_v[1](v)
|
v = self.to_v[1](v)
|
||||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||||
q = apply_rotary_pos_emb(q, rope_emb)
|
# apply_rotary_pos_emb inlined
|
||||||
k = apply_rotary_pos_emb(k, rope_emb)
|
q_shape = q.shape
|
||||||
return q, k, v
|
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
|
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
||||||
|
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
||||||
|
|
||||||
def cal_attn(self, q, k, v, mask=None):
|
# apply_rotary_pos_emb inlined
|
||||||
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
k_shape = k.shape
|
||||||
out = rearrange(out, " b n s c -> s b (n c)")
|
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
||||||
return self.to_out(out)
|
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
||||||
|
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -191,7 +195,10 @@ class Attention(nn.Module):
|
|||||||
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
||||||
"""
|
"""
|
||||||
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
||||||
return self.cal_attn(q, k, v, mask)
|
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
||||||
|
del q, k, v
|
||||||
|
out = rearrange(out, " b n s c -> s b (n c)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module):
|
|||||||
crossattn_mask: Optional[torch.Tensor] = None,
|
crossattn_mask: Optional[torch.Tensor] = None,
|
||||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||||
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if extra_per_block_pos_emb is not None:
|
|
||||||
x = x + extra_per_block_pos_emb
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
|
@ -168,7 +168,7 @@ class GeneralDIT(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.build_pos_embed(device=device)
|
self.build_pos_embed(device=device, dtype=dtype)
|
||||||
self.block_x_format = block_x_format
|
self.block_x_format = block_x_format
|
||||||
self.use_adaln_lora = use_adaln_lora
|
self.use_adaln_lora = use_adaln_lora
|
||||||
self.adaln_lora_dim = adaln_lora_dim
|
self.adaln_lora_dim = adaln_lora_dim
|
||||||
@ -210,7 +210,7 @@ class GeneralDIT(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_pos_embed(self, device=None):
|
def build_pos_embed(self, device=None, dtype=None):
|
||||||
if self.pos_emb_cls == "rope3d":
|
if self.pos_emb_cls == "rope3d":
|
||||||
cls_type = VideoRopePosition3DEmb
|
cls_type = VideoRopePosition3DEmb
|
||||||
else:
|
else:
|
||||||
@ -242,6 +242,7 @@ class GeneralDIT(nn.Module):
|
|||||||
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
||||||
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
||||||
kwargs["device"] = device
|
kwargs["device"] = device
|
||||||
|
kwargs["dtype"] = dtype
|
||||||
self.extra_pos_embedder = LearnablePosEmbAxis(
|
self.extra_pos_embedder = LearnablePosEmbAxis(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -476,6 +477,8 @@ class GeneralDIT(nn.Module):
|
|||||||
inputs["original_shape"],
|
inputs["original_shape"],
|
||||||
)
|
)
|
||||||
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
||||||
|
del inputs
|
||||||
|
|
||||||
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
assert (
|
assert (
|
||||||
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
||||||
@ -486,6 +489,8 @@ class GeneralDIT(nn.Module):
|
|||||||
self.blocks["block0"].x_format == block.x_format
|
self.blocks["block0"].x_format == block.x_format
|
||||||
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
||||||
|
|
||||||
|
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
||||||
|
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
||||||
x = block(
|
x = block(
|
||||||
x,
|
x,
|
||||||
affline_emb_B_D,
|
affline_emb_B_D,
|
||||||
@ -493,7 +498,6 @@ class GeneralDIT(nn.Module):
|
|||||||
crossattn_mask,
|
crossattn_mask,
|
||||||
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
||||||
adaln_lora_B_3D=adaln_lora_B_3D,
|
adaln_lora_B_3D=adaln_lora_B_3D,
|
||||||
extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
||||||
|
@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
len_w: int,
|
len_w: int,
|
||||||
len_t: int,
|
len_t: int,
|
||||||
device=None,
|
device=None,
|
||||||
|
dtype=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb):
|
|||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
||||||
|
|
||||||
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
|
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
||||||
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
|
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
||||||
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
|
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user