diff --git a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py new file mode 100644 index 00000000..f4d52aa1 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py @@ -0,0 +1,110 @@ +import math + +import torch.nn as nn + + +class CA_layer(nn.Module): + def __init__(self, channel, reduction=16): + super(CA_layer, self).__init__() + # global average pooling + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False), + nn.GELU(), + nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False), + # nn.Sigmoid() + ) + + def forward(self, x): + y = self.fc(self.gap(x)) + return x * y.expand_as(x) + + +class Simple_CA_layer(nn.Module): + def __init__(self, channel): + super(Simple_CA_layer, self).__init__() + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=channel, + out_channels=channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return x * self.fc(self.gap(x)) + + +class ECA_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class ECA_MaxPool_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_MaxPool_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.max_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py new file mode 100644 index 00000000..d7a12969 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSA.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:07:42 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce +from torch import einsum, nn + +from .layernorm import LayerNorm2d + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, length=1): + return val if isinstance(val, tuple) else ((val,) * length) + + +# helper classes + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class Conv_PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = LayerNorm2d(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1, 0), + nn.GELU(), + nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim, 1, 1, 0), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Gated_Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=1, bias=False, dropout=0.0): + super().__init__() + + hidden_features = int(dim * mult) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + ) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +# MBConv + + +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate=0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(dim, hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias=False), + nn.Sigmoid(), + Rearrange("b c -> b c 1 1"), + ) + + def forward(self, x): + return x * self.gate(x) + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout=0.0): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + x + + +class Dropsample(nn.Module): + def __init__(self, prob=0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0.0 or (not self.training): + return x + + keep_mask = ( + torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() + > self.prob + ) + return x * keep_mask / (1 - self.prob) + + +def MBConv( + dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 +): + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + net = nn.Sequential( + nn.Conv2d(dim_in, hidden_dim, 1), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d( + hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim + ), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), + # nn.BatchNorm2d(dim_out) + ) + + if dim_in == dim_out and not downsample: + net = MBConvResidual(net, dropout=dropout) + + return net + + +# attention related classes +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) + ) + + # relative positional bias + if self.with_pe: + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos)) + grid = rearrange(grid, "c i j -> (i j) c") + rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( + grid, "j ... -> 1 j ..." + ) + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( + dim=-1 + ) + + self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) + + def forward(self, x): + batch, height, width, window_height, window_width, _, device, h = ( + *x.shape, + x.device, + self.heads, + ) + + # flatten + + x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") + + # project for queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # add positional bias + if self.with_pe: + bias = self.rel_pos_bias(self.rel_pos_indices) + sim = sim + rearrange(bias, "i j h -> h i j") + + # attention + + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + + out = rearrange( + out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width + ) + + # combine heads out + + out = self.to_out(out) + return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) + + +class Block_Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + bias=False, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.ps = window_size + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + # project for queries, keys, values + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # split heads + + q, k, v = map( + lambda t: rearrange( + t, + "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", + h=self.heads, + w1=self.ps, + w2=self.ps, + ), + (q, k, v), + ) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # attention + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + out = rearrange( + out, + "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", + x=h // self.ps, + y=w // self.ps, + head=self.heads, + w1=self.ps, + w2=self.ps, + ) + + out = self.to_out(out) + return out + + +class Channel_Attention(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class Channel_Attention_grid(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention_grid, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class OSA_Block(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + ffn_bias=True, + window_size=8, + with_pe=False, + dropout=0.0, + ): + super(OSA_Block, self).__init__() + + w = window_size + + self.layer = nn.Sequential( + MBConv( + channel_num, + channel_num, + downsample=False, + expansion_rate=1, + shrinkage_rate=0.25, + ), + Rearrange( + "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w + ), # block-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + Rearrange( + "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w + ), # grid-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention_grid( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + ) + + def forward(self, x): + out = self.layer(x) + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py new file mode 100644 index 00000000..477e81f9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSAG.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:08:49 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + + +import torch.nn as nn + +from .esa import ESA +from .OSA import OSA_Block + + +class OSAG(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + block_num=4, + ffn_bias=False, + window_size=0, + pe=False, + ): + super(OSAG, self).__init__() + + # print("window_size: %d" % (window_size)) + # print("with_pe", pe) + # print("ffn_bias: %d" % (ffn_bias)) + + # block_script_name = kwargs.get("block_script_name", "OSA") + # block_class_name = kwargs.get("block_class_name", "OSA_Block") + + # script_name = "." + block_script_name + # package = __import__(script_name, fromlist=True) + block_class = OSA_Block # getattr(package, block_class_name) + group_list = [] + for _ in range(block_num): + temp_res = block_class( + channel_num, + bias, + ffn_bias=ffn_bias, + window_size=window_size, + with_pe=pe, + ) + group_list.append(temp_res) + group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias)) + self.residual_layer = nn.Sequential(*group_list) + esa_channel = max(channel_num // 4, 16) + self.esa = ESA(esa_channel, channel_num) + + def forward(self, x): + out = self.residual_layer(x) + out = out + x + return self.esa(out) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py new file mode 100644 index 00000000..dec16952 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OmniSR.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:06:36 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .OSAG import OSAG +from .pixelshuffle import pixelshuffle_block + + +class OmniSR(nn.Module): + def __init__( + self, + state_dict, + **kwargs, + ): + super(OmniSR, self).__init__() + self.state = state_dict + + bias = True # Fine to assume this for now + block_num = 1 # Fine to assume this for now + ffn_bias = True + pe = True + + num_feat = state_dict["input.weight"].shape[0] or 64 + num_in_ch = state_dict["input.weight"].shape[1] or 3 + num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh + + pixelshuffle_shape = state_dict["up.0.weight"].shape[0] + up_scale = math.sqrt(pixelshuffle_shape / num_out_ch) + if up_scale - int(up_scale) > 0: + print( + "out_nc is probably different than in_nc, scale calculation might be wrong" + ) + up_scale = int(up_scale) + res_num = 0 + for key in state_dict.keys(): + if "residual_layer" in key: + temp_res_num = int(key.split(".")[1]) + if temp_res_num > res_num: + res_num = temp_res_num + res_num = res_num + 1 # zero-indexed + + residual_layer = [] + self.res_num = res_num + + self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer) + self.up_scale = up_scale + + for _ in range(res_num): + temp_res = OSAG( + channel_num=num_feat, + bias=bias, + block_num=block_num, + ffn_bias=ffn_bias, + window_size=self.window_size, + pe=pe, + ) + residual_layer.append(temp_res) + self.residual_layer = nn.Sequential(*residual_layer) + self.input = nn.Conv2d( + in_channels=num_in_ch, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.output = nn.Conv2d( + in_channels=num_feat, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias) + + # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, sqrt(2. / n)) + + # chaiNNer specific stuff + self.model_arch = "OmniSR" + self.sub_type = "SR" + self.in_nc = num_in_ch + self.out_nc = num_out_ch + self.num_feat = num_feat + self.scale = up_scale + + self.supports_fp16 = True # TODO: Test this + self.supports_bfp16 = True + self.min_size_restriction = 16 + + self.load_state_dict(state_dict, strict=False) + + def check_image_size(self, x): + _, _, h, w = x.size() + # import pdb; pdb.set_trace() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0) + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + residual = self.input(x) + out = self.residual_layer(residual) + + # origin + out = torch.add(self.output(out), residual) + out = self.up(out) + + out = out[:, :, : H * self.up_scale, : W * self.up_scale] + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/esa.py b/comfy_extras/chainner_models/architecture/OmniSR/esa.py new file mode 100644 index 00000000..f9ce7f7a --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/esa.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: esa.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:06 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .layernorm import LayerNorm2d + + +def moment(x, dim=(2, 3), k=2): + assert len(x.size()) == 4 + mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1) + mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim) + return mk + + +class ESA(nn.Module): + """ + Modification of Enhanced Spatial Attention (ESA), which is proposed by + `Residual Feature Aggregation Network for Image Super-Resolution` + Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes + are deleted. + """ + + def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): + super(ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) + self.conv3 = conv(f, f, kernel_size=3, padding=1) + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + c1 = self.conv2(c1_) + v_max = F.max_pool2d(c1, kernel_size=7, stride=3) + c3 = self.conv3(v_max) + c3 = F.interpolate( + c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False + ) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA_LN(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA_LN, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.norm = LayerNorm2d(n_feats) + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.norm(x) + c1_ = self.conv1(c1_) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class AdaGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaGuidedFilter, self).__init__() + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=n_feats, + out_channels=1, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + self.r = 5 + + def box_filter(self, x, r): + channel = x.shape[1] + kernel_size = 2 * r + 1 + weight = 1.0 / (kernel_size**2) + box_kernel = weight * torch.ones( + (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device + ) + output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) + return output + + def forward(self, x): + _, _, H, W = x.shape + N = self.box_filter( + torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r + ) + + # epsilon = self.fc(self.gap(x)) + # epsilon = torch.pow(epsilon, 2) + epsilon = 1e-2 + + mean_x = self.box_filter(x, self.r) / N + var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x + + A = var_x / (var_x + epsilon) + b = (1 - A) * mean_x + m = A * x + b + + # mean_A = self.box_filter(A, self.r) / N + # mean_b = self.box_filter(b, self.r) / N + # m = mean_A * x + mean_b + return x * m + + +class AdaConvGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaConvGuidedFilter, self).__init__() + f = esa_channels + + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=f, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=f, + bias=bias, + ) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + y = self.vec_conv(x) + y = self.hor_conv(y) + + sigma = torch.pow(y, 2) + epsilon = self.fc(self.gap(y)) + + weight = sigma / (sigma + epsilon) + + m = weight * x + (1 - weight) + + return x * m diff --git a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py new file mode 100644 index 00000000..731a25f7 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: layernorm.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:20 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py new file mode 100644 index 00000000..4260fb7c --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: pixelshuffle.py +# Created Date: Friday July 1st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 1st July 2022 10:18:39 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch.nn as nn + + +def pixelshuffle_block( + in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False +): + """ + Upsample features according to `upscale_factor`. + """ + padding = kernel_size // 2 + conv = nn.Conv2d( + in_channels, + out_channels * (upscale_factor**2), + kernel_size, + padding=1, + bias=bias, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + return nn.Sequential(*[conv, pixel_shuffle]) diff --git a/comfy_extras/chainner_models/architecture/RRDB.py b/comfy_extras/chainner_models/architecture/RRDB.py index 4d52f05d..b50db7c2 100644 --- a/comfy_extras/chainner_models/architecture/RRDB.py +++ b/comfy_extras/chainner_models/architecture/RRDB.py @@ -79,6 +79,12 @@ class RRDBNet(nn.Module): self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None @@ -105,11 +111,15 @@ class RRDBNet(nn.Module): out_nc=self.num_filters, upscale_factor=3, act_type=self.act, + c2x2=c2x2, ) else: upsample_blocks = [ upsample_block( - in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] @@ -122,6 +132,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), B.ShortcutBlock( B.sequential( @@ -138,6 +149,7 @@ class RRDBNet(nn.Module): act_type=self.act, mode="CNA", plus=self.plus, + c2x2=c2x2, ) for _ in range(self.num_blocks) ], @@ -149,6 +161,7 @@ class RRDBNet(nn.Module): norm_type=self.norm, act_type=None, mode=self.mode, + c2x2=c2x2, ), ) ), @@ -160,6 +173,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=self.act, + c2x2=c2x2, ), # hr_conv1 B.conv_block( @@ -168,6 +182,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), ) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 214642cc..d7bc5d22 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -141,6 +141,19 @@ def sequential(*args): ConvMode = Literal["CNA", "NAC", "CNAC"] +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + def conv_block( in_nc: int, out_nc: int, @@ -153,12 +166,17 @@ def conv_block( norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", + c2x2=False, ): """ Conv layer with padding, normalization, activation mode: CNA --> Conv -> Norm -> Act NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None @@ -285,6 +303,7 @@ class RRDB(nn.Module): _convtype="Conv2D", _spectral_norm=False, plus=False, + c2x2=False, ): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( @@ -298,6 +317,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, @@ -310,6 +330,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, @@ -322,6 +343,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) def forward(self, x): @@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module): act_type="leakyrelu", mode: ConvMode = "CNA", plus=False, + c2x2=False, ): super(ResidualDenseBlock_5C, self).__init__() @@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv2 = conv_block( nf + gc, @@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv3 = conv_block( nf + 2 * gc, @@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv4 = conv_block( nf + 3 * gc, @@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) if mode == "CNA": last_act = None @@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=last_act, mode=mode, + c2x2=c2x2, ) def forward(self, x): @@ -499,6 +527,7 @@ def upconv_block( norm_type: str | None = None, act_type="relu", mode="nearest", + c2x2=False, ): # Up conv # described in https://distill.pub/2016/deconv-checkerboard/ @@ -512,5 +541,6 @@ def upconv_block( pad_type=pad_type, norm_type=norm_type, act_type=act_type, + c2x2=c2x2, ) return sequential(upsample, conv) diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index 8234ac5d..2e66e624 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel: state_dict = state_dict["params"] state_dict_keys = list(state_dict.keys()) + # SRVGGNet Real-ESRGAN (v2) if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: model = RealESRGANv2(state_dict) @@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel: # MAT elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys: model = MAT(state_dict) + # Omni-SR + elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: + model = OmniSR(state_dict) # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 else: try: diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py index 8e2bef47..1906c0c7 100644 --- a/comfy_extras/chainner_models/types.py +++ b/comfy_extras/chainner_models/types.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN from .architecture.Swin2SR import Swin2SR from .architecture.SwinIR import SwinIR -PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT) +PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR) PyTorchSRModel = Union[ RealESRGANv2, SPSR, @@ -22,6 +23,7 @@ PyTorchSRModel = Union[ SwinIR, Swin2SR, HAT, + OmniSR, ]