diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index f8df7e00..0401a1f5 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -1739,7 +1739,8 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, size_t tasks_required = params.back().v_tile_size < k4xNFVTileSize ? k4xNFVTileSize : kVTileSize; - if (params.back().v_tile_size + tasks_remaining < tasks_required || + if ((params.back().v_tile_size + tasks_remaining < tasks_required && + params.back().v_tile_size > 0) || params.back().v_tile_size == kVTileSize) { // We don't have enough tasks remaining to fill a tile, or the // current tile is full so start new tile.