From d0165d819afe76bd4e6bdd710eb5f3e571b6a804 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 24 Dec 2023 07:06:59 -0500 Subject: [PATCH] Fix SVD lowvram mode. --- comfy/ldm/modules/diffusionmodules/util.py | 6 +++--- comfy/ldm/modules/temporal_ae.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 68175b62..ac7e2717 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -51,9 +51,9 @@ class AlphaBlender(nn.Module): if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor + alpha = self.mix_factor.to(image_only_indicator.device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor) + alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": @@ -61,7 +61,7 @@ class AlphaBlender(nn.Module): alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py index 7ea68dc9..2992aeaf 100644 --- a/comfy/ldm/modules/temporal_ae.py +++ b/comfy/ldm/modules/temporal_ae.py @@ -82,14 +82,14 @@ class VideoResBlock(ResnetBlock): x = self.time_stack(x, temb) - alpha = self.get_alpha(bs=b // timesteps) + alpha = self.get_alpha(bs=b // timesteps).to(x.device) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b c t h w -> (b t) c h w") return x -class AE3DConv(torch.nn.Conv2d): +class AE3DConv(ops.Conv2d): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): super().__init__(in_channels, out_channels, *args, **kwargs) if isinstance(video_kernel_size, Iterable): @@ -97,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d): else: padding = int(video_kernel_size // 2) - self.time_mix_conv = torch.nn.Conv3d( + self.time_mix_conv = ops.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=video_kernel_size, @@ -167,7 +167,7 @@ class AttnVideoBlock(AttnBlock): emb = emb[:, None, :] x_mix = x_mix + emb - alpha = self.get_alpha() + alpha = self.get_alpha().to(x.device) x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge