diff --git a/gemma/run.cc b/gemma/run.cc index 8d005304..95dec0d7 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -200,15 +200,14 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache, config.wrapping, abs_pos, prompt_string, image_tokens.Rows()); runtime_config.image_tokens = &image_tokens; + // PrefixLM sees/attends to all tokens. + runtime_config.prefill_tbatch_size = prompt.size(); + prompt_size = prompt.size() - image_tokens.Rows(); if (config.wrapping == PromptWrapping::PALIGEMMA) { // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. prefix_end = prompt_size; - // We need to look at all the tokens for the prefix. - // NOTE: Online softmax is on the roadmap, after which this requirement - // can be lifted. - runtime_config.prefill_tbatch_size = prompt_size; } } else { prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),