diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 85b017e0..a95616f1 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -34,11 +34,8 @@ class ClipVisionModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() - self.dtype = torch.float32 - if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - - self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init) + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) @@ -47,15 +44,8 @@ class ClipVisionModel(): def encode_image(self, image): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)) - - if self.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + pixel_values = clip_preprocess(image.to(self.load_device)).float() + out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())