Add support for new colour T2I adapter model.
This commit is contained in:
parent
9d00235b41
commit
16130c7546
@ -612,6 +612,15 @@ class T2IAdapter:
|
|||||||
|
|
||||||
def load_t2i_adapter(ckpt_path, model=None):
|
def load_t2i_adapter(ckpt_path, model=None):
|
||||||
t2i_data = load_torch_file(ckpt_path)
|
t2i_data = load_torch_file(ckpt_path)
|
||||||
|
keys = t2i_data.keys()
|
||||||
|
if "style_embedding" in keys:
|
||||||
|
pass
|
||||||
|
# TODO
|
||||||
|
# model_ad = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||||
|
elif "body.0.in_conv.weight" in keys:
|
||||||
|
cin = t2i_data['body.0.in_conv.weight'].shape[1]
|
||||||
|
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
|
||||||
|
else:
|
||||||
cin = t2i_data['conv_in.weight'].shape[1]
|
cin = t2i_data['conv_in.weight'].shape[1]
|
||||||
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
|
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
|
||||||
model_ad.load_state_dict(t2i_data)
|
model_ad.load_state_dict(t2i_data)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user