diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 45396d86..6f511338 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -19,6 +19,11 @@ from typing import Optional, NamedTuple, Protocol, List from torch import Tensor from typing import List +try: + OOM_EXCEPTION = torch.cuda.OutOfMemoryError +except: + OOM_EXCEPTION = Exception + def dynamic_slice( x: Tensor, starts: List[int], @@ -151,7 +156,7 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except torch.cuda.OutOfMemoryError: + except OOM_EXCEPTION: print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") torch.exp(attn_scores, out=attn_scores) summed = torch.sum(attn_scores, dim=-1, keepdim=True)