use super::super::blocks::sequence_section::ModeType;
use super::super::blocks::sequence_section::Sequence;
use super::super::blocks::sequence_section::SequencesHeader;
use super::bit_reader_reverse::{BitReaderReversed, GetBitsError};
use super::scratch::FSEScratch;
use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};
use crate::fse::{FSEDecoder, FSEDecoderError, FSETableError};
use alloc::vec::Vec;
#[derive(Debug)]
#[non_exhaustive]
pub enum DecodeSequenceError {
GetBitsError(GetBitsError),
FSEDecoderError(FSEDecoderError),
FSETableError(FSETableError),
ExtraPadding { skipped_bits: i32 },
UnsupportedOffset { offset_code: u8 },
ZeroOffset,
NotEnoughBytesForNumSequences,
ExtraBits { bits_remaining: isize },
MissingCompressionMode,
MissingByteForRleLlTable,
MissingByteForRleOfTable,
MissingByteForRleMlTable,
}
#[cfg(feature = "std")]
impl std::error::Error for DecodeSequenceError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
DecodeSequenceError::GetBitsError(source) => Some(source),
DecodeSequenceError::FSEDecoderError(source) => Some(source),
DecodeSequenceError::FSETableError(source) => Some(source),
_ => None,
}
}
}
impl core::fmt::Display for DecodeSequenceError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DecodeSequenceError::GetBitsError(e) => write!(f, "{:?}", e),
DecodeSequenceError::FSEDecoderError(e) => write!(f, "{:?}", e),
DecodeSequenceError::FSETableError(e) => write!(f, "{:?}", e),
DecodeSequenceError::ExtraPadding { skipped_bits } => {
write!(f,
"Padding at the end of the sequence_section was more than a byte long: {} bits. Probably caused by data corruption",
skipped_bits,
)
}
DecodeSequenceError::UnsupportedOffset { offset_code } => {
write!(
f,
"Do not support offsets bigger than 1<<32; got: {}",
offset_code,
)
}
DecodeSequenceError::ZeroOffset => write!(
f,
"Read an offset == 0. That is an illegal value for offsets"
),
DecodeSequenceError::NotEnoughBytesForNumSequences => write!(
f,
"Bytestream did not contain enough bytes to decode num_sequences"
),
DecodeSequenceError::ExtraBits { bits_remaining } => write!(f, "{}", bits_remaining),
DecodeSequenceError::MissingCompressionMode => write!(
f,
"compression modes are none but they must be set to something"
),
DecodeSequenceError::MissingByteForRleLlTable => {
write!(f, "Need a byte to read for RLE ll table")
}
DecodeSequenceError::MissingByteForRleOfTable => {
write!(f, "Need a byte to read for RLE of table")
}
DecodeSequenceError::MissingByteForRleMlTable => {
write!(f, "Need a byte to read for RLE ml table")
}
}
}
}
impl From<GetBitsError> for DecodeSequenceError {
fn from(val: GetBitsError) -> Self {
Self::GetBitsError(val)
}
}
impl From<FSETableError> for DecodeSequenceError {
fn from(val: FSETableError) -> Self {
Self::FSETableError(val)
}
}
impl From<FSEDecoderError> for DecodeSequenceError {
fn from(val: FSEDecoderError) -> Self {
Self::FSEDecoderError(val)
}
}
pub fn decode_sequences(
section: &SequencesHeader,
source: &[u8],
scratch: &mut FSEScratch,
target: &mut Vec<Sequence>,
) -> Result<(), DecodeSequenceError> {
let bytes_read = maybe_update_fse_tables(section, source, scratch)?;
vprintln!("Updating tables used {} bytes", bytes_read);
let bit_stream = &source[bytes_read..];
let mut br = BitReaderReversed::new(bit_stream);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
}
}
if skipped_bits > 8 {
return Err(DecodeSequenceError::ExtraPadding { skipped_bits });
}
if scratch.ll_rle.is_some() || scratch.ml_rle.is_some() || scratch.of_rle.is_some() {
decode_sequences_with_rle(section, &mut br, scratch, target)
} else {
decode_sequences_without_rle(section, &mut br, scratch, target)
}
}
fn decode_sequences_with_rle(
section: &SequencesHeader,
br: &mut BitReaderReversed<'_>,
scratch: &FSEScratch,
target: &mut Vec<Sequence>,
) -> Result<(), DecodeSequenceError> {
let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
let mut of_dec = FSEDecoder::new(&scratch.offsets);
if scratch.ll_rle.is_none() {
ll_dec.init_state(br)?;
}
if scratch.of_rle.is_none() {
of_dec.init_state(br)?;
}
if scratch.ml_rle.is_none() {
ml_dec.init_state(br)?;
}
target.clear();
target.reserve(section.num_sequences as usize);
for _seq_idx in 0..section.num_sequences {
let ll_code = if scratch.ll_rle.is_some() {
scratch.ll_rle.unwrap()
} else {
ll_dec.decode_symbol()
};
let ml_code = if scratch.ml_rle.is_some() {
scratch.ml_rle.unwrap()
} else {
ml_dec.decode_symbol()
};
let of_code = if scratch.of_rle.is_some() {
scratch.of_rle.unwrap()
} else {
of_dec.decode_symbol()
};
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
}
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);
if offset == 0 {
return Err(DecodeSequenceError::ZeroOffset);
}
target.push(Sequence {
ll: ll_value + ll_add as u32,
ml: ml_value + ml_add as u32,
of: offset,
});
if target.len() < section.num_sequences as usize {
if scratch.ll_rle.is_none() {
ll_dec.update_state(br);
}
if scratch.ml_rle.is_none() {
ml_dec.update_state(br);
}
if scratch.of_rle.is_none() {
of_dec.update_state(br);
}
}
if br.bits_remaining() < 0 {
return Err(DecodeSequenceError::NotEnoughBytesForNumSequences);
}
}
if br.bits_remaining() > 0 {
Err(DecodeSequenceError::ExtraBits {
bits_remaining: br.bits_remaining(),
})
} else {
Ok(())
}
}
fn decode_sequences_without_rle(
section: &SequencesHeader,
br: &mut BitReaderReversed<'_>,
scratch: &FSEScratch,
target: &mut Vec<Sequence>,
) -> Result<(), DecodeSequenceError> {
let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
let mut of_dec = FSEDecoder::new(&scratch.offsets);
ll_dec.init_state(br)?;
of_dec.init_state(br)?;
ml_dec.init_state(br)?;
target.clear();
target.reserve(section.num_sequences as usize);
for _seq_idx in 0..section.num_sequences {
let ll_code = ll_dec.decode_symbol();
let ml_code = ml_dec.decode_symbol();
let of_code = of_dec.decode_symbol();
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
if of_code > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::UnsupportedOffset {
offset_code: of_code,
});
}
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);
if offset == 0 {
return Err(DecodeSequenceError::ZeroOffset);
}
target.push(Sequence {
ll: ll_value + ll_add as u32,
ml: ml_value + ml_add as u32,
of: offset,
});
if target.len() < section.num_sequences as usize {
ll_dec.update_state(br);
ml_dec.update_state(br);
of_dec.update_state(br);
}
if br.bits_remaining() < 0 {
return Err(DecodeSequenceError::NotEnoughBytesForNumSequences);
}
}
if br.bits_remaining() > 0 {
Err(DecodeSequenceError::ExtraBits {
bits_remaining: br.bits_remaining(),
})
} else {
Ok(())
}
}
fn lookup_ll_code(code: u8) -> (u32, u8) {
match code {
0..=15 => (u32::from(code), 0),
16 => (16, 1),
17 => (18, 1),
18 => (20, 1),
19 => (22, 1),
20 => (24, 2),
21 => (28, 2),
22 => (32, 3),
23 => (40, 3),
24 => (48, 4),
25 => (64, 6),
26 => (128, 7),
27 => (256, 8),
28 => (512, 9),
29 => (1024, 10),
30 => (2048, 11),
31 => (4096, 12),
32 => (8192, 13),
33 => (16384, 14),
34 => (32768, 15),
35 => (65536, 16),
_ => unreachable!("Illegal literal length code was: {}", code),
}
}
fn lookup_ml_code(code: u8) -> (u32, u8) {
match code {
0..=31 => (u32::from(code) + 3, 0),
32 => (35, 1),
33 => (37, 1),
34 => (39, 1),
35 => (41, 1),
36 => (43, 2),
37 => (47, 2),
38 => (51, 3),
39 => (59, 3),
40 => (67, 4),
41 => (83, 4),
42 => (99, 5),
43 => (131, 7),
44 => (259, 8),
45 => (515, 9),
46 => (1027, 10),
47 => (2051, 11),
48 => (4099, 12),
49 => (8195, 13),
50 => (16387, 14),
51 => (32771, 15),
52 => (65539, 16),
_ => unreachable!("Illegal match length code was: {}", code),
}
}
pub const LL_MAX_LOG: u8 = 9;
pub const ML_MAX_LOG: u8 = 9;
pub const OF_MAX_LOG: u8 = 8;
fn maybe_update_fse_tables(
section: &SequencesHeader,
source: &[u8],
scratch: &mut FSEScratch,
) -> Result<usize, DecodeSequenceError> {
let modes = section
.modes
.ok_or(DecodeSequenceError::MissingCompressionMode)?;
let mut bytes_read = 0;
match modes.ll_mode() {
ModeType::FSECompressed => {
let bytes = scratch.literal_lengths.build_decoder(source, LL_MAX_LOG)?;
bytes_read += bytes;
vprintln!("Updating ll table");
vprintln!("Used bytes: {}", bytes);
scratch.ll_rle = None;
}
ModeType::RLE => {
vprintln!("Use RLE ll table");
if source.is_empty() {
return Err(DecodeSequenceError::MissingByteForRleLlTable);
}
bytes_read += 1;
if source[0] > MAX_LITERAL_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ll_rle = Some(source[0]);
}
ModeType::Predefined => {
vprintln!("Use predefined ll table");
scratch.literal_lengths.build_from_probabilities(
LL_DEFAULT_ACC_LOG,
&Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
)?;
scratch.ll_rle = None;
}
ModeType::Repeat => {
vprintln!("Repeat ll table");
}
};
let of_source = &source[bytes_read..];
match modes.of_mode() {
ModeType::FSECompressed => {
let bytes = scratch.offsets.build_decoder(of_source, OF_MAX_LOG)?;
vprintln!("Updating of table");
vprintln!("Used bytes: {}", bytes);
bytes_read += bytes;
scratch.of_rle = None;
}
ModeType::RLE => {
vprintln!("Use RLE of table");
if of_source.is_empty() {
return Err(DecodeSequenceError::MissingByteForRleOfTable);
}
bytes_read += 1;
if of_source[0] > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.of_rle = Some(of_source[0]);
}
ModeType::Predefined => {
vprintln!("Use predefined of table");
scratch.offsets.build_from_probabilities(
OF_DEFAULT_ACC_LOG,
&Vec::from(&OFFSET_DEFAULT_DISTRIBUTION[..]),
)?;
scratch.of_rle = None;
}
ModeType::Repeat => {
vprintln!("Repeat of table");
}
};
let ml_source = &source[bytes_read..];
match modes.ml_mode() {
ModeType::FSECompressed => {
let bytes = scratch.match_lengths.build_decoder(ml_source, ML_MAX_LOG)?;
bytes_read += bytes;
vprintln!("Updating ml table");
vprintln!("Used bytes: {}", bytes);
scratch.ml_rle = None;
}
ModeType::RLE => {
vprintln!("Use RLE ml table");
if ml_source.is_empty() {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
bytes_read += 1;
if ml_source[0] > MAX_MATCH_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ml_rle = Some(ml_source[0]);
}
ModeType::Predefined => {
vprintln!("Use predefined ml table");
scratch.match_lengths.build_from_probabilities(
ML_DEFAULT_ACC_LOG,
&Vec::from(&MATCH_LENGTH_DEFAULT_DISTRIBUTION[..]),
)?;
scratch.ml_rle = None;
}
ModeType::Repeat => {
vprintln!("Repeat ml table");
}
};
Ok(bytes_read)
}
const LL_DEFAULT_ACC_LOG: u8 = 6;
const LITERALS_LENGTH_DEFAULT_DISTRIBUTION: [i32; 36] = [
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
-1, -1, -1, -1,
];
const ML_DEFAULT_ACC_LOG: u8 = 6;
const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i32; 53] = [
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
];
const OF_DEFAULT_ACC_LOG: u8 = 5;
const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
];
#[test]
fn test_ll_default() {
let mut table = crate::fse::FSETable::new(MAX_LITERAL_LENGTH_CODE);
table
.build_from_probabilities(
LL_DEFAULT_ACC_LOG,
&Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
)
.unwrap();
#[cfg(feature = "std")]
for idx in 0..table.decode.len() {
std::println!(
"{:3}: {:3} {:3} {:3}",
idx,
table.decode[idx].symbol,
table.decode[idx].num_bits,
table.decode[idx].base_line
);
}
assert!(table.decode.len() == 64);
assert!(table.decode[0].symbol == 0);
assert!(table.decode[0].num_bits == 4);
assert!(table.decode[0].base_line == 0);
assert!(table.decode[19].symbol == 27);
assert!(table.decode[19].num_bits == 6);
assert!(table.decode[19].base_line == 0);
assert!(table.decode[39].symbol == 25);
assert!(table.decode[39].num_bits == 4);
assert!(table.decode[39].base_line == 16);
assert!(table.decode[60].symbol == 35);
assert!(table.decode[60].num_bits == 6);
assert!(table.decode[60].base_line == 0);
assert!(table.decode[59].symbol == 24);
assert!(table.decode[59].num_bits == 5);
assert!(table.decode[59].base_line == 32);
}