From aaa9017302b75cf4453b5f8a58788e121f8e0a39 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 7 Jan 2024 04:13:58 -0500 Subject: [PATCH] Add attention mask support to sub quad attention. --- comfy/ldm/modules/attention.py | 1 + comfy/ldm/modules/sub_quadratic_attention.py | 30 +++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a18a6929..8015a307 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -177,6 +177,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=False, upcast_attention=upcast_attention, + mask=mask, ) hidden_states = hidden_states.to(dtype) diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 8e8e8054..cb0896b0 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -61,6 +61,7 @@ def _summarize_chunk( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> AttnChunk: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -84,6 +85,8 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() attn_weights -= max_score + if mask is not None: + attn_weights += mask torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) @@ -96,11 +99,12 @@ def _query_chunk_attention( value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, + mask, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape - def chunk_scanner(chunk_idx: int) -> AttnChunk: + def chunk_scanner(chunk_idx: int, mask) -> AttnChunk: key_chunk = dynamic_slice( key_t, (0, 0, chunk_idx), @@ -111,10 +115,13 @@ def _query_chunk_attention( (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head) ) - return summarize_chunk(query, key_chunk, value_chunk) + if mask is not None: + mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size] + + return summarize_chunk(query, key_chunk, value_chunk, mask=mask) chunks: List[AttnChunk] = [ - chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk @@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking( beta=0, ) + if mask is not None: + attn_scores += mask try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores @@ -183,6 +193,7 @@ def efficient_dot_product_attention( kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, upcast_attention=False, + mask = None, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -209,13 +220,22 @@ def efficient_dot_product_attention( if kv_chunk_size_min is not None: kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + if mask is not None and len(mask.shape) == 2: + mask = mask.unsqueeze(0) + def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - + + def get_mask_chunk(chunk_idx: int) -> Tensor: + if mask is None: + return None + chunk = min(query_chunk_size, q_tokens) + return mask[:,chunk_idx:chunk_idx + chunk] + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( @@ -237,6 +257,7 @@ def efficient_dot_product_attention( query=query, key_t=key_t, value=value, + mask=mask, ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, @@ -246,6 +267,7 @@ def efficient_dot_product_attention( query=get_query_chunk(i * query_chunk_size), key_t=key_t, value=value, + mask=get_mask_chunk(i * query_chunk_size) ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) return res