From 8b65f5de54426f25cc7c08928332e5b7bf0fd25f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 22 Oct 2023 03:59:53 -0400 Subject: [PATCH] attention_basic now works with hypertile. --- comfy/ldm/modules/attention.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f8391e19..dcf46748 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -95,9 +95,19 @@ def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) def attention_basic(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + h = heads - scale = (q.shape[-1] // heads) ** -0.5 - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": @@ -119,7 +129,12 @@ def attention_basic(q, k, v, heads, mask=None): sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) return out