diff --git a/comfy/samplers.py b/comfy/samplers.py index 22a9b68a..964febb2 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -136,10 +136,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): out_cond = torch.zeros_like(x_in) - out_count = torch.zeros_like(x_in) + out_count = torch.ones_like(x_in) * 1e-37 out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.zeros_like(x_in) + out_uncond_count = torch.ones_like(x_in) * 1e-37 COND = 0 UNCOND = 1 @@ -239,9 +239,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod del out_count out_uncond /= out_uncond_count del out_uncond_count - - torch.nan_to_num(out_cond, nan=0.0, posinf=0.0, neginf=0.0, out=out_cond) #in case out_count or out_uncond_count had some zeros - torch.nan_to_num(out_uncond, nan=0.0, posinf=0.0, neginf=0.0, out=out_uncond) return out_cond, out_uncond