From a8f63c0d5b40b4ed12faa1376e973b0e790b1c0d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 22 Apr 2025 17:01:27 +0800 Subject: [PATCH] Support dora_scale on both axis (#7727) --- comfy/weight_adapter/base.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 54af3bab..29873519 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -43,13 +43,23 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) + + wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0] + if wd_on_output_axis: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[0], *[1] * (weight.dim() - 1)) + ) + else: + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + weight_norm = weight_norm + torch.finfo(weight.dtype).eps weight_calc *= (dora_scale / weight_norm).type(weight.dtype) if strength != 1.0: