use crate::io::{Error, Read, Write};
use alloc::vec::Vec;
#[cfg(feature = "hash")]
use core::hash::Hasher;
use super::ringbuffer::RingBuffer;
pub struct DecodeBuffer {
buffer: RingBuffer,
pub dict_content: Vec<u8>,
pub window_size: usize,
total_output_counter: u64,
#[cfg(feature = "hash")]
pub hash: twox_hash::XxHash64,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum DecodeBufferError {
NotEnoughBytesInDictionary { got: usize, need: usize },
OffsetTooBig { offset: usize, buf_len: usize },
}
#[cfg(feature = "std")]
impl std::error::Error for DecodeBufferError {}
impl core::fmt::Display for DecodeBufferError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DecodeBufferError::NotEnoughBytesInDictionary { got, need } => {
write!(
f,
"Need {} bytes from the dictionary but it is only {} bytes long",
need, got,
)
}
DecodeBufferError::OffsetTooBig { offset, buf_len } => {
write!(f, "offset: {} bigger than buffer: {}", offset, buf_len,)
}
}
}
}
impl Read for DecodeBuffer {
fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
let max_amount = self.can_drain_to_window_size().unwrap_or(0);
let amount = max_amount.min(target.len());
let mut written = 0;
self.drain_to(amount, |buf| {
target[written..][..buf.len()].copy_from_slice(buf);
written += buf.len();
(buf.len(), Ok(()))
})?;
Ok(amount)
}
}
impl DecodeBuffer {
pub fn new(window_size: usize) -> DecodeBuffer {
DecodeBuffer {
buffer: RingBuffer::new(),
dict_content: Vec::new(),
window_size,
total_output_counter: 0,
#[cfg(feature = "hash")]
hash: twox_hash::XxHash64::with_seed(0),
}
}
pub fn reset(&mut self, window_size: usize) {
self.window_size = window_size;
self.buffer.clear();
self.buffer.reserve(self.window_size);
self.dict_content.clear();
self.total_output_counter = 0;
#[cfg(feature = "hash")]
{
self.hash = twox_hash::XxHash64::with_seed(0);
}
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn push(&mut self, data: &[u8]) {
self.buffer.extend(data);
self.total_output_counter += data.len() as u64;
}
pub fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), DecodeBufferError> {
if offset > self.buffer.len() {
self.repeat_from_dict(offset, match_length)
} else {
let buf_len = self.buffer.len();
let start_idx = buf_len - offset;
let end_idx = start_idx + match_length;
self.buffer.reserve(match_length);
if end_idx > buf_len {
self.repeat_in_chunks(offset, match_length, start_idx);
} else {
unsafe {
self.buffer
.extend_from_within_unchecked(start_idx, match_length)
};
}
self.total_output_counter += match_length as u64;
Ok(())
}
}
fn repeat_in_chunks(&mut self, offset: usize, match_length: usize, start_idx: usize) {
let mut start_idx = start_idx;
let mut copied_counter_left = match_length;
while copied_counter_left > 0 {
let chunksize = usize::min(offset, copied_counter_left);
unsafe {
self.buffer
.extend_from_within_unchecked(start_idx, chunksize)
};
copied_counter_left -= chunksize;
start_idx += chunksize;
}
}
#[cold]
fn repeat_from_dict(
&mut self,
offset: usize,
match_length: usize,
) -> Result<(), DecodeBufferError> {
if self.total_output_counter <= self.window_size as u64 {
let bytes_from_dict = offset - self.buffer.len();
if bytes_from_dict > self.dict_content.len() {
return Err(DecodeBufferError::NotEnoughBytesInDictionary {
got: self.dict_content.len(),
need: bytes_from_dict,
});
}
if bytes_from_dict < match_length {
let dict_slice = &self.dict_content[self.dict_content.len() - bytes_from_dict..];
self.buffer.extend(dict_slice);
self.total_output_counter += bytes_from_dict as u64;
return self.repeat(self.buffer.len(), match_length - bytes_from_dict);
} else {
let low = self.dict_content.len() - bytes_from_dict;
let high = low + match_length;
let dict_slice = &self.dict_content[low..high];
self.buffer.extend(dict_slice);
}
Ok(())
} else {
Err(DecodeBufferError::OffsetTooBig {
offset,
buf_len: self.buffer.len(),
})
}
}
pub fn can_drain_to_window_size(&self) -> Option<usize> {
if self.buffer.len() > self.window_size {
Some(self.buffer.len() - self.window_size)
} else {
None
}
}
pub fn can_drain(&self) -> usize {
self.buffer.len()
}
pub fn drain_to_window_size(&mut self) -> Option<Vec<u8>> {
match self.can_drain_to_window_size() {
None => None,
Some(can_drain) => {
let mut vec = Vec::with_capacity(can_drain);
self.drain_to(can_drain, |buf| {
vec.extend_from_slice(buf);
(buf.len(), Ok(()))
})
.ok()?;
Some(vec)
}
}
}
pub fn drain_to_window_size_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
match self.can_drain_to_window_size() {
None => Ok(0),
Some(can_drain) => self.drain_to(can_drain, |buf| write_all_bytes(&mut sink, buf)),
}
}
pub fn drain(&mut self) -> Vec<u8> {
let (slice1, slice2) = self.buffer.as_slices();
#[cfg(feature = "hash")]
{
self.hash.write(slice1);
self.hash.write(slice2);
}
let mut vec = Vec::with_capacity(slice1.len() + slice2.len());
vec.extend_from_slice(slice1);
vec.extend_from_slice(slice2);
self.buffer.clear();
vec
}
pub fn drain_to_writer(&mut self, mut sink: impl Write) -> Result<usize, Error> {
let write_limit = self.buffer.len();
self.drain_to(write_limit, |buf| write_all_bytes(&mut sink, buf))
}
pub fn read_all(&mut self, target: &mut [u8]) -> Result<usize, Error> {
let amount = self.buffer.len().min(target.len());
let mut written = 0;
self.drain_to(amount, |buf| {
target[written..][..buf.len()].copy_from_slice(buf);
written += buf.len();
(buf.len(), Ok(()))
})?;
Ok(amount)
}
fn drain_to(
&mut self,
amount: usize,
mut write_bytes: impl FnMut(&[u8]) -> (usize, Result<(), Error>),
) -> Result<usize, Error> {
if amount == 0 {
return Ok(0);
}
struct DrainGuard<'a> {
buffer: &'a mut RingBuffer,
amount: usize,
}
impl<'a> Drop for DrainGuard<'a> {
fn drop(&mut self) {
if self.amount != 0 {
self.buffer.drop_first_n(self.amount);
}
}
}
let mut drain_guard = DrainGuard {
buffer: &mut self.buffer,
amount: 0,
};
let (slice1, slice2) = drain_guard.buffer.as_slices();
let n1 = slice1.len().min(amount);
let n2 = slice2.len().min(amount - n1);
if n1 != 0 {
let (written1, res1) = write_bytes(&slice1[..n1]);
#[cfg(feature = "hash")]
self.hash.write(&slice1[..written1]);
drain_guard.amount += written1;
res1?;
if written1 == n1 && n2 != 0 {
let (written2, res2) = write_bytes(&slice2[..n2]);
#[cfg(feature = "hash")]
self.hash.write(&slice2[..written2]);
drain_guard.amount += written2;
res2?;
}
}
let amount_written = drain_guard.amount;
drop(drain_guard);
Ok(amount_written)
}
}
fn write_all_bytes(mut sink: impl Write, buf: &[u8]) -> (usize, Result<(), Error>) {
let mut written = 0;
while written < buf.len() {
match sink.write(&buf[written..]) {
Ok(0) => return (written, Ok(())),
Ok(w) => written += w,
Err(e) => return (written, Err(e)),
}
}
(written, Ok(()))
}
#[cfg(test)]
mod tests {
use super::DecodeBuffer;
use crate::io::{Error, ErrorKind, Write};
extern crate std;
use alloc::vec;
use alloc::vec::Vec;
#[test]
fn short_writer() {
struct ShortWriter {
buf: Vec<u8>,
write_len: usize,
}
impl Write for ShortWriter {
fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
if buf.len() > self.write_len {
self.buf.extend_from_slice(&buf[..self.write_len]);
Ok(self.write_len)
} else {
self.buf.extend_from_slice(buf);
Ok(buf.len())
}
}
fn flush(&mut self) -> std::result::Result<(), Error> {
Ok(())
}
}
let mut short_writer = ShortWriter {
buf: vec![],
write_len: 10,
};
let mut decode_buf = DecodeBuffer::new(100);
decode_buf.push(b"0123456789");
decode_buf.repeat(10, 90).unwrap();
let repeats = 1000;
for _ in 0..repeats {
assert_eq!(decode_buf.len(), 100);
decode_buf.repeat(10, 50).unwrap();
assert_eq!(decode_buf.len(), 150);
decode_buf
.drain_to_window_size_writer(&mut short_writer)
.unwrap();
assert_eq!(decode_buf.len(), 100);
}
assert_eq!(short_writer.buf.len(), repeats * 50);
decode_buf.drain_to_writer(&mut short_writer).unwrap();
assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
}
#[test]
fn wouldblock_writer() {
struct WouldblockWriter {
buf: Vec<u8>,
last_blocked: usize,
block_every: usize,
}
impl Write for WouldblockWriter {
fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, Error> {
if self.last_blocked < self.block_every {
self.buf.extend_from_slice(buf);
self.last_blocked += 1;
Ok(buf.len())
} else {
self.last_blocked = 0;
Err(Error::from(ErrorKind::WouldBlock))
}
}
fn flush(&mut self) -> std::result::Result<(), Error> {
Ok(())
}
}
let mut short_writer = WouldblockWriter {
buf: vec![],
last_blocked: 0,
block_every: 5,
};
let mut decode_buf = DecodeBuffer::new(100);
decode_buf.push(b"0123456789");
decode_buf.repeat(10, 90).unwrap();
let repeats = 1000;
for _ in 0..repeats {
assert_eq!(decode_buf.len(), 100);
decode_buf.repeat(10, 50).unwrap();
assert_eq!(decode_buf.len(), 150);
loop {
match decode_buf.drain_to_window_size_writer(&mut short_writer) {
Ok(written) => {
if written == 0 {
break;
}
}
Err(e) => {
if e.kind() == ErrorKind::WouldBlock {
continue;
} else {
panic!("Unexpected error {:?}", e);
}
}
}
}
assert_eq!(decode_buf.len(), 100);
}
assert_eq!(short_writer.buf.len(), repeats * 50);
loop {
match decode_buf.drain_to_writer(&mut short_writer) {
Ok(written) => {
if written == 0 {
break;
}
}
Err(e) => {
if e.kind() == ErrorKind::WouldBlock {
continue;
} else {
panic!("Unexpected error {:?}", e);
}
}
}
}
assert_eq!(short_writer.buf.len(), repeats * 50 + 100);
}
}