Disable xformers in VAE when xformers == 0.0.18
This commit is contained in:
parent
af291e6f69
commit
e46b1c3034
@ -9,7 +9,7 @@ from typing import Optional, Any
|
|||||||
from ldm.modules.attention import MemoryEfficientCrossAttention
|
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||||
import model_management
|
import model_management
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
if model_management.xformers_enabled() and attn_type == "vanilla":
|
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-pytorch"
|
attn_type = "vanilla-pytorch"
|
||||||
|
@ -199,11 +199,25 @@ def get_autocast_device(dev):
|
|||||||
return dev.type
|
return dev.type
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
if vram_state == CPU:
|
if vram_state == CPU:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILBLE
|
return XFORMERS_IS_AVAILBLE
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_enabled_vae():
|
||||||
|
enabled = xformers_enabled()
|
||||||
|
if not enabled:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
|
||||||
|
if xformers.version.__version__ == "0.0.18":
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return enabled
|
||||||
|
|
||||||
def pytorch_attention_enabled():
|
def pytorch_attention_enabled():
|
||||||
return ENABLE_PYTORCH_ATTENTION
|
return ENABLE_PYTORCH_ATTENTION
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user