Skip to content

Commit 023df37

Browse files
committed
cleanup
1 parent 05fac59 commit 023df37

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

ldm/modules/embedding_manager.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,21 @@
1515
PROGRESSIVE_SCALE = 2000
1616

1717

18-
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str):
18+
def get_clip_token_id_for_string(tokenizer: CLIPTokenizer, token_str: str) -> int:
1919
token_id = tokenizer.convert_tokens_to_ids(token_str)
2020
return token_id
2121

22-
def get_bert_token_for_string(tokenizer, string):
22+
def get_bert_token_id_for_string(tokenizer, string) -> int:
2323
token = tokenizer(string)
2424
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
25-
2625
token = token[0, 1]
27-
28-
return token
26+
return token.item()
2927

3028

31-
def get_embedding_for_clip_token(embedder, token):
32-
return embedder(token.unsqueeze(0))[0, 0]
29+
def get_embedding_for_clip_token_id(embedder, token_id):
30+
if type(token_id) is not torch.Tensor:
31+
token_id = torch.tensor(token_id, dtype=torch.int)
32+
return embedder(token_id.unsqueeze(0))[0, 0]
3333

3434
@dataclass
3535
class TextualInversion:
@@ -183,9 +183,6 @@ def overwrite_textual_inversion_embeddings(self, prompt_token_ids: list[int], pr
183183
return overwritten_prompt_embeddings
184184

185185

186-
187-
188-
189186
class EmbeddingManager(nn.Module):
190187
def __init__(
191188
self,
@@ -222,8 +219,8 @@ def __init__(
222219
get_token_id_for_string = partial(
223220
get_clip_token_id_for_string, embedder.tokenizer
224221
)
225-
get_embedding_for_tkn = partial(
226-
get_embedding_for_clip_token,
222+
get_embedding_for_tkn_id = partial(
223+
get_embedding_for_clip_token_id,
227224
embedder.transformer.text_model.embeddings,
228225
)
229226
# per bug report #572
@@ -232,9 +229,9 @@ def __init__(
232229
else: # using LDM's BERT encoder
233230
self.is_clip = False
234231
get_token_id_for_string = partial(
235-
get_bert_token_for_string, embedder.tknz_fn
232+
get_bert_token_id_for_string, embedder.tknz_fn
236233
)
237-
get_embedding_for_tkn = embedder.transformer.token_emb
234+
get_embedding_for_tkn_id = embedder.transformer.token_emb
238235
token_dim = 1280
239236

240237
if per_image_tokens:
@@ -248,9 +245,7 @@ def __init__(
248245
init_word_token_id = get_token_id_for_string(initializer_words[idx])
249246

250247
with torch.no_grad():
251-
init_word_embedding = get_embedding_for_tkn(
252-
init_word_token_id.cpu()
253-
)
248+
init_word_embedding = get_embedding_for_tkn_id(init_word_token_id)
254249

255250
token_params = torch.nn.Parameter(
256251
init_word_embedding.unsqueeze(0).repeat(

0 commit comments

Comments
 (0)