Support dora_scale on both axis (#7727)
This commit is contained in:
parent
454a635c1b
commit
a8f63c0d5b
@ -43,6 +43,15 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
|||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
|
||||||
|
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||||
|
if wd_on_output_axis:
|
||||||
|
weight_norm = (
|
||||||
|
weight.reshape(weight.shape[0], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||||
|
)
|
||||||
|
else:
|
||||||
weight_norm = (
|
weight_norm = (
|
||||||
weight_calc.transpose(0, 1)
|
weight_calc.transpose(0, 1)
|
||||||
.reshape(weight_calc.shape[1], -1)
|
.reshape(weight_calc.shape[1], -1)
|
||||||
@ -50,6 +59,7 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
|||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
)
|
)
|
||||||
|
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
if strength != 1.0:
|
if strength != 1.0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user