Save v pred zsnr metadata (#7840)
This commit is contained in:
parent
cb9ac3db58
commit
30159a7fe6
@ -111,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
self.linear_start = linear_start
|
self.linear_start = linear_start
|
||||||
self.linear_end = linear_end
|
self.linear_end = linear_end
|
||||||
|
self.zsnr = zsnr
|
||||||
|
|
||||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
if zsnr:
|
if self.zsnr:
|
||||||
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
|
||||||
|
|
||||||
self.set_sigmas(sigmas)
|
self.set_sigmas(sigmas)
|
||||||
|
@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
|||||||
metadata["modelspec.predict_key"] = "epsilon"
|
metadata["modelspec.predict_key"] = "epsilon"
|
||||||
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
|
||||||
metadata["modelspec.predict_key"] = "v"
|
metadata["modelspec.predict_key"] = "v"
|
||||||
|
extra_keys["v_pred"] = torch.tensor([])
|
||||||
|
if getattr(model_sampling, "zsnr", False):
|
||||||
|
extra_keys["ztsnr"] = torch.tensor([])
|
||||||
|
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
metadata["prompt"] = prompt_info
|
metadata["prompt"] = prompt_info
|
||||||
|
Loading…
x
Reference in New Issue
Block a user