Reuse code from flux model.
This commit is contained in:
parent
cce1d9145e
commit
f00f340a56
@ -8,26 +8,12 @@ from einops import repeat
|
|||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope, rope
|
||||||
|
from comfy.ldm.flux.layers import LastLayer
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
|
||||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
|
||||||
assert dim % 2 == 0, "The dimension must be even."
|
|
||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
|
||||||
omega = 1.0 / (theta**scale)
|
|
||||||
|
|
||||||
batch_size, seq_length = pos.shape
|
|
||||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
|
||||||
cos_out = torch.cos(out)
|
|
||||||
sin_out = torch.sin(out)
|
|
||||||
|
|
||||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
|
||||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
|
||||||
return out.float()
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
@ -84,23 +70,6 @@ class TimestepEmbed(nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
class OutEmbed(nn.Module):
|
|
||||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
||||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.SiLU(),
|
|
||||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, adaln_input):
|
|
||||||
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1)
|
|
||||||
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
||||||
x = self.linear(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||||
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
||||||
|
|
||||||
@ -663,7 +632,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
||||||
caption_projection = []
|
caption_projection = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user