Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,142 @@ struct CompressTraits<SfpStream> {
}
};

template <>
struct CompressTraits<int8_t> {
using Packed = int8_t;

static size_t CompressBound(size_t num) { return num * sizeof(Packed); }

template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
size_t num, CompressPerThread& /*tls*/,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
const hn::Repartition<int32_t, DF> di32;
const hn::Repartition<int16_t, DF> di16;
const hn::Repartition<int8_t, DF> di8;
const auto di16_16 = hn::Half<decltype(di16)>();
const auto di8_16 = hn::Half<decltype(di8)>();
using VF = hn::Vec<DF>;
const size_t NF = hn::Lanes(df);

size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
const VF v0 = hn::LoadU(df, raw + i);
const VF v1 = hn::LoadU(df, raw + i + NF);
const auto vi32_0 = hn::NearestInt(v0);
const auto vi32_1 = hn::NearestInt(v1);
const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1);
const auto vi8 = hn::OrderedDemote2To(
di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16));
hn::StoreU(vi8, di8_16, packed.ptr + packed_ofs + i);
}
}
const size_t remaining = num - i;
if (remaining > 0) {
HWY_ALIGN float buf[2 * NF];
hwy::ZeroBytes(buf, 2 * NF * sizeof(float));
for (size_t j = 0; j < remaining; ++j) buf[j] = raw[i + j];
const VF v0 = hn::LoadU(df, buf);
const VF v1 = hn::LoadU(df, buf + NF);
const auto vi32_0 = hn::NearestInt(v0);
const auto vi32_1 = hn::NearestInt(v1);
const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1);
const auto vi8 = hn::OrderedDemote2To(
di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16));
hn::StoreN(vi8, di8_16, packed.ptr + packed_ofs + i, remaining);
}
}

static float ToFloatSlow(const Packed x) { return static_cast<float>(x); }

template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<DF>& raw0,
hn::Vec<DF>& raw1) {
const hn::Repartition<int32_t, DF> di32;
const hn::Repartition<int16_t, DF> di16;
const hn::Rebind<int8_t, decltype(di16)> di8_half;

const auto vec_i8 = hn::LoadU(di8_half, packed.ptr + packed_ofs);
const auto vec_i16 = hn::PromoteTo(di16, vec_i8);
const auto vec_i32_0 = hn::PromoteLowerTo(di32, vec_i16);
const auto vec_i32_1 = hn::PromoteUpperTo(di32, vec_i16);

raw0 = hn::ConvertTo(df, vec_i32_0);
raw1 = hn::ConvertTo(df, vec_i32_1);
}

template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Load2(DBF dbf, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<DBF>& raw0,
hn::Vec<DBF>& raw1) {
const hn::Repartition<float, DBF> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);

VF f0, f1, f2, f3;
Load2(df, packed, packed_ofs, f0, f1);
Load2(df, packed, packed_ofs + 2 * NF, f2, f3);

raw0 = hn::OrderedDemote2To(dbf, f0, f1);
raw1 = hn::OrderedDemote2To(dbf, f2, f3);
}

template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DecompressAndZeroPad(
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
float* HWY_RESTRICT raw, size_t num) {
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);

size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF raw0, raw1;
Load2(df, packed, packed_ofs + i, raw0, raw1);
hn::StoreU(raw0, df, raw + i);
hn::StoreU(raw1, df, raw + i + NF);
}
}

const size_t remaining = num - i;
if (HWY_UNLIKELY(remaining != 0)) {
for (size_t j = 0; j < remaining; ++j) {
raw[i + j] = static_cast<float>(packed.ptr[packed_ofs + i + j]);
}
}
}

template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void DecompressAndZeroPad(
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
BF16* HWY_RESTRICT raw, size_t num) {
const hn::Repartition<float, DBF> df;
const size_t NF = hn::Lanes(df);
size_t i = 0;
const size_t NBF = hn::Lanes(dbf);
if (num >= NBF) {
for (; i <= num - NBF; i += NBF) {
hn::Vec<decltype(df)> f0, f1;
Load2(df, packed, packed_ofs + i, f0, f1);
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
hn::StoreU(vbf, dbf, raw + i);
}
}
const size_t remaining = num - i;
if (remaining > 0) {
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed, packed_ofs + i, buf, remaining);
auto f0 = hn::LoadU(df, buf);
auto f1 = hn::LoadU(df, buf + NF);
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
hn::StoreN(vbf, dbf, raw + i, remaining);
}
}
};

// Integer quantization.
template <>
struct CompressTraits<I8Stream> {
Expand Down
4 changes: 4 additions & 0 deletions compression/compress_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ struct TestDecompress2 {
HWY_ASSERT(stats.L1().Max() <= 0.08f);
HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1()));
HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1()));
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
HWY_ASSERT(stats.L1().Max() <= 0.6f);
} else {
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
}
Expand Down Expand Up @@ -200,6 +202,8 @@ struct TestShortLengths {
HWY_ASSERT(stats.L1().Max() <= 0.14f);
HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1()));
HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1()));
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
HWY_ASSERT(stats.L1().Max() <= 0.6f);
} else {
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
}
Expand Down
15 changes: 12 additions & 3 deletions compression/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}

template <typename Packed>
constexpr bool IsInt8() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int8_t>();
}

template <typename Packed>
constexpr bool IsBF16() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
Expand Down Expand Up @@ -231,12 +236,13 @@ enum class Type {
kI8,
kU16,
kU8,
kInt8,
};
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "u32", "u64",
"i8", "u16", "u8"};
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64",
"u32", "u64", "i8", "u16", "u8", "int8"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {
Expand All @@ -251,6 +257,7 @@ static constexpr size_t kTypeBits[] = {
8 * sizeof(I8Stream),
8 * sizeof(uint16_t),
8 * sizeof(uint8_t),
8 * sizeof(int8_t),
};

static inline bool EnumValid(Type type) {
Expand Down Expand Up @@ -281,6 +288,8 @@ constexpr Type TypeEnum() {
return Type::kU16;
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
return Type::kU8;
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
return Type::kInt8;
} else {
return Type::kUnknown;
}
Expand Down
70 changes: 70 additions & 0 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,52 @@ static HWY_NOINLINE void ApplyMasking(
}
}

template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
VF& x0_p1, VF& x1_p0, VF& x1_p1,
VF& x2_p0, VF& x2_p1, VF& x3_p0,
VF& x3_p1, VF& x4_p0, VF& x4_p1,
VF& x5_p0, VF& x5_p1, VF& x6_p0,
VF& x6_p1, VF& x7_p0, VF& x7_p1) {
const size_t kTileSize = hn::Lanes(df);
const PackedSpan<const BF16> scales_span =
MakeConstSpan(scales, 2 * kTileSize);
VF scales_p0, scales_p1;
Decompress2(df, scales_span, 0, scales_p0, scales_p1);
if constexpr (kNumQueries >= 1) {
x0_p0 = hn::Mul(x0_p0, scales_p0);
x0_p1 = hn::Mul(x0_p1, scales_p1);
}
if constexpr (kNumQueries >= 2) {
x1_p0 = hn::Mul(x1_p0, scales_p0);
x1_p1 = hn::Mul(x1_p1, scales_p1);
}
if constexpr (kNumQueries >= 3) {
x2_p0 = hn::Mul(x2_p0, scales_p0);
x2_p1 = hn::Mul(x2_p1, scales_p1);
}
if constexpr (kNumQueries >= 4) {
x3_p0 = hn::Mul(x3_p0, scales_p0);
x3_p1 = hn::Mul(x3_p1, scales_p1);
}
if constexpr (kNumQueries >= 5) {
x4_p0 = hn::Mul(x4_p0, scales_p0);
x4_p1 = hn::Mul(x4_p1, scales_p1);
}
if constexpr (kNumQueries >= 6) {
x5_p0 = hn::Mul(x5_p0, scales_p0);
x5_p1 = hn::Mul(x5_p1, scales_p1);
}
if constexpr (kNumQueries >= 7) {
x6_p0 = hn::Mul(x6_p0, scales_p0);
x6_p1 = hn::Mul(x6_p1, scales_p1);
}
if constexpr (kNumQueries >= 8) {
x7_p0 = hn::Mul(x7_p0, scales_p0);
x7_p1 = hn::Mul(x7_p1, scales_p1);
}
}

// Performs tiled flash attention for arbitrary number of queries
// It depends on kv being tiled.
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
Expand Down Expand Up @@ -1400,6 +1446,21 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
false,
"Query type type not supported, only float and BF16 are supported");
}
// microscaling
// TODO: Change to more generic function to inform if we should use
// microscaling or not.
constexpr bool kUseMicroScaling = IsInt8<KV_T>();
if constexpr (kUseMicroScaling) {
// After end of the tile, we have kTileSize * 2 bfloat16 for the
// microscaling scales for K and V.
const BF16* microscaling_scales_k =
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
pos_in_tile;
MultiplyByScale<kNumQueries>(df, microscaling_scales_k, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
}

constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
Expand Down Expand Up @@ -1433,6 +1494,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
kNumQueriesPerGroup);
if constexpr (kUseMicroScaling) {
const BF16* microscaling_scales_v =
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
kTileSize + pos_in_tile;
MultiplyByScale<kNumQueries>(df, microscaling_scales_v, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
}
if constexpr (IsF32<Q_T>()) {
MulByConstAndAddTileUpTo8<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
Expand Down
Loading
Loading