Support dora_scale on both axis (#7727)

This commit is contained in:
Kohaku-Blueleaf 2025-04-22 17:01:27 +08:00 committed by GitHub
parent 454a635c1b
commit a8f63c0d5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,6 +43,15 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype) weight_calc = weight + function(lora_diff).type(weight.dtype)
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
if wd_on_output_axis:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
)
else:
weight_norm = ( weight_norm = (
weight_calc.transpose(0, 1) weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1) .reshape(weight_calc.shape[1], -1)
@ -50,6 +59,7 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1) .transpose(0, 1)
) )
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
weight_calc *= (dora_scale / weight_norm).type(weight.dtype) weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0: if strength != 1.0: