ruzstd/decoding/
ringbuffer.rs

1use alloc::alloc::{alloc, dealloc};
2use core::{alloc::Layout, ptr::NonNull, slice};
3
4pub struct RingBuffer {
5    // Safety invariants:
6    //
7    // 1.
8    //    a.`buf` must be a valid allocation of capacity `cap`
9    //    b. ...unless `cap=0`, in which case it is dangling
10    // 2. If tail≥head
11    //    a. `head..tail` must contain initialized memory.
12    //    b. Else, `head..` and `..tail` must be initialized
13    // 3. `head` and `tail` are in bounds (≥ 0 and < cap)
14    // 4. `tail` is never `cap` except for a full buffer, and instead uses the value `0`. In other words, `tail` always points to the place
15    //    where the next element would go (if there is space)
16    buf: NonNull<u8>,
17    cap: usize,
18    head: usize,
19    tail: usize,
20}
21
22// SAFETY: RingBuffer does not hold any thread specific values -> it can be sent to another thread -> RingBuffer is Send
23unsafe impl Send for RingBuffer {}
24
25// SAFETY: Ringbuffer does not provide unsyncronized interior mutability which makes &RingBuffer Send -> RingBuffer is Sync
26unsafe impl Sync for RingBuffer {}
27
28impl RingBuffer {
29    pub fn new() -> Self {
30        RingBuffer {
31            // SAFETY: Upholds invariant 1a as stated
32            buf: NonNull::dangling(),
33            cap: 0,
34            // SAFETY: Upholds invariant 2-4
35            head: 0,
36            tail: 0,
37        }
38    }
39
40    /// Return the number of bytes in the buffer.
41    pub fn len(&self) -> usize {
42        let (x, y) = self.data_slice_lengths();
43        x + y
44    }
45
46    /// Return the amount of available space (in bytes) of the buffer.
47    pub fn free(&self) -> usize {
48        let (x, y) = self.free_slice_lengths();
49        (x + y).saturating_sub(1)
50    }
51
52    /// Empty the buffer and reset the head and tail.
53    pub fn clear(&mut self) {
54        // SAFETY: Upholds invariant 2, trivially
55        // SAFETY: Upholds invariant 3; 0 is always valid
56        self.head = 0;
57        self.tail = 0;
58    }
59
60    /// Whether the buffer is empty
61    pub fn is_empty(&self) -> bool {
62        self.head == self.tail
63    }
64
65    /// Ensure that there's space for `amount` elements in the buffer.
66    pub fn reserve(&mut self, amount: usize) {
67        let free = self.free();
68        if free >= amount {
69            return;
70        }
71
72        self.reserve_amortized(amount - free);
73    }
74
75    #[inline(never)]
76    #[cold]
77    fn reserve_amortized(&mut self, amount: usize) {
78        // SAFETY: if we were succesfully able to construct this layout when we allocated then it's also valid do so now
79        let current_layout = unsafe { Layout::array::<u8>(self.cap).unwrap_unchecked() };
80
81        // Always have at least 1 unused element as the sentinel.
82        let new_cap = usize::max(
83            self.cap.next_power_of_two(),
84            (self.cap + amount).next_power_of_two(),
85        ) + 1;
86
87        // Check that the capacity isn't bigger than isize::MAX, which is the max allowed by LLVM, or that
88        // we are on a >= 64 bit system which will never allow that much memory to be allocated
89        #[allow(clippy::assertions_on_constants)]
90        {
91            debug_assert!(usize::BITS >= 64 || new_cap < isize::MAX as usize);
92        }
93
94        let new_layout = Layout::array::<u8>(new_cap)
95            .unwrap_or_else(|_| panic!("Could not create layout for u8 array of size {}", new_cap));
96
97        // alloc the new memory region and panic if alloc fails
98        // TODO maybe rework this to generate an error?
99        let new_buf = unsafe {
100            let new_buf = alloc(new_layout);
101
102            NonNull::new(new_buf).expect("Allocating new space for the ringbuffer failed")
103        };
104
105        // If we had data before, copy it over to the newly alloced memory region
106        if self.cap > 0 {
107            let ((s1_ptr, s1_len), (s2_ptr, s2_len)) = self.data_slice_parts();
108
109            unsafe {
110                // SAFETY: Upholds invariant 2, we end up populating (0..(len₁ + len₂))
111                new_buf.as_ptr().copy_from_nonoverlapping(s1_ptr, s1_len);
112                new_buf
113                    .as_ptr()
114                    .add(s1_len)
115                    .copy_from_nonoverlapping(s2_ptr, s2_len);
116                dealloc(self.buf.as_ptr(), current_layout);
117            }
118
119            // SAFETY: Upholds invariant 3, head is 0 and in bounds, tail is only ever `cap` if the buffer
120            // is entirely full
121            self.tail = s1_len + s2_len;
122            self.head = 0;
123        }
124        // SAFETY: Upholds invariant 1: the buffer was just allocated correctly
125        self.buf = new_buf;
126        self.cap = new_cap;
127    }
128
129    #[allow(dead_code)]
130    pub fn push_back(&mut self, byte: u8) {
131        self.reserve(1);
132
133        // SAFETY: Upholds invariant 2 by writing initialized memory
134        unsafe { self.buf.as_ptr().add(self.tail).write(byte) };
135        // SAFETY: Upholds invariant 3 by wrapping `tail` around
136        self.tail = (self.tail + 1) % self.cap;
137    }
138
139    /// Fetch the byte stored at the selected index from the buffer, returning it, or
140    /// `None` if the index is out of bounds.
141    #[allow(dead_code)]
142    pub fn get(&self, idx: usize) -> Option<u8> {
143        if idx < self.len() {
144            // SAFETY: Establishes invariants on memory being initialized and the range being in-bounds
145            // (Invariants 2 & 3)
146            let idx = (self.head + idx) % self.cap;
147            Some(unsafe { self.buf.as_ptr().add(idx).read() })
148        } else {
149            None
150        }
151    }
152    /// Append the provided data to the end of `self`.
153    pub fn extend(&mut self, data: &[u8]) {
154        let len = data.len();
155        let ptr = data.as_ptr();
156        if len == 0 {
157            return;
158        }
159
160        self.reserve(len);
161
162        debug_assert!(self.len() + len <= self.cap - 1);
163        debug_assert!(self.free() >= len, "free: {} len: {}", self.free(), len);
164
165        let ((f1_ptr, f1_len), (f2_ptr, f2_len)) = self.free_slice_parts();
166        debug_assert!(f1_len + f2_len >= len, "{} + {} < {}", f1_len, f2_len, len);
167
168        let in_f1 = usize::min(len, f1_len);
169
170        let in_f2 = len - in_f1;
171
172        debug_assert!(in_f1 + in_f2 == len);
173
174        unsafe {
175            // SAFETY: `in_f₁ + in_f₂ = len`, so this writes `len` bytes total
176            // upholding invariant 2
177            if in_f1 > 0 {
178                f1_ptr.copy_from_nonoverlapping(ptr, in_f1);
179            }
180            if in_f2 > 0 {
181                f2_ptr.copy_from_nonoverlapping(ptr.add(in_f1), in_f2);
182            }
183        }
184        // SAFETY: Upholds invariant 3 by wrapping `tail` around.
185        self.tail = (self.tail + len) % self.cap;
186    }
187
188    /// Advance head past `amount` elements, effectively removing
189    /// them from the buffer.
190    pub fn drop_first_n(&mut self, amount: usize) {
191        debug_assert!(amount <= self.len());
192        let amount = usize::min(amount, self.len());
193        // SAFETY: we maintain invariant 2 here since this will always lead to a smaller buffer
194        // for amount≤len
195        self.head = (self.head + amount) % self.cap;
196    }
197
198    /// Return the size of the two contiguous occupied sections of memory used
199    /// by the buffer.
200    // SAFETY: other code relies on this pointing to initialized halves of the buffer only
201    fn data_slice_lengths(&self) -> (usize, usize) {
202        let len_after_head;
203        let len_to_tail;
204
205        // TODO can we do this branchless?
206        if self.tail >= self.head {
207            len_after_head = self.tail - self.head;
208            len_to_tail = 0;
209        } else {
210            len_after_head = self.cap - self.head;
211            len_to_tail = self.tail;
212        }
213        (len_after_head, len_to_tail)
214    }
215
216    // SAFETY: other code relies on this pointing to initialized halves of the buffer only
217    /// Return pointers to the head and tail, and the length of each section.
218    fn data_slice_parts(&self) -> ((*const u8, usize), (*const u8, usize)) {
219        let (len_after_head, len_to_tail) = self.data_slice_lengths();
220
221        (
222            (unsafe { self.buf.as_ptr().add(self.head) }, len_after_head),
223            (self.buf.as_ptr(), len_to_tail),
224        )
225    }
226
227    /// Return references to each part of the ring buffer.
228    pub fn as_slices(&self) -> (&[u8], &[u8]) {
229        let (s1, s2) = self.data_slice_parts();
230        unsafe {
231            // SAFETY: relies on the behavior of data_slice_parts for producing initialized memory
232            let s1 = slice::from_raw_parts(s1.0, s1.1);
233            let s2 = slice::from_raw_parts(s2.0, s2.1);
234            (s1, s2)
235        }
236    }
237
238    // SAFETY: other code relies on this producing the lengths of free zones
239    // at the beginning/end of the buffer. Everything else must be initialized
240    /// Returns the size of the two unoccupied sections of memory used by the buffer.
241    fn free_slice_lengths(&self) -> (usize, usize) {
242        let len_to_head;
243        let len_after_tail;
244
245        // TODO can we do this branchless?
246        if self.tail < self.head {
247            len_after_tail = self.head - self.tail;
248            len_to_head = 0;
249        } else {
250            len_after_tail = self.cap - self.tail;
251            len_to_head = self.head;
252        }
253        (len_to_head, len_after_tail)
254    }
255
256    /// Returns mutable references to the available space and the size of that available space,
257    /// for the two sections in the buffer.
258    // SAFETY: Other code relies on this pointing to the free zones, data after the first and before the second must
259    // be valid
260    fn free_slice_parts(&self) -> ((*mut u8, usize), (*mut u8, usize)) {
261        let (len_to_head, len_after_tail) = self.free_slice_lengths();
262
263        (
264            (unsafe { self.buf.as_ptr().add(self.tail) }, len_after_tail),
265            (self.buf.as_ptr(), len_to_head),
266        )
267    }
268
269    /// Copies elements from the provided range to the end of the buffer.
270    #[allow(dead_code)]
271    pub fn extend_from_within(&mut self, start: usize, len: usize) {
272        if start + len > self.len() {
273            panic!(
274                "Calls to this functions must respect start ({}) + len ({}) <= self.len() ({})!",
275                start,
276                len,
277                self.len()
278            );
279        }
280
281        self.reserve(len);
282
283        // SAFETY: Requirements checked:
284        // 1. explicitly checked above, resulting in a panic if it does not hold
285        // 2. explicitly reserved enough memory
286        unsafe { self.extend_from_within_unchecked(start, len) }
287    }
288
289    /// Copies data from the provided range to the end of the buffer, without
290    /// first verifying that the unoccupied capacity is available.
291    ///
292    /// SAFETY:
293    /// For this to be safe two requirements need to hold:
294    /// 1. start + len <= self.len() so we do not copy uninitialised memory
295    /// 2. More then len reserved space so we do not write out-of-bounds
296    #[warn(unsafe_op_in_unsafe_fn)]
297    pub unsafe fn extend_from_within_unchecked(&mut self, start: usize, len: usize) {
298        debug_assert!(start + len <= self.len());
299        debug_assert!(self.free() >= len);
300
301        if self.head < self.tail {
302            // Continuous source section and possibly non continuous write section:
303            //
304            //            H           T
305            // Read:  ____XXXXSSSSXXXX________
306            // Write: ________________DDDD____
307            //
308            // H: Head position (first readable byte)
309            // T: Tail position (first writable byte)
310            // X: Uninvolved bytes in the readable section
311            // S: Source bytes, to be copied to D bytes
312            // D: Destination bytes, going to be copied from S bytes
313            // _: Uninvolved bytes in the writable section
314            let after_tail = usize::min(len, self.cap - self.tail);
315
316            let src = (
317                // SAFETY: `len <= isize::MAX` and fits the memory range of `buf`
318                unsafe { self.buf.as_ptr().add(self.head + start) }.cast_const(),
319                // Src length (see above diagram)
320                self.tail - self.head - start,
321            );
322
323            let dst = (
324                // SAFETY: `len <= isize::MAX` and fits the memory range of `buf`
325                unsafe { self.buf.as_ptr().add(self.tail) },
326                // Dst length (see above diagram)
327                self.cap - self.tail,
328            );
329
330            // SAFETY: `src` points at initialized data, `dst` points to writable memory
331            // and does not overlap `src`.
332            unsafe { copy_bytes_overshooting(src, dst, after_tail) }
333
334            if after_tail < len {
335                // The write section was not continuous:
336                //
337                //            H           T
338                // Read:  ____XXXXSSSSXXXX__
339                // Write: DD______________DD
340                //
341                // H: Head position (first readable byte)
342                // T: Tail position (first writable byte)
343                // X: Uninvolved bytes in the readable section
344                // S: Source bytes, to be copied to D bytes
345                // D: Destination bytes, going to be copied from S bytes
346                // _: Uninvolved bytes in the writable section
347
348                let src = (
349                    // SAFETY: we are still within the memory range of `buf`
350                    unsafe { src.0.add(after_tail) },
351                    // Src length (see above diagram)
352                    src.1 - after_tail,
353                );
354                let dst = (
355                    self.buf.as_ptr(),
356                    // Dst length overflowing (see above diagram)
357                    self.head,
358                );
359
360                // SAFETY: `src` points at initialized data, `dst` points to writable memory
361                // and does not overlap `src`.
362                unsafe { copy_bytes_overshooting(src, dst, len - after_tail) }
363            }
364        } else {
365            if self.head + start > self.cap {
366                // Continuous read section and destination section:
367                //
368                //                  T           H
369                // Read:  XXSSSSXXXX____________XX
370                // Write: __________DDDD__________
371                //
372                // H: Head position (first readable byte)
373                // T: Tail position (first writable byte)
374                // X: Uninvolved bytes in the readable section
375                // S: Source bytes, to be copied to D bytes
376                // D: Destination bytes, going to be copied from S bytes
377                // _: Uninvolved bytes in the writable section
378
379                let start = (self.head + start) % self.cap;
380
381                let src = (
382                    // SAFETY: `len <= isize::MAX` and fits the memory range of `buf`
383                    unsafe { self.buf.as_ptr().add(start) }.cast_const(),
384                    // Src length (see above diagram)
385                    self.tail - start,
386                );
387
388                let dst = (
389                    // SAFETY: `len <= isize::MAX` and fits the memory range of `buf`
390                    unsafe { self.buf.as_ptr().add(self.tail) }, // Dst length (see above diagram)
391                    // Dst length (see above diagram)
392                    self.head - self.tail,
393                );
394
395                // SAFETY: `src` points at initialized data, `dst` points to writable memory
396                // and does not overlap `src`.
397                unsafe { copy_bytes_overshooting(src, dst, len) }
398            } else {
399                // Possibly non continuous read section and continuous destination section:
400                //
401                //            T           H
402                // Read:  XXXX____________XXSSSSXX
403                // Write: ____DDDD________________
404                //
405                // H: Head position (first readable byte)
406                // T: Tail position (first writable byte)
407                // X: Uninvolved bytes in the readable section
408                // S: Source bytes, to be copied to D bytes
409                // D: Destination bytes, going to be copied from S bytes
410                // _: Uninvolved bytes in the writable section
411
412                let after_start = usize::min(len, self.cap - self.head - start);
413
414                let src = (
415                    // SAFETY: `len <= isize::MAX` and fits the memory range of `buf`
416                    unsafe { self.buf.as_ptr().add(self.head + start) }.cast_const(),
417                    // Src length - chunk 1 (see above diagram on the right)
418                    self.cap - self.head - start,
419                );
420
421                let dst = (
422                    // SAFETY: `len <= isize::MAX` and fits the memory range of `buf`
423                    unsafe { self.buf.as_ptr().add(self.tail) },
424                    // Dst length (see above diagram)
425                    self.head - self.tail,
426                );
427
428                // SAFETY: `src` points at initialized data, `dst` points to writable memory
429                // and does not overlap `src`.
430                unsafe { copy_bytes_overshooting(src, dst, after_start) }
431
432                if after_start < len {
433                    // The read section was not continuous:
434                    //
435                    //                T           H
436                    // Read:  SSXXXXXX____________XXSS
437                    // Write: ________DDDD____________
438                    //
439                    // H: Head position (first readable byte)
440                    // T: Tail position (first writable byte)
441                    // X: Uninvolved bytes in the readable section
442                    // S: Source bytes, to be copied to D bytes
443                    // D: Destination bytes, going to be copied from S bytes
444                    // _: Uninvolved bytes in the writable section
445
446                    let src = (
447                        self.buf.as_ptr().cast_const(),
448                        // Src length - chunk 2 (see above diagram on the left)
449                        self.tail,
450                    );
451
452                    let dst = (
453                        // SAFETY: we are still within the memory range of `buf`
454                        unsafe { dst.0.add(after_start) },
455                        // Dst length (see above diagram)
456                        dst.1 - after_start,
457                    );
458
459                    // SAFETY: `src` points at initialized data, `dst` points to writable memory
460                    // and does not overlap `src`.
461                    unsafe { copy_bytes_overshooting(src, dst, len - after_start) }
462                }
463            }
464        }
465
466        self.tail = (self.tail + len) % self.cap;
467    }
468
469    #[allow(dead_code)]
470    /// This function is functionally the same as [RingBuffer::extend_from_within_unchecked],
471    /// but it does not contain any branching operations.
472    ///
473    /// SAFETY:
474    /// Needs start + len <= self.len()
475    /// And more then len reserved space
476    pub unsafe fn extend_from_within_unchecked_branchless(&mut self, start: usize, len: usize) {
477        // data slices in raw parts
478        let ((s1_ptr, s1_len), (s2_ptr, s2_len)) = self.data_slice_parts();
479
480        debug_assert!(len <= s1_len + s2_len, "{} > {} + {}", len, s1_len, s2_len);
481
482        // calc the actually wanted slices in raw parts
483        let start_in_s1 = usize::min(s1_len, start);
484        let end_in_s1 = usize::min(s1_len, start + len);
485        let m1_ptr = s1_ptr.add(start_in_s1);
486        let m1_len = end_in_s1 - start_in_s1;
487
488        debug_assert!(end_in_s1 <= s1_len);
489        debug_assert!(start_in_s1 <= s1_len);
490
491        let start_in_s2 = start.saturating_sub(s1_len);
492        let end_in_s2 = start_in_s2 + (len - m1_len);
493        let m2_ptr = s2_ptr.add(start_in_s2);
494        let m2_len = end_in_s2 - start_in_s2;
495
496        debug_assert!(start_in_s2 <= s2_len);
497        debug_assert!(end_in_s2 <= s2_len);
498
499        debug_assert_eq!(len, m1_len + m2_len);
500
501        // the free slices, must hold: f1_len + f2_len >= m1_len + m2_len
502        let ((f1_ptr, f1_len), (f2_ptr, f2_len)) = self.free_slice_parts();
503
504        debug_assert!(f1_len + f2_len >= m1_len + m2_len);
505
506        // calc how many from where bytes go where
507        let m1_in_f1 = usize::min(m1_len, f1_len);
508        let m1_in_f2 = m1_len - m1_in_f1;
509        let m2_in_f1 = usize::min(f1_len - m1_in_f1, m2_len);
510        let m2_in_f2 = m2_len - m2_in_f1;
511
512        debug_assert_eq!(m1_len, m1_in_f1 + m1_in_f2);
513        debug_assert_eq!(m2_len, m2_in_f1 + m2_in_f2);
514        debug_assert!(f1_len >= m1_in_f1 + m2_in_f1);
515        debug_assert!(f2_len >= m1_in_f2 + m2_in_f2);
516        debug_assert_eq!(len, m1_in_f1 + m2_in_f1 + m1_in_f2 + m2_in_f2);
517
518        debug_assert!(self.buf.as_ptr().add(self.cap) > f1_ptr.add(m1_in_f1 + m2_in_f1));
519        debug_assert!(self.buf.as_ptr().add(self.cap) > f2_ptr.add(m1_in_f2 + m2_in_f2));
520
521        debug_assert!((m1_in_f2 > 0) ^ (m2_in_f1 > 0) || (m1_in_f2 == 0 && m2_in_f1 == 0));
522
523        copy_with_checks(
524            m1_ptr, m2_ptr, f1_ptr, f2_ptr, m1_in_f1, m2_in_f1, m1_in_f2, m2_in_f2,
525        );
526        self.tail = (self.tail + len) % self.cap;
527    }
528}
529
530impl Drop for RingBuffer {
531    fn drop(&mut self) {
532        if self.cap == 0 {
533            return;
534        }
535
536        // SAFETY: is we were succesfully able to construct this layout when we allocated then it's also valid do so now
537        // Relies on / establishes invariant 1
538        let current_layout = unsafe { Layout::array::<u8>(self.cap).unwrap_unchecked() };
539
540        unsafe {
541            dealloc(self.buf.as_ptr(), current_layout);
542        }
543    }
544}
545
546/// Similar to ptr::copy_nonoverlapping
547///
548/// But it might overshoot the desired copy length if deemed useful
549///
550/// src and dst specify the entire length they are eligible for reading/writing respectively
551/// in addition to the desired copy length.
552///
553/// This function will then copy in chunks and might copy up to chunk size - 1 more bytes from src to dst
554/// if that operation does not read/write memory that does not belong to src/dst.
555///
556/// The chunk size is not part of the contract and may change depending on the target platform.
557///
558/// If that isn't possible we just fall back to ptr::copy_nonoverlapping
559#[inline(always)]
560unsafe fn copy_bytes_overshooting(
561    src: (*const u8, usize),
562    dst: (*mut u8, usize),
563    copy_at_least: usize,
564) {
565    // By default use usize as the copy size
566    #[cfg(all(not(target_feature = "sse2"), not(target_feature = "neon")))]
567    type CopyType = usize;
568
569    // Use u128 if we detect a simd feature
570    #[cfg(target_feature = "neon")]
571    type CopyType = u128;
572    #[cfg(target_feature = "sse2")]
573    type CopyType = u128;
574
575    const COPY_AT_ONCE_SIZE: usize = core::mem::size_of::<CopyType>();
576    let min_buffer_size = usize::min(src.1, dst.1);
577
578    // Can copy in just one read+write, very common case
579    if min_buffer_size >= COPY_AT_ONCE_SIZE && copy_at_least <= COPY_AT_ONCE_SIZE {
580        dst.0
581            .cast::<CopyType>()
582            .write_unaligned(src.0.cast::<CopyType>().read_unaligned())
583    } else {
584        let copy_multiple = copy_at_least.next_multiple_of(COPY_AT_ONCE_SIZE);
585        // Can copy in multiple simple instructions
586        if min_buffer_size >= copy_multiple {
587            let mut src_ptr = src.0.cast::<CopyType>();
588            let src_ptr_end = src.0.add(copy_multiple).cast::<CopyType>();
589            let mut dst_ptr = dst.0.cast::<CopyType>();
590
591            while src_ptr < src_ptr_end {
592                dst_ptr.write_unaligned(src_ptr.read_unaligned());
593                src_ptr = src_ptr.add(1);
594                dst_ptr = dst_ptr.add(1);
595            }
596        } else {
597            // Fall back to standard memcopy
598            dst.0.copy_from_nonoverlapping(src.0, copy_at_least);
599        }
600    }
601
602    debug_assert_eq!(
603        slice::from_raw_parts(src.0, copy_at_least),
604        slice::from_raw_parts(dst.0, copy_at_least)
605    );
606}
607
608#[allow(dead_code)]
609#[inline(always)]
610#[allow(clippy::too_many_arguments)]
611unsafe fn copy_without_checks(
612    m1_ptr: *const u8,
613    m2_ptr: *const u8,
614    f1_ptr: *mut u8,
615    f2_ptr: *mut u8,
616    m1_in_f1: usize,
617    m2_in_f1: usize,
618    m1_in_f2: usize,
619    m2_in_f2: usize,
620) {
621    f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
622    f1_ptr
623        .add(m1_in_f1)
624        .copy_from_nonoverlapping(m2_ptr, m2_in_f1);
625
626    f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2);
627    f2_ptr
628        .add(m1_in_f2)
629        .copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2);
630}
631
632#[allow(dead_code)]
633#[inline(always)]
634#[allow(clippy::too_many_arguments)]
635unsafe fn copy_with_checks(
636    m1_ptr: *const u8,
637    m2_ptr: *const u8,
638    f1_ptr: *mut u8,
639    f2_ptr: *mut u8,
640    m1_in_f1: usize,
641    m2_in_f1: usize,
642    m1_in_f2: usize,
643    m2_in_f2: usize,
644) {
645    if m1_in_f1 != 0 {
646        f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
647    }
648    if m2_in_f1 != 0 {
649        f1_ptr
650            .add(m1_in_f1)
651            .copy_from_nonoverlapping(m2_ptr, m2_in_f1);
652    }
653
654    if m1_in_f2 != 0 {
655        f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2);
656    }
657    if m2_in_f2 != 0 {
658        f2_ptr
659            .add(m1_in_f2)
660            .copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2);
661    }
662}
663
664#[allow(dead_code)]
665#[inline(always)]
666#[allow(clippy::too_many_arguments)]
667unsafe fn copy_with_nobranch_check(
668    m1_ptr: *const u8,
669    m2_ptr: *const u8,
670    f1_ptr: *mut u8,
671    f2_ptr: *mut u8,
672    m1_in_f1: usize,
673    m2_in_f1: usize,
674    m1_in_f2: usize,
675    m2_in_f2: usize,
676) {
677    let case = (m1_in_f1 > 0) as usize
678        | (((m2_in_f1 > 0) as usize) << 1)
679        | (((m1_in_f2 > 0) as usize) << 2)
680        | (((m2_in_f2 > 0) as usize) << 3);
681
682    match case {
683        0 => {}
684
685        // one bit set
686        1 => {
687            f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
688        }
689        2 => {
690            f1_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f1);
691        }
692        4 => {
693            f2_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f2);
694        }
695        8 => {
696            f2_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f2);
697        }
698
699        // two bit set
700        3 => {
701            f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
702            f1_ptr
703                .add(m1_in_f1)
704                .copy_from_nonoverlapping(m2_ptr, m2_in_f1);
705        }
706        5 => {
707            f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
708            f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2);
709        }
710        6 => core::hint::unreachable_unchecked(),
711        7 => core::hint::unreachable_unchecked(),
712        9 => {
713            f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
714            f2_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f2);
715        }
716        10 => {
717            f1_ptr.copy_from_nonoverlapping(m2_ptr, m2_in_f1);
718            f2_ptr.copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2);
719        }
720        12 => {
721            f2_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f2);
722            f2_ptr
723                .add(m1_in_f2)
724                .copy_from_nonoverlapping(m2_ptr, m2_in_f2);
725        }
726
727        // three bit set
728        11 => {
729            f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
730            f1_ptr
731                .add(m1_in_f1)
732                .copy_from_nonoverlapping(m2_ptr, m2_in_f1);
733            f2_ptr.copy_from_nonoverlapping(m2_ptr.add(m2_in_f1), m2_in_f2);
734        }
735        13 => {
736            f1_ptr.copy_from_nonoverlapping(m1_ptr, m1_in_f1);
737            f2_ptr.copy_from_nonoverlapping(m1_ptr.add(m1_in_f1), m1_in_f2);
738            f2_ptr
739                .add(m1_in_f2)
740                .copy_from_nonoverlapping(m2_ptr, m2_in_f2);
741        }
742        14 => core::hint::unreachable_unchecked(),
743        15 => core::hint::unreachable_unchecked(),
744        _ => core::hint::unreachable_unchecked(),
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::RingBuffer;
751
752    #[test]
753    fn smoke() {
754        let mut rb = RingBuffer::new();
755
756        rb.reserve(15);
757        assert_eq!(17, rb.cap);
758
759        rb.extend(b"0123456789");
760        assert_eq!(rb.len(), 10);
761        assert_eq!(rb.as_slices().0, b"0123456789");
762        assert_eq!(rb.as_slices().1, b"");
763
764        rb.drop_first_n(5);
765        assert_eq!(rb.len(), 5);
766        assert_eq!(rb.as_slices().0, b"56789");
767        assert_eq!(rb.as_slices().1, b"");
768
769        rb.extend_from_within(2, 3);
770        assert_eq!(rb.len(), 8);
771        assert_eq!(rb.as_slices().0, b"56789789");
772        assert_eq!(rb.as_slices().1, b"");
773
774        rb.extend_from_within(0, 3);
775        assert_eq!(rb.len(), 11);
776        assert_eq!(rb.as_slices().0, b"56789789567");
777        assert_eq!(rb.as_slices().1, b"");
778
779        rb.extend_from_within(0, 2);
780        assert_eq!(rb.len(), 13);
781        assert_eq!(rb.as_slices().0, b"567897895675");
782        assert_eq!(rb.as_slices().1, b"6");
783
784        rb.drop_first_n(11);
785        assert_eq!(rb.len(), 2);
786        assert_eq!(rb.as_slices().0, b"5");
787        assert_eq!(rb.as_slices().1, b"6");
788
789        rb.extend(b"0123456789");
790        assert_eq!(rb.len(), 12);
791        assert_eq!(rb.as_slices().0, b"5");
792        assert_eq!(rb.as_slices().1, b"60123456789");
793
794        rb.drop_first_n(11);
795        assert_eq!(rb.len(), 1);
796        assert_eq!(rb.as_slices().0, b"9");
797        assert_eq!(rb.as_slices().1, b"");
798
799        rb.extend(b"0123456789");
800        assert_eq!(rb.len(), 11);
801        assert_eq!(rb.as_slices().0, b"9012345");
802        assert_eq!(rb.as_slices().1, b"6789");
803    }
804
805    #[test]
806    fn edge_cases() {
807        // Fill exactly, then empty then fill again
808        let mut rb = RingBuffer::new();
809        rb.reserve(16);
810        assert_eq!(17, rb.cap);
811        rb.extend(b"0123456789012345");
812        assert_eq!(17, rb.cap);
813        assert_eq!(16, rb.len());
814        assert_eq!(0, rb.free());
815        rb.drop_first_n(16);
816        assert_eq!(0, rb.len());
817        assert_eq!(16, rb.free());
818        rb.extend(b"0123456789012345");
819        assert_eq!(16, rb.len());
820        assert_eq!(0, rb.free());
821        assert_eq!(17, rb.cap);
822        assert_eq!(1, rb.as_slices().0.len());
823        assert_eq!(15, rb.as_slices().1.len());
824
825        rb.clear();
826
827        // data in both slices and then reserve
828        rb.extend(b"0123456789012345");
829        rb.drop_first_n(8);
830        rb.extend(b"67890123");
831        assert_eq!(16, rb.len());
832        assert_eq!(0, rb.free());
833        assert_eq!(17, rb.cap);
834        assert_eq!(9, rb.as_slices().0.len());
835        assert_eq!(7, rb.as_slices().1.len());
836        rb.reserve(1);
837        assert_eq!(16, rb.len());
838        assert_eq!(16, rb.free());
839        assert_eq!(33, rb.cap);
840        assert_eq!(16, rb.as_slices().0.len());
841        assert_eq!(0, rb.as_slices().1.len());
842
843        rb.clear();
844
845        // fill exactly, then extend from within
846        rb.extend(b"0123456789012345");
847        rb.extend_from_within(0, 16);
848        assert_eq!(32, rb.len());
849        assert_eq!(0, rb.free());
850        assert_eq!(33, rb.cap);
851        assert_eq!(32, rb.as_slices().0.len());
852        assert_eq!(0, rb.as_slices().1.len());
853
854        // extend from within cases
855        let mut rb = RingBuffer::new();
856        rb.reserve(8);
857        rb.extend(b"01234567");
858        rb.drop_first_n(5);
859        rb.extend_from_within(0, 3);
860        assert_eq!(4, rb.as_slices().0.len());
861        assert_eq!(2, rb.as_slices().1.len());
862
863        rb.drop_first_n(2);
864        assert_eq!(2, rb.as_slices().0.len());
865        assert_eq!(2, rb.as_slices().1.len());
866        rb.extend_from_within(0, 4);
867        assert_eq!(2, rb.as_slices().0.len());
868        assert_eq!(6, rb.as_slices().1.len());
869
870        rb.drop_first_n(2);
871        assert_eq!(6, rb.as_slices().0.len());
872        assert_eq!(0, rb.as_slices().1.len());
873        rb.drop_first_n(2);
874        assert_eq!(4, rb.as_slices().0.len());
875        assert_eq!(0, rb.as_slices().1.len());
876        rb.extend_from_within(0, 4);
877        assert_eq!(7, rb.as_slices().0.len());
878        assert_eq!(1, rb.as_slices().1.len());
879
880        let mut rb = RingBuffer::new();
881        rb.reserve(8);
882        rb.extend(b"11111111");
883        rb.drop_first_n(7);
884        rb.extend(b"111");
885        assert_eq!(2, rb.as_slices().0.len());
886        assert_eq!(2, rb.as_slices().1.len());
887        rb.extend_from_within(0, 4);
888        assert_eq!(b"11", rb.as_slices().0);
889        assert_eq!(b"111111", rb.as_slices().1);
890    }
891}