From ac7d8cfa875e08623993da8109cc73a68df42379 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Oct 2023 20:24:17 -0400 Subject: [PATCH] Allow attn_mask in attention_pytorch. --- comfy/ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3230cfaf..ac0d9c8c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -284,7 +284,7 @@ def attention_pytorch(q, k, v, heads, mask=None): (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if exists(mask): raise NotImplementedError