diff --git a/comfy/ldm/audio/autoencoder.py b/comfy/ldm/audio/autoencoder.py index 7363131e..8123e66a 100644 --- a/comfy/ldm/audio/autoencoder.py +++ b/comfy/ldm/audio/autoencoder.py @@ -75,10 +75,16 @@ class SnakeBeta(nn.Module): return x def WNConv1d(*args, **kwargs): - return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) + try: + return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs)) + except: + return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older def WNConvTranspose1d(*args, **kwargs): - return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) + try: + return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) + except: + return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: if activation == "elu":