15
15
PROGRESSIVE_SCALE = 2000
16
16
17
17
18
- def get_clip_token_id_for_string (tokenizer : CLIPTokenizer , token_str : str ):
18
+ def get_clip_token_id_for_string (tokenizer : CLIPTokenizer , token_str : str ) -> int :
19
19
token_id = tokenizer .convert_tokens_to_ids (token_str )
20
20
return token_id
21
21
22
- def get_bert_token_for_string (tokenizer , string ):
22
+ def get_bert_token_id_for_string (tokenizer , string ) -> int :
23
23
token = tokenizer (string )
24
24
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
25
-
26
25
token = token [0 , 1 ]
27
-
28
- return token
26
+ return token .item ()
29
27
30
28
31
- def get_embedding_for_clip_token (embedder , token ):
32
- return embedder (token .unsqueeze (0 ))[0 , 0 ]
29
+ def get_embedding_for_clip_token_id (embedder , token_id ):
30
+ if type (token_id ) is not torch .Tensor :
31
+ token_id = torch .tensor (token_id , dtype = torch .int )
32
+ return embedder (token_id .unsqueeze (0 ))[0 , 0 ]
33
33
34
34
@dataclass
35
35
class TextualInversion :
@@ -183,9 +183,6 @@ def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], pr
183
183
return overwritten_prompt_embeddings
184
184
185
185
186
-
187
-
188
-
189
186
class EmbeddingManager (nn .Module ):
190
187
def __init__ (
191
188
self ,
@@ -222,8 +219,8 @@ def __init__(
222
219
get_token_id_for_string = partial (
223
220
get_clip_token_id_for_string , embedder .tokenizer
224
221
)
225
- get_embedding_for_tkn = partial (
226
- get_embedding_for_clip_token ,
222
+ get_embedding_for_tkn_id = partial (
223
+ get_embedding_for_clip_token_id ,
227
224
embedder .transformer .text_model .embeddings ,
228
225
)
229
226
# per bug report #572
@@ -232,9 +229,9 @@ def __init__(
232
229
else : # using LDM's BERT encoder
233
230
self .is_clip = False
234
231
get_token_id_for_string = partial (
235
- get_bert_token_for_string , embedder .tknz_fn
232
+ get_bert_token_id_for_string , embedder .tknz_fn
236
233
)
237
- get_embedding_for_tkn = embedder .transformer .token_emb
234
+ get_embedding_for_tkn_id = embedder .transformer .token_emb
238
235
token_dim = 1280
239
236
240
237
if per_image_tokens :
@@ -248,9 +245,7 @@ def __init__(
248
245
init_word_token_id = get_token_id_for_string (initializer_words [idx ])
249
246
250
247
with torch .no_grad ():
251
- init_word_embedding = get_embedding_for_tkn (
252
- init_word_token_id .cpu ()
253
- )
248
+ init_word_embedding = get_embedding_for_tkn_id (init_word_token_id )
254
249
255
250
token_params = torch .nn .Parameter (
256
251
init_word_embedding .unsqueeze (0 ).repeat (
0 commit comments