From e11052afcfd1e8aa3190d47c44f79b72a7a7a78a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Jun 2024 16:32:30 -0400 Subject: [PATCH] Add ipndm sampler. --- comfy/k_diffusion/sampling.py | 41 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 3 ++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 5bb991e7..21419926 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -841,3 +841,44 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non d_prime = w1 * d + w2 * d_2 + w3 * d_3 x = x + d_prime * dt return x + + +#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +#under Apache 2 license +def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + x_next = x + + buffer_model = [] + for i in trange(len(sigmas) - 1, disable=disable): + t_cur = sigmas[i] + t_next = sigmas[i + 1] + + x_cur = x_next + + denoised = model(x_cur, t_cur * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + d_cur = (x_cur - denoised) / t_cur + + order = min(max_order, i+1) + if order == 1: # First Euler step. + x_next = x_cur + (t_next - t_cur) * d_cur + elif order == 2: # Use one history point. + x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2 + elif order == 3: # Use two history points. + x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12 + elif order == 4: # Use three history points. + x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24 + + if len(buffer_model) == max_order - 1: + for k in range(max_order - 2): + buffer_model[k] = buffer_model[k+1] + buffer_model[-1] = d_cur + else: + buffer_model.append(d_cur) + + return x_next diff --git a/comfy/samplers.py b/comfy/samplers.py index 656e0a28..4de377cc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -539,7 +539,8 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", + "ipndm"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}):