From c24f897352238f040e162a81d253c290635d44fd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 Jul 2024 02:00:19 -0400 Subject: [PATCH] Fix to get fp8 working on T5 base. --- comfy/text_encoders/t5.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index b6491090..a1420c6c 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -236,4 +236,6 @@ class T5(torch.nn.Module): def forward(self, input_ids, *args, **kwargs): x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) + if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + x = torch.nan_to_num(x) #Fix for fp8 T5 base return self.encoder(x, *args, **kwargs)