Fix control loras breaking.
This commit is contained in:
parent
db8b59ecff
commit
448d9263a2
@ -201,7 +201,7 @@ class ControlNet(ControlBase):
|
|||||||
super().cleanup()
|
super().cleanup()
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
device=None, dtype=None) -> None:
|
device=None, dtype=None) -> None:
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
@ -220,7 +220,7 @@ class ControlLoraOps:
|
|||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module):
|
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
40
comfy/ops.py
40
comfy/ops.py
@ -31,13 +31,13 @@ def cast_bias_weight(s, input):
|
|||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
|
class CastWeightBiasOp:
|
||||||
|
comfy_cast_weights = False
|
||||||
|
weight_function = None
|
||||||
|
bias_function = None
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
comfy_cast_weights = False
|
|
||||||
weight_function = None
|
|
||||||
bias_function = None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -51,11 +51,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d):
|
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
||||||
comfy_cast_weights = False
|
|
||||||
weight_function = None
|
|
||||||
bias_function = None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -69,11 +65,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
||||||
comfy_cast_weights = False
|
|
||||||
weight_function = None
|
|
||||||
bias_function = None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -87,11 +79,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class GroupNorm(torch.nn.GroupNorm):
|
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
||||||
comfy_cast_weights = False
|
|
||||||
weight_function = None
|
|
||||||
bias_function = None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -106,11 +94,7 @@ class disable_weight_init:
|
|||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm):
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||||
comfy_cast_weights = False
|
|
||||||
weight_function = None
|
|
||||||
bias_function = None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -128,11 +112,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||||
comfy_cast_weights = False
|
|
||||||
weight_function = None
|
|
||||||
bias_function = None
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user