diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index b089eebb..03bc683c 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -96,6 +96,7 @@ class ResnetBlock(nn.Module): self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut + self.swish = torch.nn.SiLU(inplace=True) self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, @@ -106,7 +107,7 @@ class ResnetBlock(nn.Module): self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) + self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, @@ -129,14 +130,14 @@ class ResnetBlock(nn.Module): def forward(self, x, temb): h = x h = self.norm1(h) - h = nonlinearity(h) + h = self.swish(h) h = self.conv1(h) if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) - h = nonlinearity(h) + h = self.swish(h) h = self.dropout(h) h = self.conv2(h)