Support dora_scale on both axis (#7727)
This commit is contained in:
parent
454a635c1b
commit
a8f63c0d5b
@ -43,6 +43,15 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
|
||||
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||
if wd_on_output_axis:
|
||||
weight_norm = (
|
||||
weight.reshape(weight.shape[0], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||
)
|
||||
else:
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
@ -50,6 +59,7 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user