From 49c14ff0926809c924e5da4cde53d4f977fcd436 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Mon, 23 Feb 2026 06:45:39 -0800 Subject: [PATCH] Int8 + microscaling support for kv cache formats. Right now multiplication is done by converting to corresponding float format. Can yield up to 2x improvements for membw constrained shapes PiperOrigin-RevId: 874047973 --- compression/compress-inl.h | 136 ++++++++++++++++++++++++++++++++++ compression/compress_test.cc | 4 + compression/types.h | 15 +++- gemma/flash_attention.cc | 70 +++++++++++++++++ gemma/flash_attention_test.cc | 136 ++++++++++++++++++++++++++++++++++ gemma/gemma_args.h | 1 + gemma/kv_cache.cc | 6 +- gemma/kv_cache.h | 1 + gemma/tiled_attention.cc | 93 +++++++++++++++++++++-- gemma/tiled_attention_test.cc | 114 +++++++++++++++++++++++++++- util/mat.h | 5 ++ 11 files changed, 565 insertions(+), 16 deletions(-) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index e7bb9d68..5f09df39 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -444,6 +444,142 @@ struct CompressTraits { } }; +template <> +struct CompressTraits { + using Packed = int8_t; + + static size_t CompressBound(size_t num) { return num * sizeof(Packed); } + + template + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, + size_t num, CompressPerThread& /*tls*/, + const PackedSpan& packed, + const size_t packed_ofs) { + const hn::Repartition di32; + const hn::Repartition di16; + const hn::Repartition di8; + const auto di16_16 = hn::Half(); + const auto di8_16 = hn::Half(); + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + const VF v0 = hn::LoadU(df, raw + i); + const VF v1 = hn::LoadU(df, raw + i + NF); + const auto vi32_0 = hn::NearestInt(v0); + const auto vi32_1 = hn::NearestInt(v1); + const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1); + const auto vi8 = hn::OrderedDemote2To( + di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16)); + hn::StoreU(vi8, di8_16, packed.ptr + packed_ofs + i); + } + } + const size_t remaining = num - i; + if (remaining > 0) { + HWY_ALIGN float buf[2 * NF]; + hwy::ZeroBytes(buf, 2 * NF * sizeof(float)); + for (size_t j = 0; j < remaining; ++j) buf[j] = raw[i + j]; + const VF v0 = hn::LoadU(df, buf); + const VF v1 = hn::LoadU(df, buf + NF); + const auto vi32_0 = hn::NearestInt(v0); + const auto vi32_1 = hn::NearestInt(v1); + const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1); + const auto vi8 = hn::OrderedDemote2To( + di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16)); + hn::StoreN(vi8, di8_16, packed.ptr + packed_ofs + i, remaining); + } + } + + static float ToFloatSlow(const Packed x) { return static_cast(x); } + + template + static HWY_INLINE void Load2(DF df, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition di32; + const hn::Repartition di16; + const hn::Rebind di8_half; + + const auto vec_i8 = hn::LoadU(di8_half, packed.ptr + packed_ofs); + const auto vec_i16 = hn::PromoteTo(di16, vec_i8); + const auto vec_i32_0 = hn::PromoteLowerTo(di32, vec_i16); + const auto vec_i32_1 = hn::PromoteUpperTo(di32, vec_i16); + + raw0 = hn::ConvertTo(df, vec_i32_0); + raw1 = hn::ConvertTo(df, vec_i32_1); + } + + template + static HWY_INLINE void Load2(DBF dbf, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + VF f0, f1, f2, f3; + Load2(df, packed, packed_ofs, f0, f1); + Load2(df, packed, packed_ofs + 2 * NF, f2, f3); + + raw0 = hn::OrderedDemote2To(dbf, f0, f1); + raw1 = hn::OrderedDemote2To(dbf, f2, f3); + } + + template + static HWY_INLINE void DecompressAndZeroPad( + DF df, const PackedSpan& packed, const size_t packed_ofs, + float* HWY_RESTRICT raw, size_t num) { + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF raw0, raw1; + Load2(df, packed, packed_ofs + i, raw0, raw1); + hn::StoreU(raw0, df, raw + i); + hn::StoreU(raw1, df, raw + i + NF); + } + } + + const size_t remaining = num - i; + if (HWY_UNLIKELY(remaining != 0)) { + for (size_t j = 0; j < remaining; ++j) { + raw[i + j] = static_cast(packed.ptr[packed_ofs + i + j]); + } + } + } + + template + static HWY_INLINE void DecompressAndZeroPad( + DBF dbf, const PackedSpan& packed, const size_t packed_ofs, + BF16* HWY_RESTRICT raw, size_t num) { + const hn::Repartition df; + const size_t NF = hn::Lanes(df); + size_t i = 0; + const size_t NBF = hn::Lanes(dbf); + if (num >= NBF) { + for (; i <= num - NBF; i += NBF) { + hn::Vec f0, f1; + Load2(df, packed, packed_ofs + i, f0, f1); + auto vbf = hn::OrderedDemote2To(dbf, f0, f1); + hn::StoreU(vbf, dbf, raw + i); + } + } + const size_t remaining = num - i; + if (remaining > 0) { + HWY_ALIGN float buf[2 * hn::MaxLanes(df)]; + DecompressAndZeroPad(df, packed, packed_ofs + i, buf, remaining); + auto f0 = hn::LoadU(df, buf); + auto f1 = hn::LoadU(df, buf + NF); + auto vbf = hn::OrderedDemote2To(dbf, f0, f1); + hn::StoreN(vbf, dbf, raw + i, remaining); + } + } +}; + // Integer quantization. template <> struct CompressTraits { diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 421492e9..4002184e 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -126,6 +126,8 @@ struct TestDecompress2 { HWY_ASSERT(stats.L1().Max() <= 0.08f); HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1())); HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1())); + } else if constexpr (hwy::IsSame()) { + HWY_ASSERT(stats.L1().Max() <= 0.6f); } else { HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType"); } @@ -200,6 +202,8 @@ struct TestShortLengths { HWY_ASSERT(stats.L1().Max() <= 0.14f); HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1())); HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1())); + } else if constexpr (hwy::IsSame()) { + HWY_ASSERT(stats.L1().Max() <= 0.6f); } else { HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType"); } diff --git a/compression/types.h b/compression/types.h index dc22f4ca..e7f9bda0 100644 --- a/compression/types.h +++ b/compression/types.h @@ -192,6 +192,11 @@ constexpr bool IsF32() { return hwy::IsSame, float>(); } +template +constexpr bool IsInt8() { + return hwy::IsSame, int8_t>(); +} + template constexpr bool IsBF16() { return hwy::IsSame, BF16>(); @@ -231,12 +236,13 @@ enum class Type { kI8, kU16, kU8, + kInt8, }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", - "nuq", "f64", "u32", "u64", - "i8", "u16", "u8"}; +static constexpr const char* kTypeStrings[] = { + "unknown", "f32", "bf16", "sfp", "nuq", "f64", + "u32", "u64", "i8", "u16", "u8", "int8"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -251,6 +257,7 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(I8Stream), 8 * sizeof(uint16_t), 8 * sizeof(uint8_t), + 8 * sizeof(int8_t), }; static inline bool EnumValid(Type type) { @@ -281,6 +288,8 @@ constexpr Type TypeEnum() { return Type::kU16; } else if constexpr (hwy::IsSame()) { return Type::kU8; + } else if constexpr (hwy::IsSame()) { + return Type::kInt8; } else { return Type::kUnknown; } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 0401a1f5..c7cc3dac 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -1260,6 +1260,52 @@ static HWY_NOINLINE void ApplyMasking( } } +template > +static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0, + VF& x0_p1, VF& x1_p0, VF& x1_p1, + VF& x2_p0, VF& x2_p1, VF& x3_p0, + VF& x3_p1, VF& x4_p0, VF& x4_p1, + VF& x5_p0, VF& x5_p1, VF& x6_p0, + VF& x6_p1, VF& x7_p0, VF& x7_p1) { + const size_t kTileSize = hn::Lanes(df); + const PackedSpan scales_span = + MakeConstSpan(scales, 2 * kTileSize); + VF scales_p0, scales_p1; + Decompress2(df, scales_span, 0, scales_p0, scales_p1); + if constexpr (kNumQueries >= 1) { + x0_p0 = hn::Mul(x0_p0, scales_p0); + x0_p1 = hn::Mul(x0_p1, scales_p1); + } + if constexpr (kNumQueries >= 2) { + x1_p0 = hn::Mul(x1_p0, scales_p0); + x1_p1 = hn::Mul(x1_p1, scales_p1); + } + if constexpr (kNumQueries >= 3) { + x2_p0 = hn::Mul(x2_p0, scales_p0); + x2_p1 = hn::Mul(x2_p1, scales_p1); + } + if constexpr (kNumQueries >= 4) { + x3_p0 = hn::Mul(x3_p0, scales_p0); + x3_p1 = hn::Mul(x3_p1, scales_p1); + } + if constexpr (kNumQueries >= 5) { + x4_p0 = hn::Mul(x4_p0, scales_p0); + x4_p1 = hn::Mul(x4_p1, scales_p1); + } + if constexpr (kNumQueries >= 6) { + x5_p0 = hn::Mul(x5_p0, scales_p0); + x5_p1 = hn::Mul(x5_p1, scales_p1); + } + if constexpr (kNumQueries >= 7) { + x6_p0 = hn::Mul(x6_p0, scales_p0); + x6_p1 = hn::Mul(x6_p1, scales_p1); + } + if constexpr (kNumQueries >= 8) { + x7_p0 = hn::Mul(x7_p0, scales_p0); + x7_p1 = hn::Mul(x7_p1, scales_p1); + } +} + // Performs tiled flash attention for arbitrary number of queries // It depends on kv being tiled. // Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time). @@ -1400,6 +1446,21 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( false, "Query type type not supported, only float and BF16 are supported"); } + // microscaling + // TODO: Change to more generic function to inform if we should use + // microscaling or not. + constexpr bool kUseMicroScaling = IsInt8(); + if constexpr (kUseMicroScaling) { + // After end of the tile, we have kTileSize * 2 bfloat16 for the + // microscaling scales for K and V. + const BF16* microscaling_scales_k = + reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + + pos_in_tile; + MultiplyByScale(df, microscaling_scales_k, x_0_p_0, x_0_p_1, + x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, + x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, + x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); constexpr int kSecondHalfAmountOfQueries = @@ -1433,6 +1494,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx, kNumQueriesPerGroup); + if constexpr (kUseMicroScaling) { + const BF16* microscaling_scales_v = + reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + + kTileSize + pos_in_tile; + MultiplyByScale(df, microscaling_scales_v, x_0_p_0, x_0_p_1, + x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, + x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, + x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1); + } if constexpr (IsF32()) { MulByConstAndAddTileUpTo8( df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index fd693d98..76be5d6e 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -490,6 +490,139 @@ void TestTiledFlashAttentionBF16() { } } +void TestTiledFlashAttentionInt8() { + int qkv_dim = 64; + int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by + // tiles size to test the padding logic. + int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + float att_cap = 10.0f; + int num_queries = 8; + int num_queries_per_timestep = 4; + int num_tokens = num_queries / num_queries_per_timestep; + int kv_seq_end = + kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + + int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; + int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + + 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; + + MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), + ctx.allocator, MatPadding::kPacked); + + // fill in kvs with predictable, synthetic data + for (int i = 0; i < padded_kv_seq_len; ++i) { + int tile_idx = i / gcpp::KVCache::kTileSize; + int in_tile_offset = i % gcpp::KVCache::kTileSize; + int8_t* tile_ptr = kv.Row(tile_idx); + BF16* scales_ptr = HWY_RCAST_ALIGNED( + BF16*, tile_ptr + 2 * qkv_dim * gcpp::KVCache::kTileSize); + + // Generate float values for K and V + std::vector k_vals(qkv_dim); + std::vector v_vals(qkv_dim); + float max_abs_k = 0.0f; + float max_abs_v = 0.0f; + + for (int j = 0; j < qkv_dim; ++j) { + k_vals[j] = 0.01f * (i + 1) / (j + 1); + v_vals[j] = 0.02f * (i + 1) / (j + 1); + max_abs_k = std::max(max_abs_k, std::abs(k_vals[j])); + max_abs_v = std::max(max_abs_v, std::abs(v_vals[j])); + } + + // Quantize K + float scale_k = max_abs_k / 127.0f; + if (scale_k == 0.0f) scale_k = 1.0f; + scales_ptr[in_tile_offset] = hwy::ConvertScalarTo(scale_k); + for (int j = 0; j < qkv_dim; ++j) { + int val = std::round(k_vals[j] / scale_k); + val = std::max(-127, std::min(127, val)); + tile_ptr[j * gcpp::KVCache::kTileSize + in_tile_offset] = + static_cast(val); + } + + // Quantize V + float scale_v = max_abs_v / 127.0f; + if (scale_v == 0.0f) scale_v = 1.0f; + scales_ptr[gcpp::KVCache::kTileSize + in_tile_offset] = + hwy::ConvertScalarTo(scale_v); + size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; + for (int j = 0; j < qkv_dim; ++j) { + int val = std::round(v_vals[j] / scale_v); + val = std::max(-127, std::min(127, val)); + tile_ptr[v_offset + in_tile_offset * qkv_dim + j] = + static_cast(val); + } + } + + std::vector q_float(4 * qkv_dim); + std::vector q_float2(4 * qkv_dim); + // fill in qs with predictable, synthetic data + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < qkv_dim; j++) { + float val_1 = 0.01f * (i + 1) / (j + 1); + float val_2 = 0.01f * (i + 4 + 1) / (j + 1); + q_float[j * 4 + i] = val_1; + q_float2[j * 4 + i] = val_2; + } + } + const float* q_T[2] = {q_float.data(), q_float2.data()}; + + MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), + ctx.allocator, MatPadding::kPacked); + using DF = hn::ScalableTag; + const DF df; + HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df); + size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes); + std::vector exp_denominator_sums(num_queries_rounded_to_laness); + std::vector max_logits(num_queries_rounded_to_laness); + for (size_t i = 0; i < num_queries; ++i) { + hwy::ZeroBytes(att_out.Row(i), + qkv_dim * sizeof(decltype(att_out.Row(i)[0]))); + exp_denominator_sums[i] = 0.0f; + max_logits[i] = -std::numeric_limits::max() / 2.0f; + } + std::vector> start_pos_per_query; + std::vector> last_pos_per_query; + start_pos_per_query.reserve(num_queries); + last_pos_per_query.reserve(num_queries); + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + ssize_t query_last_pos = kv_seq_end + token_idx; + ssize_t query_start_pos = + std::max(query_last_pos - 100000 + 1, static_cast(0)); + for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + ++q_head_idx) { + start_pos_per_query.push_back(query_start_pos); + last_pos_per_query.push_back(query_last_pos); + } + } + + hwy::Span kvs(&kv, 1); + DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( + kvs, num_queries, hwy::Span(q_T, 2), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + + // TODO: Replace with Other implementation for generating goldens. + // Current values are taken from a point in time where code was run with gemma + // and output looked good. Not ideal but should be good enough to test the + // plumbing and detect regressions. + PrintMatPtr(att_out); + for (int i = 0; i < num_queries; ++i) { + std::cerr << "exp_d: " << exp_denominator_sums[i] + << " max_logit: " << max_logits[i] << std::endl; + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-2f) + << "i=" << i; + EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; + for (int j = 0; j < qkv_dim; ++j) { + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f); + } + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -500,6 +633,9 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(FlashAttentionTest); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttention); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionBF16); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8); HWY_AFTER_TEST(); } // namespace gcpp diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index b7cfcb22..8d6b0b5c 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -152,6 +152,7 @@ struct RuntimeConfig { // If not set, it will be set based on the attention_impl. // F32 for tiled // BF16 for tiled bf16 + // Int8 works for both tiled and tiled bf16. // If you want to use type other than F32 or BF16, you might need to update // call upcasted. std::optional kv_cache_type = {}; diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index f33cd218..967070cc 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -85,7 +85,6 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const size_t num_tiles = hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize); tiled_seq_len = num_tiles * kTileSize; - int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize; Type kv_cache_type; if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 || hwy::IsSame()) { @@ -93,6 +92,11 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, } else { kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32); } + + int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize; + if (kv_cache_type == Type::kInt8) { + tile_length += 2 * sizeof(BF16) * kTileSize; + } auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size, size_t max_seq_len) { return hwy::DivCeil( diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 5fe1f1e9..0ca26b06 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -31,6 +31,7 @@ namespace gcpp { using KV_t = BF16; +using KV_microscale_t = BF16; struct KVCache; // A non-owning view of a KVCache. diff --git a/gemma/tiled_attention.cc b/gemma/tiled_attention.cc index a824dbf0..a5b0e81e 100644 --- a/gemma/tiled_attention.cc +++ b/gemma/tiled_attention.cc @@ -69,6 +69,27 @@ static HWY_INLINE void MergeOnlineSoftmax( accumulator_softmax_d = d_new; } +template +T AbsMaxOfSpan(hwy::Span span) { + hn::ScalableTag dt; + using VT = hn::Vec; + VT max_vec = hn::Set(dt, 0.0f); + const size_t lanes = hn::Lanes(dt); + size_t i = 0; + // Process full vectors using LoadU. + for (; i + lanes <= span.size(); i += lanes) { + const VT vec = hn::Abs(hn::LoadU(dt, span.data() + i)); + max_vec = hn::Max(max_vec, vec); + } + // Process remaining elements using LoadN. + const size_t remaining = span.size() - i; + if (HWY_UNLIKELY(remaining != 0)) { + const VT vec = hn::Abs(hn::LoadN(dt, span.data() + i, remaining)); + max_vec = hn::Max(max_vec, vec); + } + return hn::ReduceMax(dt, max_vec); +} + // Forked from ComputeQKV. But it stores the K/V in the tiled format // KV_T is type stored in the KV cache (typically float or BF16). template @@ -168,9 +189,9 @@ static HWY_INLINE void ComputeQKVTransposedTile( const float* kv_row = kv_out_data + (token_in_tile_idx * qbatch.Size() + query_idx) * kv_out_cols; - const float* k_ptr = kv_row + kv_head * 2 * qkv_dim; - const float* v_ptr = kv_row + kv_head * 2 * qkv_dim + qkv_dim; - hwy::CopyBytes(k_ptr, k_f32, qkv_dim * sizeof(float)); + const float* k_values = kv_row + kv_head * 2 * qkv_dim; + const float* v_values = kv_row + kv_head * 2 * qkv_dim + qkv_dim; + hwy::CopyBytes(k_values, k_f32, qkv_dim * sizeof(float)); if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, k_f32, @@ -183,7 +204,53 @@ static HWY_INLINE void ComputeQKVTransposedTile( /*mul=*/1.0f); const size_t in_tile_idx = current_pos_mod % KVCache::kTileSize; - if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + // `v_cache_values` is a pointer to the V data that will be + // compressed and stored in the KV cache. By default, it points to + // the raw `v_values`. + const float* v_cache_values = v_values; + // `v_buf` is a temporary buffer used only when quantizing V values + // to int8_t. + HWY_ALIGN float v_buf[kMaxQKVDim]; + + if constexpr (IsInt8()) { + BF16* scales_ptr = HWY_RCAST_ALIGNED( + BF16*, tile_ptr + 2 * qkv_dim * KVCache::kTileSize); + + auto scale_and_store = [&](float* values, int dim, + size_t scale_idx) HWY_ATTR { + const float max_abs = + AbsMaxOfSpan(hwy::Span(values, dim)); + float scale = max_abs / 127.0f; + if (scale == 0.0f) scale = 1.0f; + scales_ptr[scale_idx] = hwy::ConvertScalarTo(scale); + const float inv_scale = 1.0f / scale; + const hn::Vec v_inv_scale = + hn::Set(df, inv_scale); + const size_t lanes = hn::Lanes(df); + size_t i = 0; + for (; i + lanes <= dim; i += lanes) { + hn::StoreU(hn::Mul(hn::LoadU(df, values + i), v_inv_scale), + df, values + i); + } + if (HWY_UNLIKELY(i < dim)) { + hn::StoreN( + hn::Mul(hn::LoadN(df, values + i, dim - i), v_inv_scale), + df, values + i, dim - i); + } + }; + + // K Scaling + scale_and_store(k_f32, qkv_dim, in_tile_idx); + + // V Scaling: Copy `v_values` to `v_buf`, scale `v_buf` in-place, + // and then update `v_cache_values` to point to `v_buf`. + hwy::CopyBytes(v_values, v_buf, qkv_dim * sizeof(float)); + scale_and_store(v_buf, qkv_dim, KVCache::kTileSize + in_tile_idx); + v_cache_values = v_buf; + } + + if (attention_impl == AttentionImpl::kFlashTransposedQsBF16 && + !IsInt8()) { const int in_tile_idx_mod_2 = in_tile_idx % 2; for (int dim = 0; dim < qkv_dim; dim += 2) { const int dim_mod_2 = dim % 2; @@ -196,16 +263,17 @@ static HWY_INLINE void ComputeQKVTransposedTile( in_tile_idx * 2 + 1] = k_f32[dim + 1]; // Pack v's in pairs v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim + - dim * 2 + in_tile_idx_mod_2] = v_ptr[dim]; + dim * 2 + in_tile_idx_mod_2] = v_cache_values[dim]; v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim + - (dim + 1) * 2 + in_tile_idx_mod_2] = v_ptr[dim + 1]; + (dim + 1) * 2 + in_tile_idx_mod_2] = + v_cache_values[dim + 1]; } } else { for (int i = 0; i < qkv_dim; ++i) { k_tile_vec[i * KVCache::kTileSize + in_tile_idx] = k_f32[i]; } - Compress(v_ptr, qkv_dim, tls, tile_packed_span, + Compress(v_cache_values, qkv_dim, tls, tile_packed_span, qkv_dim * (KVCache::kTileSize + in_tile_idx)); } @@ -640,12 +708,21 @@ void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, "query heads must be a multiple of key-value heads"); (void)layer_config; // only used in HWY_DASSERT + if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kBF16) { ComputeQKVTransposedTile(num_tokens, layer_idx, layer, attention_impl, activations, qbatch, flags, env); - } else { + } else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kF32) { ComputeQKVTransposedTile(num_tokens, layer_idx, layer, attention_impl, activations, qbatch, flags, env); + } else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == + Type::kInt8) { + ComputeQKVTransposedTile(num_tokens, layer_idx, layer, + attention_impl, activations, qbatch, flags, + env); + } else { + HWY_ABORT("Unsupported KV cache type: %d", + qbatch.KV(0).cache->compact_kv_cache_ptr.GetType()); } RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer.query_norm_scale, layer_idx, activations, diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc index 46d276f6..84354ca2 100644 --- a/gemma/tiled_attention_test.cc +++ b/gemma/tiled_attention_test.cc @@ -1,7 +1,10 @@ #include +#include +#include #include #include +#include #include #include @@ -42,7 +45,7 @@ struct AttentionTestEnv { int qkv_dim, int kv_seq_len, int attention_window_size, int num_kv_heads, int num_heads, int num_tokens, int last_pos, float att_cap, int layer_idx, int layers_total, int qbatch_size, AttentionImpl attention_impl, - ) + std::optional kv_cache_type = {} ) : ctx(threading_args), env(ctx) { layer_config.heads = num_heads; layer_config.kv_heads = num_kv_heads; @@ -65,6 +68,7 @@ struct AttentionTestEnv { *tensor_info_registry); runtime_config.attention_impl = attention_impl; + runtime_config.kv_cache_type = kv_cache_type; inference_args.seq_len = kv_seq_len; all_queries.Reserve(qbatch_size); @@ -72,7 +76,8 @@ struct AttentionTestEnv { for (int q = 0; q < qbatch_size; ++q) { kv_caches.emplace_back(model_config, inference_args, runtime_config, ctx.allocator); - if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) { + if (attention_impl == AttentionImpl::kFlashTransposedQsBF16 && + kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kBF16) { MatPtrT compact_kv_cache = kv_caches.back().compact_kv_cache_ptr; for (int i = 0; i < compact_kv_cache.Rows(); ++i) { for (int j = 0; j < compact_kv_cache.Cols(); ++j) { @@ -98,8 +103,65 @@ struct AttentionTestEnv { } } } else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) { - MatPtrT compact_kv_cache = kv_caches.back().compact_kv_cache_ptr; - FillMatPtrT(compact_kv_cache); + if (kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kInt8) { + MatPtrT compact_kv_cache = + kv_caches.back().compact_kv_cache_ptr; + for (int i = 0; i < compact_kv_cache.Rows(); ++i) { + BF16* scales_ptr = HWY_RCAST_ALIGNED( + BF16*, compact_kv_cache.Row(i) + + 2 * qkv_dim * gcpp::KVCache::kTileSize); + for (int in_tile_idx = 0; in_tile_idx < gcpp::KVCache::kTileSize; + ++in_tile_idx) { + // Compute scale and fill K + float max_k = 0.0f; + for (int dim = 0; dim < qkv_dim; ++dim) { + int j = dim * gcpp::KVCache::kTileSize + in_tile_idx; + float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); + max_k = std::max(max_k, expected); + } + float scale_k = max_k / 127.0f; + if (scale_k == 0.0f) scale_k = 1.0f; + scales_ptr[in_tile_idx] = hwy::ConvertScalarTo(scale_k); + + for (int dim = 0; dim < qkv_dim; ++dim) { + int j = dim * gcpp::KVCache::kTileSize + in_tile_idx; + float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); + compact_kv_cache.Row(i)[j] = + static_cast(std::round(expected / scale_k)); + } + + // Compute scale and fill V + float max_v = 0.0f; + for (int dim = 0; dim < qkv_dim; ++dim) { + int j = qkv_dim * gcpp::KVCache::kTileSize + + in_tile_idx * qkv_dim + dim; + float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); + max_v = std::max(max_v, expected); + } + float scale_v = max_v / 127.0f; + if (scale_v == 0.0f) scale_v = 1.0f; + scales_ptr[gcpp::KVCache::kTileSize + in_tile_idx] = + hwy::ConvertScalarTo(scale_v); + + for (int dim = 0; dim < qkv_dim; ++dim) { + int j = qkv_dim * gcpp::KVCache::kTileSize + + in_tile_idx * qkv_dim + dim; + float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); + compact_kv_cache.Row(i)[j] = + static_cast(std::round(expected / scale_v)); + } + } + } + } else if (kv_caches.back().compact_kv_cache_ptr.GetType() == + Type::kBF16) { + MatPtrT compact_kv_cache = + kv_caches.back().compact_kv_cache_ptr; + FillMatPtrT(compact_kv_cache); + } else { + MatPtrT compact_kv_cache = + kv_caches.back().compact_kv_cache_ptr; + FillMatPtrT(compact_kv_cache); + } } else { FillMatPtrT(kv_caches.back().kv_cache); } @@ -725,6 +787,50 @@ void TestAttentionMultipleTokensBF16() { } } +void TestAttentionMultipleTokensInt8() { + int qkv_dim = 64; + int kv_seq_len = 64; + int num_kv_heads = 2; + int num_heads = 4; + int num_tokens = 2; + int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 + // will have 64 tokens to attend to. + float att_cap = 10.0f; + int layer_idx = 0; + int layers_total = 1; + int qbatch_size = 2; + AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16; + AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, + num_heads, num_tokens, last_pos, att_cap, layer_idx, + layers_total, qbatch_size, attention_impl, + Type::kInt8); + test_env.SetupWeights(); + FillMatPtrT(test_env.activations->attention.pre_att_rms_out); + FillMatPtrT(test_env.activations->attention.q); + FillMatPtrT(test_env.activations->attention.vit_Q); + FillMatPtrT(test_env.activations->attention.vit_K); + FillMatPtrT(test_env.activations->attention.att); + FillMatPtrT(test_env.activations->attention.att_out); + FillMatPtrT(test_env.activations->attention.softmax_max); + FillMatPtrT(test_env.activations->attention.softmax_d); + + int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); + TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, + test_env.activations->attention, *test_env.qbatch, + test_env.env, flags); + std::cerr << "att_out\n"; + PrintMatPtr(test_env.activations->attention.att_out); + for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { + EXPECT_TRUE(hwy::CompareArraySimilar( + AttentionMultipleTokensAttentionGoldens.data() + + i * test_env.activations->attention.att_out.Cols(), + test_env.activations->attention.att_out.Row(i), + test_env.activations->attention.att_out.Cols(), 1e-1, + hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) + << "att_out mismatch for query: " << i; + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/util/mat.h b/util/mat.h index 9bb72821..5213ee61 100644 --- a/util/mat.h +++ b/util/mat.h @@ -508,6 +508,11 @@ decltype(auto) CallUpcastedKVs(hwy::Span base, const Func& func, auto matptrs = make_matptr_vec(BF16{}); hwy::Span> matptrs_span(matptrs.data(), matptrs.size()); return func(matptrs_span, std::forward(args)...); + } else if (type == Type::kInt8) { + auto matptrs = make_matptr_vec(int8_t{}); + hwy::Span> matptrs_span(matptrs.data(), + matptrs.size()); + return func(matptrs_span, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(type)); }