diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 74a2fd99..2ce99d46 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -313,9 +313,19 @@ except: def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads + + disabled_xformers = False + if BROKEN_XFORMERS: if b * heads > 65535: - return attention_pytorch(q, k, v, heads, mask) + disabled_xformers = True + + if not disabled_xformers: + if torch.jit.is_tracing() or torch.jit.is_scripting(): + disabled_xformers = True + + if disabled_xformers: + return attention_pytorch(q, k, v, heads, mask) q, k, v = map( lambda t: t.reshape(b, -1, heads, dim_head),