Allow zeroing out of embeds with unused attention mask.
This commit is contained in:
parent
b4c2d03d47
commit
ce649d61c0
@ -169,7 +169,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks or self.zero_out_masked:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
attention_mask = torch.zeros_like(tokens)
|
||||||
end_token = self.special_tokens.get("end", -1)
|
end_token = self.special_tokens.get("end", -1)
|
||||||
for x in range(attention_mask.shape[0]):
|
for x in range(attention_mask.shape[0]):
|
||||||
@ -178,7 +178,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if tokens[x, y] == end_token:
|
if tokens[x, y] == end_token:
|
||||||
break
|
break
|
||||||
|
|
||||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
attention_mask_model = None
|
||||||
|
if self.enable_attention_masks:
|
||||||
|
attention_mask_model = attention_mask
|
||||||
|
|
||||||
|
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
@ -186,7 +190,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
else:
|
else:
|
||||||
z = outputs[1].float()
|
z = outputs[1].float()
|
||||||
|
|
||||||
if self.zero_out_masked and attention_mask is not None:
|
if self.zero_out_masked:
|
||||||
z *= attention_mask.unsqueeze(-1).float()
|
z *= attention_mask.unsqueeze(-1).float()
|
||||||
|
|
||||||
pooled_output = None
|
pooled_output = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user