From 2dc84d14447782683862616eaf8c19c0c1feacf3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 Jul 2024 04:06:03 -0400 Subject: [PATCH] Add a way to set the timestep multiplier in the flow sampling. --- comfy/model_sampling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 6bd3a5d7..2d95a83d 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -190,11 +190,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): else: sampling_settings = {} - self.set_parameters(shift=sampling_settings.get("shift", 1.0)) + self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000)) - def set_parameters(self, shift=1.0, timesteps=1000): + def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000): self.shift = shift - ts = self.sigma(torch.arange(1, timesteps + 1, 1)) + self.multiplier = multiplier + ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier) self.register_buffer('sigmas', ts) @property @@ -206,10 +207,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): return self.sigmas[-1] def timestep(self, sigma): - return sigma * 1000 + return sigma * self.multiplier def sigma(self, timestep): - return time_snr_shift(self.shift, timestep / 1000) + return time_snr_shift(self.shift, timestep / self.multiplier) def percent_to_sigma(self, percent): if percent <= 0.0: