diff --git a/comfy/utils.py b/comfy/utils.py index f4c0ab41..294bbb42 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -318,7 +318,9 @@ def bislerp(samples, width, height): coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) return ratios, coords_1, coords_2 - + + orig_dtype = samples.dtype + samples = samples.float() n,c,h,w = samples.shape h_new, w_new = (height, width) @@ -347,7 +349,7 @@ def bislerp(samples, width, height): result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) - return result + return result.to(orig_dtype) def lanczos(samples, width, height): images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f65ef05e..48bcc689 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -1,6 +1,8 @@ import torch +import comfy.utils class PatchModelAddDownscale: + upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -9,13 +11,15 @@ class PatchModelAddDownscale: "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), "downscale_after_skip": ("BOOLEAN", {"default": True}), + "downscale_method": (s.upscale_methods,), + "upscale_method": (s.upscale_methods,), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing" - def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): sigma_start = model.model.model_sampling.percent_to_sigma(start_percent) sigma_end = model.model.model_sampling.percent_to_sigma(end_percent) @@ -23,12 +27,12 @@ class PatchModelAddDownscale: if transformer_options["block"][1] == block_number: sigma = transformer_options["sigmas"][0].item() if sigma <= sigma_start and sigma >= sigma_end: - h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False) + h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") return h def output_block_patch(h, hsp, transformer_options): if h.shape[2] != hsp.shape[2]: - h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False) + h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") return h, hsp m = model.clone()