Support for text encoder models that need attention_mask.
This commit is contained in:
parent
0d8f376446
commit
44361f6344
@ -71,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.empty_tokens = [[49406] + [49407] * 76]
|
self.empty_tokens = [[49406] + [49407] * 76]
|
||||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
|
self.enable_attention_masks = False
|
||||||
|
|
||||||
self.layer_norm_hidden_state = True
|
self.layer_norm_hidden_state = True
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
@ -147,7 +148,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(device), torch.float32):
|
with precision_scope(model_management.get_autocast_device(device), torch.float32):
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
attention_mask = None
|
||||||
|
if self.enable_attention_masks:
|
||||||
|
attention_mask = torch.zeros_like(tokens)
|
||||||
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||||
|
for x in range(attention_mask.shape[0]):
|
||||||
|
for y in range(attention_mask.shape[1]):
|
||||||
|
attention_mask[x, y] = 1
|
||||||
|
if tokens[x, y] == max_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user