diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index b6491090..a1420c6c 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -236,4 +236,6 @@ class T5(torch.nn.Module): def forward(self, input_ids, *args, **kwargs): x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) + if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + x = torch.nan_to_num(x) #Fix for fp8 T5 base return self.encoder(x, *args, **kwargs)