diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index b79af1e9..7e729147 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module): self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end + self.zsnr = zsnr # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - if zsnr: + if self.zsnr: sigmas = rescale_zero_terminal_snr_sigmas(sigmas) self.set_sigmas(sigmas) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index ccf60115..78d28488 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: metadata["modelspec.predict_key"] = "v" + extra_keys["v_pred"] = torch.tensor([]) + if getattr(model_sampling, "zsnr", False): + extra_keys["ztsnr"] = torch.tensor([]) if not args.disable_metadata: metadata["prompt"] = prompt_info