From 3ad0191bfb7674486734db98769ab466f27e9362 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 Jan 2024 04:33:03 -0500 Subject: [PATCH] Implement attention mask on xformers. --- comfy/ldm/modules/attention.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3e12886b..14d41a8c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -294,11 +294,14 @@ def attention_xformers(q, k, v, heads, mask=None): (q, k, v), ) - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + if mask is not None: + pad = 8 - q.shape[1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) + mask_out[:, :, :mask.shape[-1]] = mask + mask = mask_out[:, :, :mask.shape[-1]] + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - if exists(mask): - raise NotImplementedError out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head)