From e1e322cf69319d125680d791822d8f4733fea027 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Dec 2023 21:41:10 -0500 Subject: [PATCH] Load weights that can't be lowvramed to target device. --- comfy/model_management.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3adc4270..c0cb4130 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -259,6 +259,14 @@ print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] +def module_size(module): + module_mem = 0 + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + return module_mem + class LoadedModel: def __init__(self, model): self.model = model @@ -296,14 +304,14 @@ class LoadedModel: if hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True - module_mem = 0 - sd = m.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem = module_size(m) if mem_counter + module_mem < lowvram_model_memory: m.to(self.device) mem_counter += module_mem + elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode + m.to(self.device) + mem_counter += module_size(m) + print("lowvram: loaded module regularly", m) self.model_accelerated = True