1
1
import os .path
2
2
from cmath import log
3
3
import torch
4
+ from attr import dataclass
4
5
from torch import nn
5
6
6
7
import sys
14
15
PROGRESSIVE_SCALE = 2000
15
16
16
17
17
- def get_clip_token_for_string (tokenizer , string ):
18
- batch_encoding = tokenizer (
19
- string ,
20
- truncation = True ,
21
- max_length = 77 ,
22
- return_length = True ,
23
- return_overflowing_tokens = False ,
24
- padding = 'max_length' ,
25
- return_tensors = 'pt' ,
26
- )
27
- tokens = batch_encoding ['input_ids' ]
28
- """ assert (
29
- torch.count_nonzero(tokens - 49407) == 2
30
- ), f"String '{string}' maps to more than a single token. Please use another string" """
18
+ def get_clip_token_id_for_string (tokenizer : CLIPTokenizer , token_str : str ) -> int :
19
+ token_id = tokenizer .convert_tokens_to_ids (token_str )
20
+ return token_id
31
21
32
- return tokens [0 , 1 ]
33
-
34
-
35
- def get_bert_token_for_string (tokenizer , string ):
22
+ def get_bert_token_id_for_string (tokenizer , string ) -> int :
36
23
token = tokenizer (string )
37
24
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
38
-
39
25
token = token [0 , 1 ]
26
+ return token .item ()
27
+
28
+
29
+ def get_embedding_for_clip_token_id (embedder , token_id ):
30
+ if type (token_id ) is not torch .Tensor :
31
+ token_id = torch .tensor (token_id , dtype = torch .int )
32
+ return embedder (token_id .unsqueeze (0 ))[0 , 0 ]
33
+
34
+ @dataclass
35
+ class TextualInversion :
36
+ trigger_string : str
37
+ token_id : int
38
+ embedding : torch .Tensor
39
+
40
+ @property
41
+ def embedding_vector_length (self ) -> int :
42
+ return self .embedding .shape [0 ]
43
+
44
+ class TextualInversionManager ():
45
+ def __init__ (self , clip_embedder ):
46
+ self .clip_embedder = clip_embedder
47
+ default_textual_inversions : list [TextualInversion ] = []
48
+ self .textual_inversions = default_textual_inversions
49
+
50
+ def load_textual_inversion (self , ckpt_path , full_precision = True ):
51
+
52
+ scan_result = scan_file_path (ckpt_path )
53
+ if scan_result .infected_files == 1 :
54
+ print (f'\n ### Security Issues Found in Model: { scan_result .issues_count } ' )
55
+ print ('### For your safety, InvokeAI will not load this embed.' )
56
+ return
57
+
58
+ ckpt = torch .load (ckpt_path , map_location = 'cpu' )
59
+
60
+ # Handle .pt textual inversion files
61
+ if 'string_to_token' in ckpt and 'string_to_param' in ckpt :
62
+ filename = os .path .basename (ckpt_path )
63
+ token_str = '.' .join (filename .split ('.' )[:- 1 ]) # filename excluding extension
64
+ if len (ckpt ["string_to_token" ]) > 1 :
65
+ print (f">> { ckpt_path } has >1 embedding, only the first will be used" )
40
66
41
- return token
67
+ string_to_param_dict = ckpt ['string_to_param' ]
68
+ embedding = list (string_to_param_dict .values ())[0 ]
69
+ self .add_textual_inversion (token_str , embedding , full_precision )
42
70
71
+ # Handle .bin textual inversion files from Huggingface Concepts
72
+ # https://huggingface.co/sd-concepts-library
73
+ else :
74
+ for token_str in list (ckpt .keys ()):
75
+ embedding = ckpt [token_str ]
76
+ self .add_textual_inversion (token_str , embedding , full_precision )
77
+
78
+ def add_textual_inversion (self , token_str , embedding ) -> int :
79
+ """
80
+ Add a textual inversion to be recognised.
81
+ :param token_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
82
+ :param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
83
+ :return: The token id for the added embedding, either existing or newly-added.
84
+ """
85
+ if token_str in [ti .trigger_string for ti in self .textual_inversions ]:
86
+ print (f">> TextualInversionManager refusing to overwrite already-loaded token '{ token_str } '" )
87
+ return
88
+ if len (embedding .shape ) == 1 :
89
+ embedding = embedding .unsqueeze (0 )
90
+ elif len (embedding .shape ) > 2 :
91
+ raise ValueError (f"embedding shape { embedding .shape } is incorrect - must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2" )
92
+
93
+ existing_token_id = get_clip_token_id_for_string (self .clip_embedder .tokenizer , token_str )
94
+ if existing_token_id == self .clip_embedder .tokenizer .unk_token_id :
95
+ num_tokens_added = self .clip_embedder .tokenizer .add_tokens (token_str )
96
+ current_embeddings = self .clip_embedder .transformer .resize_token_embeddings (None )
97
+ current_token_count = current_embeddings .num_embeddings
98
+ new_token_count = current_token_count + num_tokens_added
99
+ self .clip_embedder .transformer .resize_token_embeddings (new_token_count )
100
+
101
+ token_id = get_clip_token_id_for_string (self .clip_embedder .tokenizer , token_str )
102
+ self .textual_inversions .append (TextualInversion (
103
+ trigger_string = token_str ,
104
+ token_id = token_id ,
105
+ embedding = embedding
106
+ ))
107
+ return token_id
108
+
109
+ def has_textual_inversion_for_trigger_string (self , trigger_string : str ) -> bool :
110
+ try :
111
+ ti = self .get_textual_inversion_for_trigger_string (trigger_string )
112
+ return ti is not None
113
+ except StopIteration :
114
+ return False
115
+
116
+ def get_textual_inversion_for_trigger_string (self , trigger_string : str ) -> TextualInversion :
117
+ return next (ti for ti in self .textual_inversions if ti .trigger_string == trigger_string )
118
+
119
+
120
+ def get_textual_inversion_for_token_id (self , token_id : int ) -> TextualInversion :
121
+ return next (ti for ti in self .textual_inversions if ti .token_id == token_id )
122
+
123
+ def expand_textual_inversion_token_ids (self , prompt_token_ids : list [int ]) -> list [int ]:
124
+ """
125
+ Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
126
+
127
+ :param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
128
+ :param pad_token_id: The token id to use to pad out the list to account for textual inversion vector lengths >1.
129
+ :return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
130
+ long - caller is reponsible for truncating it if necessary and prepending/appending eos and bos token ids.
131
+ """
132
+ if prompt_token_ids [0 ] == self .clip_embedder .tokenizer .bos_token_id :
133
+ raise ValueError ("prompt_token_ids must not start with bos_token_id" )
134
+ if prompt_token_ids [- 1 ] == self .clip_embedder .tokenizer .eos_token_id :
135
+ raise ValueError ("prompt_token_ids must not end with eos_token_id" )
136
+ textual_inversion_token_ids = [ti .token_id for ti in self .textual_inversions ]
137
+ prompt_token_ids = prompt_token_ids [:]
138
+ for i , token_id in reversed (list (enumerate (prompt_token_ids ))):
139
+ if token_id in textual_inversion_token_ids :
140
+ textual_inversion = next (ti for ti in self .textual_inversions if ti .token_id == token_id )
141
+ for pad_idx in range (1 , textual_inversion .embedding_vector_length ):
142
+ prompt_token_ids .insert (i + 1 , self .clip_embedder .tokenizer .pad_token_id )
143
+
144
+ return prompt_token_ids
145
+
146
+ def overwrite_textual_inversion_embeddings (self , prompt_token_ids : list [int ], prompt_embeddings : torch .Tensor ) -> torch .Tensor :
147
+ """
148
+ For each token id in prompt_token_ids that refers to a loaded textual inversion, overwrite the corresponding
149
+ row in `prompt_embeddings` with the textual inversion embedding. If the embedding has vector length >1, overwrite
150
+ subsequent rows in `prompt_embeddings` as well.
151
+
152
+ :param `prompt_token_ids`: Prompt token ids, already expanded to account for any textual inversions with vector lenght
153
+ >1 (call `expand_textual_inversion_token_ids()` to do this) and including bos and eos markers.
154
+ :param `prompt_embeddings`: Prompt embeddings tensor of shape with indices aligning to token ids in
155
+ `prompt_token_ids` (i.e., also already expanded).
156
+ :return: `The prompt_embeddings` tensor overwritten as appropriate with the textual inversion embeddings.
157
+ """
158
+ if prompt_embeddings .shape [0 ] != self .clip_embedder .max_length : # typically 77
159
+ raise ValueError (f"prompt_embeddings must have { self .clip_embedder .max_length } entries (has: { prompt_embeddings .shape [0 ]} )" )
160
+ if len (prompt_token_ids ) > self .clip_embedder .max_length :
161
+ raise ValueError (f"prompt_token_ids is too long (has { len (prompt_token_ids )} token ids, should have { self .clip_embedder .max_length } )" )
162
+ if len (prompt_token_ids ) < self .clip_embedder .max_length :
163
+ raise ValueError (f"prompt_token_ids is too short (has { len (prompt_token_ids )} token ids, it must be fully padded out to { self .clip_embedder .max_length } entries)" )
164
+ if prompt_token_ids [0 ] != self .clip_embedder .tokenizer .bos_token_id or prompt_token_ids [- 1 ] != self .clip_embedder .tokenizer .eos_token_id :
165
+ raise ValueError ("prompt_token_ids must start with with bos token id and end with the eos token id" )
166
+
167
+ textual_inversion_token_ids = [ti .token_id for ti in self .textual_inversions ]
168
+ pad_token_id = self .clip_embedder .tokenizer .pad_token_id
169
+ overwritten_prompt_embeddings = prompt_embeddings .clone ()
170
+ for i , token_id in enumerate (prompt_token_ids ):
171
+ if token_id == pad_token_id :
172
+ continue
173
+ if token_id in textual_inversion_token_ids :
174
+ textual_inversion = next (ti for ti in self .textual_inversions if ti .token_id == token_id )
175
+ end_index = min (i + textual_inversion .embedding_vector_length , self .clip_embedder .max_length - 1 )
176
+ count_to_overwrite = end_index - i
177
+ for j in range (0 , count_to_overwrite ):
178
+ # only overwrite the textual inversion token id or the padding token id
179
+ if prompt_token_ids [i + j ] != pad_token_id and prompt_token_ids [i + j ] != token_id :
180
+ break
181
+ overwritten_prompt_embeddings [i + j ] = textual_inversion .embedding [j ]
182
+
183
+ return overwritten_prompt_embeddings
43
184
44
- def get_embedding_for_clip_token (embedder , token ):
45
- return embedder (token .unsqueeze (0 ))[0 , 0 ]
46
185
47
186
class EmbeddingManager (nn .Module ):
48
187
def __init__ (
@@ -77,38 +216,36 @@ def __init__(
77
216
embedder , 'tokenizer'
78
217
): # using Stable Diffusion's CLIP encoder
79
218
self .is_clip = True
80
- get_token_for_string = partial (
81
- get_clip_token_for_string , embedder .tokenizer
219
+ get_token_id_for_string = partial (
220
+ get_clip_token_id_for_string , embedder .tokenizer
82
221
)
83
- get_embedding_for_tkn = partial (
84
- get_embedding_for_clip_token ,
222
+ get_embedding_for_tkn_id = partial (
223
+ get_embedding_for_clip_token_id ,
85
224
embedder .transformer .text_model .embeddings ,
86
225
)
87
226
# per bug report #572
88
227
#token_dim = 1280
89
228
token_dim = 768
90
229
else : # using LDM's BERT encoder
91
230
self .is_clip = False
92
- get_token_for_string = partial (
93
- get_bert_token_for_string , embedder .tknz_fn
231
+ get_token_id_for_string = partial (
232
+ get_bert_token_id_for_string , embedder .tknz_fn
94
233
)
95
- get_embedding_for_tkn = embedder .transformer .token_emb
234
+ get_embedding_for_tkn_id = embedder .transformer .token_emb
96
235
token_dim = 1280
97
236
98
237
if per_image_tokens :
99
238
placeholder_strings .extend (per_img_token_list )
100
239
101
240
for idx , placeholder_string in enumerate (placeholder_strings ):
102
241
103
- token = get_token_for_string (placeholder_string )
242
+ token_id = get_token_id_for_string (placeholder_string )
104
243
105
244
if initializer_words and idx < len (initializer_words ):
106
- init_word_token = get_token_for_string (initializer_words [idx ])
245
+ init_word_token_id = get_token_id_for_string (initializer_words [idx ])
107
246
108
247
with torch .no_grad ():
109
- init_word_embedding = get_embedding_for_tkn (
110
- init_word_token .cpu ()
111
- )
248
+ init_word_embedding = get_embedding_for_tkn_id (init_word_token_id )
112
249
113
250
token_params = torch .nn .Parameter (
114
251
init_word_embedding .unsqueeze (0 ).repeat (
@@ -132,7 +269,7 @@ def __init__(
132
269
)
133
270
)
134
271
135
- self .string_to_token_dict [placeholder_string ] = token
272
+ self .string_to_token_dict [placeholder_string ] = token_id
136
273
self .string_to_param_dict [placeholder_string ] = token_params
137
274
138
275
def forward (
@@ -241,7 +378,7 @@ def load(self, ckpt_paths, full=True):
241
378
# both will be stored in this dictionary
242
379
for term in self .string_to_param_dict .keys ():
243
380
term = term .strip ('<' ).strip ('>' )
244
- self .concepts_loaded [term ] = True
381
+ self .concepts_loaded [term ] = True
245
382
print (f'>> Current embedding manager terms: { ", " .join (self .string_to_param_dict .keys ())} ' )
246
383
247
384
def _expand_directories (self , paths :list [str ]):
@@ -262,7 +399,7 @@ def _load(self, ckpt_path, full=True):
262
399
print (f'\n ### Security Issues Found in Model: { scan_result .issues_count } ' )
263
400
print ('### For your safety, InvokeAI will not load this embed.' )
264
401
return
265
-
402
+
266
403
ckpt = torch .load (ckpt_path , map_location = 'cpu' )
267
404
268
405
# Handle .pt textual inversion files
@@ -292,14 +429,16 @@ def add_embedding(self, token_str, embedding, full):
292
429
if len (embedding .shape ) == 1 :
293
430
embedding = embedding .unsqueeze (0 )
294
431
295
- num_tokens_added = self .embedder .tokenizer .add_tokens (token_str )
296
- current_embeddings = self .embedder .transformer .resize_token_embeddings (None )
297
- current_token_count = current_embeddings .num_embeddings
298
- new_token_count = current_token_count + num_tokens_added
299
- self .embedder .transformer .resize_token_embeddings (new_token_count )
432
+ existing_token_id = get_clip_token_id_for_string (self .embedder .tokenizer , token_str )
433
+ if existing_token_id == self .embedder .tokenizer .unk_token_id :
434
+ num_tokens_added = self .embedder .tokenizer .add_tokens (token_str )
435
+ current_embeddings = self .embedder .transformer .resize_token_embeddings (None )
436
+ current_token_count = current_embeddings .num_embeddings
437
+ new_token_count = current_token_count + num_tokens_added
438
+ self .embedder .transformer .resize_token_embeddings (new_token_count )
300
439
301
- token = get_clip_token_for_string (self .embedder .tokenizer , token_str )
302
- self .string_to_token_dict [token_str ] = token
440
+ token_id = get_clip_token_id_for_string (self .embedder .tokenizer , token_str )
441
+ self .string_to_token_dict [token_str ] = token_id
303
442
self .string_to_param_dict [token_str ] = torch .nn .Parameter (embedding )
304
443
305
444
def has_embedding_for_token (self , token_str ):
0 commit comments