Skip to content

Adding Prompt lookup decoding #27775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,98 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)


class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
likely continuations in the provided prompt (input_ids) itself.
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding

Args:
max_matching_ngram_size (`int`):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
The number of tokens to be output as candidate tokens.
"""

def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = 2,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size

if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)

Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
"""
input_length = input_ids.size(1)

chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

# Convert ngram to a tensor for comparison
ngram_tensor = input_ids[0, -ngram_size:]

# Find where the windows match the ngram
matches = (windows == ngram_tensor).all(dim=2)

# Get the indices of matches
match_indices = matches.nonzero(as_tuple=True)[1]

# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length)

if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
match_found = True
break
if match_found:
break

if chosen_ids is None or len(chosen_ids) == 0:
# Need to make a dummy tensor to avoid errors
chosen_ids = torch.zeros((1), dtype=torch.long, device=input_ids.device)

# Now need extend input_ids with chosen_ids
chosen_ids = chosen_ids.unsqueeze(0)
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
# assisted_generation expects logits as well, but we don't have those here, so returning None
return candidate_input_ids, None

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Currently does nothing
return


def _crop_past_key_values(model, past_key_values, maximum_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

Expand Down
24 changes: 15 additions & 9 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .candidate_generator import (
AssistedCandidateGenerator,
CandidateGenerator,
PromptLookupCandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
Expand Down Expand Up @@ -908,14 +909,19 @@ def _get_candidate_generator(
"""
Returns the candidate generator to be used in `assisted_generation`
"""
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
)
if generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
num_output_tokens=generation_config.prompt_lookup_num_tokens,
)
else:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
)
return candidate_generator

def _get_logits_warper(
Expand Down Expand Up @@ -995,7 +1001,7 @@ def _get_generation_mode(
generation_mode = GenerationMode.BEAM_SEARCH

# Assisted generation may extend some generation modes
if assistant_model is not None:
if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
Expand Down
60 changes: 60 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,66 @@ def test_assisted_decoding_matches_greedy_search(self):
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@is_flaky()
def test_prompt_lookup_decoding_matches_greedy_search(self):
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search

for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bigbirdpegasus",
"led",
"mega",
"speech2text",
"git",
"prophetnet",
"seamlessm4t",
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")

# enable cache
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")

config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of
# prompt lookup is correct
# c) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see c)
"num_beams": 1,
"do_sample": False,
"output_scores": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
}

output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, input_ids, model.config, use_cache=True)

def test_assisted_decoding_sample(self):
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
Expand Down