Better vace memory estimation. (#7875)
This commit is contained in:
parent
7ee96455e2
commit
dbc726f80c
@ -631,6 +631,7 @@ class VaceWanModel(WanModel):
|
|||||||
if ii is not None:
|
if ii is not None:
|
||||||
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
|
||||||
x += c_skip * vace_strength
|
x += c_skip * vace_strength
|
||||||
|
del c_skip
|
||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
@ -993,6 +993,10 @@ class WAN21_Vace(WAN21_T2V):
|
|||||||
"model_type": "vace",
|
"model_type": "vace",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.memory_usage_factor = 1.2 * self.memory_usage_factor
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
Loading…
x
Reference in New Issue
Block a user