pop clip vision keys after loading them.
This commit is contained in:
parent
c9e4a8c9e5
commit
cd930d4e7f
@ -21,7 +21,7 @@ class ClipVisionModel():
|
|||||||
size=224)
|
size=224)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
self.model.load_state_dict(sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
img = torch.clip((255. * image[0]), 0, 255).round().int()
|
||||||
@ -59,7 +59,13 @@ def load_clipvision_from_sd(sd):
|
|||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
clip = ClipVisionModel(json_config)
|
clip = ClipVisionModel(json_config)
|
||||||
clip.load_sd(sd)
|
m, u = clip.load_sd(sd)
|
||||||
|
u = set(u)
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
if k not in u:
|
||||||
|
t = sd.pop(k)
|
||||||
|
del t
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load(ckpt_path):
|
def load(ckpt_path):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user