Skip to content

Commit ffbe1ab

Browse files
committed
Merge branch 'feature_textual_inversion_mgr' into dev/diffusers
2 parents 6e4dad6 + 023df37 commit ffbe1ab

File tree

3 files changed

+724
-44
lines changed

3 files changed

+724
-44
lines changed

backend/invoke_ai_web_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,8 @@ def image_done(image, seed, first_seed, attention_maps_image=None):
11001100
get_tokens_for_prompt(self.generate.model, parsed_prompt)
11011101
attention_maps_image_base64_url = None if attention_maps_image is None \
11021102
else image_to_dataURL(attention_maps_image)
1103+
if attention_maps_image is not None:
1104+
attention_maps_image.save(path + '.attention.png', 'PNG')
11031105

11041106
self.socketio.emit(
11051107
"generationResult",

ldm/modules/embedding_manager.py

Lines changed: 183 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os.path
22
from cmath import log
33
import torch
4+
from attr import dataclass
45
from torch import nn
56

67
import sys
@@ -14,35 +15,173 @@
1415
PROGRESSIVE_SCALE = 2000
1516

1617

17-
def get_clip_token_for_string(tokenizer, string):
18-
batch_encoding = tokenizer(
19-
string,
20-
truncation=True,
21-
max_length=77,
22-
return_length=True,
23-
return_overflowing_tokens=False,
24-
padding='max_length',
25-
return_tensors='pt',
26-
)
27-
tokens = batch_encoding['input_ids']
28-
""" assert (
29-
torch.count_nonzero(tokens - 49407) == 2
30-
), f"String '{string}' maps to more than a single token. Please use another string" """
18+
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
19+
token_id = tokenizer.convert_tokens_to_ids(token_str)
20+
return token_id
3121

32-
return tokens[0, 1]
33-
34-
35-
def get_bert_token_for_string(tokenizer, string):
22+
def get_bert_token_id_for_string(tokenizer, string) -> int:
3623
token = tokenizer(string)
3724
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
38-
3925
token = token[0, 1]
26+
return token.item()
27+
28+
29+
def get_embedding_for_clip_token_id(embedder, token_id):
30+
if type(token_id) is not torch.Tensor:
31+
token_id = torch.tensor(token_id, dtype=torch.int)
32+
return embedder(token_id.unsqueeze(0))[0, 0]
33+
34+
@dataclass
35+
class TextualInversion:
36+
trigger_string: str
37+
token_id: int
38+
embedding: torch.Tensor
39+
40+
@property
41+
def embedding_vector_length(self) -> int:
42+
return self.embedding.shape[0]
43+
44+
class TextualInversionManager():
45+
def __init__(self, clip_embedder):
46+
self.clip_embedder = clip_embedder
47+
default_textual_inversions: list[TextualInversion] = []
48+
self.textual_inversions = default_textual_inversions
49+
50+
def load_textual_inversion(self, ckpt_path, full_precision=True):
51+
52+
scan_result = scan_file_path(ckpt_path)
53+
if scan_result.infected_files == 1:
54+
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
55+
print('### For your safety, InvokeAI will not load this embed.')
56+
return
57+
58+
ckpt = torch.load(ckpt_path, map_location='cpu')
59+
60+
# Handle .pt textual inversion files
61+
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
62+
filename = os.path.basename(ckpt_path)
63+
token_str = '.'.join(filename.split('.')[:-1]) # filename excluding extension
64+
if len(ckpt["string_to_token"]) > 1:
65+
print(f">> {ckpt_path} has >1 embedding, only the first will be used")
4066

41-
return token
67+
string_to_param_dict = ckpt['string_to_param']
68+
embedding = list(string_to_param_dict.values())[0]
69+
self.add_textual_inversion(token_str, embedding, full_precision)
4270

71+
# Handle .bin textual inversion files from Huggingface Concepts
72+
# https://huggingface.co/sd-concepts-library
73+
else:
74+
for token_str in list(ckpt.keys()):
75+
embedding = ckpt[token_str]
76+
self.add_textual_inversion(token_str, embedding, full_precision)
77+
78+
def add_textual_inversion(self, token_str, embedding) -> int:
79+
"""
80+
Add a textual inversion to be recognised.
81+
:param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
82+
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
83+
:return: The token id for the added embedding, either existing or newly-added.
84+
"""
85+
if token_str in [ti.trigger_string for ti in self.textual_inversions]:
86+
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{token_str}'")
87+
return
88+
if len(embedding.shape) == 1:
89+
embedding = embedding.unsqueeze(0)
90+
elif len(embedding.shape) > 2:
91+
raise ValueError(f"embedding shape {embedding.shape} is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2")
92+
93+
existing_token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
94+
if existing_token_id == self.clip_embedder.tokenizer.unk_token_id:
95+
num_tokens_added = self.clip_embedder.tokenizer.add_tokens(token_str)
96+
current_embeddings = self.clip_embedder.transformer.resize_token_embeddings(None)
97+
current_token_count = current_embeddings.num_embeddings
98+
new_token_count = current_token_count + num_tokens_added
99+
self.clip_embedder.transformer.resize_token_embeddings(new_token_count)
100+
101+
token_id = get_clip_token_id_for_string(self.clip_embedder.tokenizer, token_str)
102+
self.textual_inversions.append(TextualInversion(
103+
trigger_string=token_str,
104+
token_id=token_id,
105+
embedding=embedding
106+
))
107+
return token_id
108+
109+
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
110+
try:
111+
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
112+
return ti is not None
113+
except StopIteration:
114+
return False
115+
116+
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
117+
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
118+
119+
120+
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
121+
return next(ti for ti in self.textual_inversions if ti.token_id == token_id)
122+
123+
def expand_textual_inversion_token_ids(self, prompt_token_ids: list[int]) -> list[int]:
124+
"""
125+
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
126+
127+
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
128+
:param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
129+
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
130+
long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
131+
"""
132+
if prompt_token_ids[0] == self.clip_embedder.tokenizer.bos_token_id:
133+
raise ValueError("prompt_token_ids must not start with bos_token_id")
134+
if prompt_token_ids[-1] == self.clip_embedder.tokenizer.eos_token_id:
135+
raise ValueError("prompt_token_ids must not end with eos_token_id")
136+
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
137+
prompt_token_ids = prompt_token_ids[:]
138+
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
139+
if token_id in textual_inversion_token_ids:
140+
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
141+
for pad_idx in range(1, textual_inversion.embedding_vector_length):
142+
prompt_token_ids.insert(i+1, self.clip_embedder.tokenizer.pad_token_id)
143+
144+
return prompt_token_ids
145+
146+
def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], prompt_embeddings: torch.Tensor) -> torch.Tensor:
147+
"""
148+
For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
149+
row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
150+
subsequent rows in `prompt_embeddings` as well.
151+
152+
:param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
153+
>1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers.
154+
:param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
155+
`prompt_token_ids` (i.e., also already expanded).
156+
:return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
157+
"""
158+
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
159+
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
160+
if len(prompt_token_ids) > self.clip_embedder.max_length:
161+
raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})")
162+
if len(prompt_token_ids) < self.clip_embedder.max_length:
163+
raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)")
164+
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
165+
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")
166+
167+
textual_inversion_token_ids = [ti.token_id for ti in self.textual_inversions]
168+
pad_token_id = self.clip_embedder.tokenizer.pad_token_id
169+
overwritten_prompt_embeddings = prompt_embeddings.clone()
170+
for i, token_id in enumerate(prompt_token_ids):
171+
if token_id == pad_token_id:
172+
continue
173+
if token_id in textual_inversion_token_ids:
174+
textual_inversion = next(ti for ti in self.textual_inversions if ti.token_id == token_id)
175+
end_index = min(i + textual_inversion.embedding_vector_length, self.clip_embedder.max_length-1)
176+
count_to_overwrite = end_index - i
177+
for j in range(0, count_to_overwrite):
178+
# only overwrite the textual inversion token id or the padding token id
179+
if prompt_token_ids[i+j] != pad_token_id and prompt_token_ids[i+j] != token_id:
180+
break
181+
overwritten_prompt_embeddings[i+j] = textual_inversion.embedding[j]
182+
183+
return overwritten_prompt_embeddings
43184

44-
def get_embedding_for_clip_token(embedder, token):
45-
return embedder(token.unsqueeze(0))[0, 0]
46185

47186
class EmbeddingManager(nn.Module):
48187
def __init__(
@@ -77,38 +216,36 @@ def __init__(
77216
embedder, 'tokenizer'
78217
): # using Stable Diffusion's CLIP encoder
79218
self.is_clip = True
80-
get_token_for_string = partial(
81-
get_clip_token_for_string, embedder.tokenizer
219+
get_token_id_for_string = partial(
220+
get_clip_token_id_for_string, embedder.tokenizer
82221
)
83-
get_embedding_for_tkn = partial(
84-
get_embedding_for_clip_token,
222+
get_embedding_for_tkn_id = partial(
223+
get_embedding_for_clip_token_id,
85224
embedder.transformer.text_model.embeddings,
86225
)
87226
# per bug report #572
88227
#token_dim = 1280
89228
token_dim = 768
90229
else: # using LDM's BERT encoder
91230
self.is_clip = False
92-
get_token_for_string = partial(
93-
get_bert_token_for_string, embedder.tknz_fn
231+
get_token_id_for_string = partial(
232+
get_bert_token_id_for_string, embedder.tknz_fn
94233
)
95-
get_embedding_for_tkn = embedder.transformer.token_emb
234+
get_embedding_for_tkn_id = embedder.transformer.token_emb
96235
token_dim = 1280
97236

98237
if per_image_tokens:
99238
placeholder_strings.extend(per_img_token_list)
100239

101240
for idx, placeholder_string in enumerate(placeholder_strings):
102241

103-
token = get_token_for_string(placeholder_string)
242+
token_id = get_token_id_for_string(placeholder_string)
104243

105244
if initializer_words and idx < len(initializer_words):
106-
init_word_token = get_token_for_string(initializer_words[idx])
245+
init_word_token_id = get_token_id_for_string(initializer_words[idx])
107246

108247
with torch.no_grad():
109-
init_word_embedding = get_embedding_for_tkn(
110-
init_word_token.cpu()
111-
)
248+
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
112249

113250
token_params = torch.nn.Parameter(
114251
init_word_embedding.unsqueeze(0).repeat(
@@ -132,7 +269,7 @@ def __init__(
132269
)
133270
)
134271

135-
self.string_to_token_dict[placeholder_string] = token
272+
self.string_to_token_dict[placeholder_string] = token_id
136273
self.string_to_param_dict[placeholder_string] = token_params
137274

138275
def forward(
@@ -241,7 +378,7 @@ def load(self, ckpt_paths, full=True):
241378
# both will be stored in this dictionary
242379
for term in self.string_to_param_dict.keys():
243380
term = term.strip('<').strip('>')
244-
self.concepts_loaded[term] = True
381+
self.concepts_loaded[term] = True
245382
print(f'>> Current embedding manager terms: {", ".join(self.string_to_param_dict.keys())}')
246383

247384
def _expand_directories(self, paths:list[str]):
@@ -262,7 +399,7 @@ def _load(self, ckpt_path, full=True):
262399
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
263400
print('### For your safety, InvokeAI will not load this embed.')
264401
return
265-
402+
266403
ckpt = torch.load(ckpt_path, map_location='cpu')
267404

268405
# Handle .pt textual inversion files
@@ -292,14 +429,16 @@ def add_embedding(self, token_str, embedding, full):
292429
if len(embedding.shape) == 1:
293430
embedding = embedding.unsqueeze(0)
294431

295-
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
296-
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
297-
current_token_count = current_embeddings.num_embeddings
298-
new_token_count = current_token_count + num_tokens_added
299-
self.embedder.transformer.resize_token_embeddings(new_token_count)
432+
existing_token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
433+
if existing_token_id == self.embedder.tokenizer.unk_token_id:
434+
num_tokens_added = self.embedder.tokenizer.add_tokens(token_str)
435+
current_embeddings = self.embedder.transformer.resize_token_embeddings(None)
436+
current_token_count = current_embeddings.num_embeddings
437+
new_token_count = current_token_count + num_tokens_added
438+
self.embedder.transformer.resize_token_embeddings(new_token_count)
300439

301-
token = get_clip_token_for_string(self.embedder.tokenizer, token_str)
302-
self.string_to_token_dict[token_str] = token
440+
token_id = get_clip_token_id_for_string(self.embedder.tokenizer, token_str)
441+
self.string_to_token_dict[token_str] = token_id
303442
self.string_to_param_dict[token_str] = torch.nn.Parameter(embedding)
304443

305444
def has_embedding_for_token(self, token_str):

0 commit comments

Comments
 (0)