-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Description
Here lookahead_generation doesn't take logits_warper as input:
PainlessInferenceAcceleration/pia/lookahead/common/pretrained_model_batch.py
Lines 426 to 439 in 8015f12
| elif generation_mode == GenerationMode.LOOKAHEAD_GENERATION: | |
| # 11. run greedy search | |
| return self.lookahead_generation( | |
| input_ids, | |
| logits_processor=logits_processor, | |
| stopping_criteria=stopping_criteria, | |
| pad_token_id=generation_config.pad_token_id, | |
| eos_token_id=generation_config.eos_token_id, | |
| output_scores=generation_config.output_scores, | |
| return_dict_in_generate=generation_config.return_dict_in_generate, | |
| synced_gpus=synced_gpus, | |
| streamer=streamer, | |
| **model_kwargs, | |
| ) |
logits_warper is used in original sample to modify next_tokens_scores:
PainlessInferenceAcceleration/pia/lookahead/common/pretrained_model_batch.py
Lines 474 to 486 in 8015f12
| return self.sample( | |
| input_ids, | |
| logits_processor=logits_processor, | |
| logits_warper=logits_warper, | |
| stopping_criteria=stopping_criteria, | |
| pad_token_id=generation_config.pad_token_id, | |
| eos_token_id=generation_config.eos_token_id, | |
| output_scores=generation_config.output_scores, | |
| return_dict_in_generate=generation_config.return_dict_in_generate, | |
| synced_gpus=synced_gpus, | |
| streamer=streamer, | |
| **model_kwargs, | |
| ) |
and to modifies logits by temperature, top_k, top_p...
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))This is not applied inside lookahead_generation. So with do_sample=True the temperature is always one
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels