From 5b4e3127494674e9a3f2e668e8fb49b278e079a9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 8 Feb 2023 20:52:02 -0500 Subject: [PATCH] Use inplace operations for less OOM issues. --- comfy/ldm/modules/diffusionmodules/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index b089eebb..03bc683c 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -96,6 +96,7 @@ class ResnetBlock(nn.Module): self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut + self.swish = torch.nn.SiLU(inplace=True) self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, @@ -106,7 +107,7 @@ class ResnetBlock(nn.Module): self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) + self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, @@ -129,14 +130,14 @@ class ResnetBlock(nn.Module): def forward(self, x, temb): h = x h = self.norm1(h) - h = nonlinearity(h) + h = self.swish(h) h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) - h = nonlinearity(h) + h = self.swish(h) h = self.dropout(h) h = self.conv2(h)