Small cleanup.

This commit is contained in:
comfyanonymous 2024-04-04 11:16:49 -04:00
parent f117566299
commit fcfd2bdf8a
2 changed files with 8 additions and 5 deletions

View File

@ -596,13 +596,16 @@ class CFGGuider:
self.original_conds = {} self.original_conds = {}
self.cfg = 1.0 self.cfg = 1.0
def set_conds(self, conds): def set_conds(self, positive, negative):
for k in conds: self.inner_set_conds({"positive": positive, "negative": negative})
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def set_cfg(self, cfg): def set_cfg(self, cfg):
self.cfg = cfg self.cfg = cfg
def inner_set_conds(self, conds):
for k in conds:
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs) return self.predict_noise(*args, **kwargs)
@ -646,7 +649,7 @@ class CFGGuider:
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model) cfg_guider = CFGGuider(model)
cfg_guider.set_conds({"positive": positive, "negative": negative}) cfg_guider.set_conds(positive, negative)
cfg_guider.set_cfg(cfg) cfg_guider.set_cfg(cfg)
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)

View File

@ -397,7 +397,7 @@ class CFGGuider:
def get_guider(self, model, positive, negative, cfg): def get_guider(self, model, positive, negative, cfg):
guider = comfy.samplers.CFGGuider(model) guider = comfy.samplers.CFGGuider(model)
guider.set_conds({"positive": positive, "negative": negative}) guider.set_conds(positive, negative)
guider.set_cfg(cfg) guider.set_cfg(cfg)
return (guider,) return (guider,)