Squash depreciation warning on new pytorch.
This commit is contained in:
parent
ca9d300a80
commit
8ddc151a4c
@ -75,10 +75,16 @@ class SnakeBeta(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def WNConv1d(*args, **kwargs):
|
def WNConv1d(*args, **kwargs):
|
||||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs))
|
try:
|
||||||
|
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||||
|
except:
|
||||||
|
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||||
|
|
||||||
def WNConvTranspose1d(*args, **kwargs):
|
def WNConvTranspose1d(*args, **kwargs):
|
||||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
try:
|
||||||
|
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||||
|
except:
|
||||||
|
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||||
|
|
||||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||||
if activation == "elu":
|
if activation == "elu":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user