Implement attention mask on xformers.
This commit is contained in:
parent
af94eb14e3
commit
3ad0191bfb
@ -294,11 +294,14 @@ def attention_xformers(q, k, v, heads, mask=None):
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
# actually compute the attention, what we cannot get enough of
|
if mask is not None:
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
pad = 8 - q.shape[1] % 8
|
||||||
|
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||||
|
mask_out[:, :, :mask.shape[-1]] = mask
|
||||||
|
mask = mask_out[:, :, :mask.shape[-1]]
|
||||||
|
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
raise NotImplementedError
|
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.unsqueeze(0)
|
||||||
.reshape(b, heads, -1, dim_head)
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user