Fix hunyuan dit text encoder weights always being in fp32.
This commit is contained in:
parent
2c038ccef0
commit
a5991a7aa6
@ -52,8 +52,8 @@ class HyditTokenizer:
|
|||||||
class HyditModel(torch.nn.Module):
|
class HyditModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hydit_clip = HyditBertModel()
|
self.hydit_clip = HyditBertModel(dtype=dtype)
|
||||||
self.mt5xl = MT5XLModel()
|
self.mt5xl = MT5XLModel(dtype=dtype)
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user