refactor(transformer): add wrapper around NonNull (#6115)

Introduce a wrapper around `NonNull` which enables methods which exist on `std::ptr::NonNull` but are not yet stable in our MSRV. These methods remove a lot of boilerplate code from `Stack` and `NonEmptyStack` and make them easier to understand - which is important, since they contain so much unsafe code.
This commit is contained in:
overlookmotel 2024-09-27 16:48:37 +00:00
parent c50500ec42
commit 9ac80bd2d2
5 changed files with 134 additions and 31 deletions

View file

@ -3,12 +3,12 @@
use std::{ use std::{
alloc::{self, Layout}, alloc::{self, Layout},
mem::{align_of, size_of}, mem::{align_of, size_of},
ptr::{self, NonNull}, ptr,
}; };
use assert_unchecked::assert_unchecked; use assert_unchecked::assert_unchecked;
use super::StackCapacity; use super::{NonNull, StackCapacity};
pub trait StackCommon<T>: StackCapacity<T> { pub trait StackCommon<T>: StackCapacity<T> {
// Getter setter methods defined by implementer // Getter setter methods defined by implementer
@ -34,11 +34,8 @@ pub trait StackCommon<T>: StackCapacity<T> {
let layout = Self::layout_for(capacity_bytes); let layout = Self::layout_for(capacity_bytes);
let (start, end) = allocate(layout); let (start, end) = allocate(layout);
// SAFETY: `start` and `end` are `NonNull` - just casting them // SAFETY: `layout_for` produces a layout with `T`'s alignment, so pointers are aligned for `T`
let start = NonNull::new_unchecked(start.as_ptr().cast::<T>()); (start.cast::<T>(), end.cast::<T>())
let end = NonNull::new_unchecked(end.as_ptr().cast::<T>());
(start, end)
} }
/// Grow allocation. /// Grow allocation.
@ -64,7 +61,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
// `MAX_CAPACITY_BYTES` is also a multiple of `size_of::<T>()`. // `MAX_CAPACITY_BYTES` is also a multiple of `size_of::<T>()`.
// So new capacity in bytes must be a multiple of `size_of::<T>()`. // So new capacity in bytes must be a multiple of `size_of::<T>()`.
// `MAX_CAPACITY_BYTES <= isize::MAX`. // `MAX_CAPACITY_BYTES <= isize::MAX`.
let old_start_ptr = NonNull::new_unchecked(self.start().as_ptr().cast::<u8>()); let old_start_ptr = self.start().cast::<u8>();
let old_layout = Self::layout_for(self.capacity_bytes()); let old_layout = Self::layout_for(self.capacity_bytes());
let (start, end, current) = grow(old_start_ptr, old_layout, Self::MAX_CAPACITY_BYTES); let (start, end, current) = grow(old_start_ptr, old_layout, Self::MAX_CAPACITY_BYTES);
@ -73,9 +70,9 @@ pub trait StackCommon<T>: StackCapacity<T> {
// All pointers returned from `grow` are aligned for `T`. // All pointers returned from `grow` are aligned for `T`.
// Old capacity and new capacity in bytes are both multiples of `size_of::<T>()`, // Old capacity and new capacity in bytes are both multiples of `size_of::<T>()`,
// so distances `end - start` and `current - start` are both multiples of `size_of::<T>()`. // so distances `end - start` and `current - start` are both multiples of `size_of::<T>()`.
self.set_start(NonNull::new_unchecked(start.as_ptr().cast::<T>())); self.set_start(start.cast::<T>());
self.set_end(NonNull::new_unchecked(end.as_ptr().cast::<T>())); self.set_end(end.cast::<T>());
self.set_cursor(NonNull::new_unchecked(current.as_ptr().cast::<T>())); self.set_cursor(current.cast::<T>());
} }
/// Deallocate stack memory. /// Deallocate stack memory.
@ -149,7 +146,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
#[expect(clippy::cast_sign_loss)] #[expect(clippy::cast_sign_loss)]
unsafe { unsafe {
assert_unchecked!(self.cursor() >= self.start()); assert_unchecked!(self.cursor() >= self.start());
self.cursor().as_ptr().offset_from(self.start().as_ptr()) as usize self.cursor().offset_from(self.start()) as usize
} }
} }
@ -166,7 +163,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
#[expect(clippy::cast_sign_loss)] #[expect(clippy::cast_sign_loss)]
unsafe { unsafe {
assert_unchecked!(self.end() >= self.start()); assert_unchecked!(self.end() >= self.start());
self.end().as_ptr().offset_from(self.start().as_ptr()) as usize self.end().offset_from(self.start()) as usize
} }
} }
@ -183,7 +180,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
#[expect(clippy::cast_sign_loss)] #[expect(clippy::cast_sign_loss)]
unsafe { unsafe {
assert_unchecked!(self.end() >= self.start()); assert_unchecked!(self.end() >= self.start());
self.end().as_ptr().byte_offset_from(self.start().as_ptr()) as usize self.end().byte_offset_from(self.start()) as usize
} }
} }
} }
@ -209,7 +206,7 @@ unsafe fn allocate(layout: Layout) -> (/* start */ NonNull<u8>, /* end */ NonNul
// SAFETY: We checked `ptr` is non-null // SAFETY: We checked `ptr` is non-null
let start = NonNull::new_unchecked(ptr); let start = NonNull::new_unchecked(ptr);
// SAFETY: We allocated `layout.size()` bytes, so `end` is end of allocation // SAFETY: We allocated `layout.size()` bytes, so `end` is end of allocation
let end = NonNull::new_unchecked(ptr.add(layout.size())); let end = start.add(layout.size());
(start, end) (start, end)
} }
@ -249,7 +246,7 @@ unsafe fn grow(
// So `new_capacity_bytes` must be a multiple of `size_of::<T>()`. // So `new_capacity_bytes` must be a multiple of `size_of::<T>()`.
// `new_capacity_bytes` is `<= MAX_CAPACITY_BYTES`, so is a legal allocation size. // `new_capacity_bytes` is `<= MAX_CAPACITY_BYTES`, so is a legal allocation size.
// `layout_for` produces a layout with `T`'s alignment, so `new_ptr` is aligned for `T`. // `layout_for` produces a layout with `T`'s alignment, so `new_ptr` is aligned for `T`.
let new_ptr = unsafe { let new_start = unsafe {
let old_ptr = old_start.as_ptr(); let old_ptr = old_start.as_ptr();
let new_ptr = alloc::realloc(old_ptr, old_layout, new_capacity_bytes); let new_ptr = alloc::realloc(old_ptr, old_layout, new_capacity_bytes);
if new_ptr.is_null() { if new_ptr.is_null() {
@ -257,7 +254,7 @@ unsafe fn grow(
Layout::from_size_align_unchecked(old_capacity_bytes, old_layout.align()); Layout::from_size_align_unchecked(old_capacity_bytes, old_layout.align());
alloc::handle_alloc_error(new_layout); alloc::handle_alloc_error(new_layout);
} }
new_ptr NonNull::new_unchecked(new_ptr)
}; };
// Update pointers. // Update pointers.
@ -273,9 +270,8 @@ unsafe fn grow(
// //
// SAFETY: We checked that `new_ptr` is non-null. // SAFETY: We checked that `new_ptr` is non-null.
// `old_capacity_bytes < new_capacity_bytes` (ensured above), so `new_cursor` must be in bounds. // `old_capacity_bytes < new_capacity_bytes` (ensured above), so `new_cursor` must be in bounds.
let new_start = NonNull::new_unchecked(new_ptr); let new_end = new_start.add(new_capacity_bytes);
let new_end = NonNull::new_unchecked(new_ptr.add(new_capacity_bytes)); let new_cursor = new_start.add(old_capacity_bytes);
let new_cursor = NonNull::new_unchecked(new_ptr.add(old_capacity_bytes));
(new_start, new_end, new_cursor) (new_start, new_end, new_cursor)
} }

View file

@ -1,11 +1,13 @@
mod capacity; mod capacity;
mod common; mod common;
mod non_empty; mod non_empty;
mod non_null;
mod sparse; mod sparse;
mod standard; mod standard;
use capacity::StackCapacity; use capacity::StackCapacity;
use common::StackCommon; use common::StackCommon;
pub use non_empty::NonEmptyStack; pub use non_empty::NonEmptyStack;
use non_null::NonNull;
pub use sparse::SparseStack; pub use sparse::SparseStack;
pub use standard::Stack; pub use standard::Stack;

View file

@ -1,8 +1,8 @@
#![expect(clippy::unnecessary_safety_comment)] #![expect(clippy::unnecessary_safety_comment)]
use std::{mem::size_of, ptr::NonNull}; use std::mem::size_of;
use super::{StackCapacity, StackCommon}; use super::{NonNull, StackCapacity, StackCommon};
/// A stack which can never be empty. /// A stack which can never be empty.
/// ///
@ -199,7 +199,7 @@ impl<T> NonEmptyStack<T> {
// of allocation. So advancing by a `T` cannot be out of bounds. // of allocation. So advancing by a `T` cannot be out of bounds.
// The distance between `self.cursor` and `self.end` is always a multiple of `size_of::<T>()`, // The distance between `self.cursor` and `self.end` is always a multiple of `size_of::<T>()`,
// so `==` check is sufficient to detect when full to capacity. // so `==` check is sufficient to detect when full to capacity.
let new_cursor = unsafe { NonNull::new_unchecked(self.cursor.as_ptr().add(1)) }; let new_cursor = unsafe { self.cursor.add(1) };
if new_cursor == self.end { if new_cursor == self.end {
// Needs to grow // Needs to grow
// SAFETY: Stack is full to capacity // SAFETY: Stack is full to capacity
@ -264,7 +264,7 @@ impl<T> NonEmptyStack<T> {
let value = self.cursor.as_ptr().read(); let value = self.cursor.as_ptr().read();
// SAFETY: Caller guarantees there's at least 2 entries on stack, so subtracting 1 // SAFETY: Caller guarantees there's at least 2 entries on stack, so subtracting 1
// cannot be out of bounds // cannot be out of bounds
self.cursor = NonNull::new_unchecked(self.cursor.as_ptr().sub(1)); self.cursor = self.cursor.sub(1);
value value
} }

View file

@ -0,0 +1,105 @@
use std::{cmp::Ordering, ptr::NonNull as NativeNonNull};
/// Wrapper around `NonNull<T>`, which adds methods `add`, `sub`, `offset_from` and `byte_offset_from`.
/// These methods exist on `std::ptr::NonNull`, and became stable in Rust 1.80.0, but are not yet
/// stable in our MSRV.
///
/// These methods are much cleaner than the workarounds required in older Rust versions,
/// and make code using them easier to understand.
///
/// Once we bump MSRV and these methods are natively supported, this type can be removed.
/// `#[expect(clippy::incompatible_msrv)]` on `non_null_add_is_not_stable` below will trigger
/// a lint warning when that happens.
/// Then this module can be deleted, and all uses of this type can be switched to `std::ptr::NonNull`.
#[derive(Debug)]
pub struct NonNull<T>(NativeNonNull<T>);
#[expect(dead_code, clippy::incompatible_msrv)]
unsafe fn non_null_add_is_not_stable(ptr: NativeNonNull<u8>) -> NativeNonNull<u8> {
ptr.add(1)
}
impl<T> NonNull<T> {
#[inline]
pub const unsafe fn new_unchecked(ptr: *mut T) -> Self {
Self(NativeNonNull::new_unchecked(ptr))
}
#[inline]
pub const fn dangling() -> Self {
Self(NativeNonNull::dangling())
}
#[inline]
pub const fn as_ptr(self) -> *mut T {
self.0.as_ptr()
}
#[inline]
pub const fn cast<U>(self) -> NonNull<U> {
// SAFETY: `self` is non-null, so it's still non-null after casting
unsafe { NonNull::new_unchecked(self.as_ptr().cast()) }
}
#[inline]
pub const unsafe fn add(self, count: usize) -> Self {
NonNull(NativeNonNull::new_unchecked(self.as_ptr().add(count)))
}
#[inline]
pub const unsafe fn sub(self, count: usize) -> Self {
NonNull(NativeNonNull::new_unchecked(self.as_ptr().sub(count)))
}
#[inline]
pub const unsafe fn offset_from(self, origin: Self) -> isize {
self.as_ptr().offset_from(origin.as_ptr())
}
#[inline]
pub const unsafe fn byte_offset_from(self, origin: Self) -> isize {
self.as_ptr().byte_offset_from(origin.as_ptr())
}
#[inline]
pub const unsafe fn as_ref<'t>(self) -> &'t T {
self.0.as_ref()
}
#[inline]
pub unsafe fn as_mut<'t>(mut self) -> &'t mut T {
self.0.as_mut()
}
}
impl<T> Copy for NonNull<T> {}
impl<T> Clone for NonNull<T> {
#[inline]
fn clone(&self) -> Self {
*self
}
}
impl<T> Eq for NonNull<T> {}
impl<T> PartialEq for NonNull<T> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T> Ord for NonNull<T> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.as_ptr().cmp(&other.as_ptr())
}
}
impl<T> PartialOrd for NonNull<T> {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

View file

@ -1,8 +1,8 @@
#![expect(clippy::unnecessary_safety_comment)] #![expect(clippy::unnecessary_safety_comment)]
use std::{mem::size_of, ptr::NonNull}; use std::mem::size_of;
use super::{StackCapacity, StackCommon}; use super::{NonNull, StackCapacity, StackCommon};
/// A simple stack. /// A simple stack.
/// ///
@ -177,7 +177,7 @@ impl<T> Stack<T> {
// SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`, // SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`,
// and `self.current.sub(1)` points to a valid initialized `T`, if stack is not empty. // and `self.current.sub(1)` points to a valid initialized `T`, if stack is not empty.
// Caller guarantees stack is not empty. // Caller guarantees stack is not empty.
NonNull::new_unchecked(self.cursor.as_ptr().sub(1)).as_ref() self.cursor.sub(1).as_ref()
} }
/// Get mutable reference to last value on stack. /// Get mutable reference to last value on stack.
@ -205,7 +205,7 @@ impl<T> Stack<T> {
// SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`, // SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`,
// and `self.current.sub(1)` points to a valid initialized `T`, if stack is not empty. // and `self.current.sub(1)` points to a valid initialized `T`, if stack is not empty.
// Caller guarantees stack is not empty. // Caller guarantees stack is not empty.
NonNull::new_unchecked(self.cursor.as_ptr().sub(1)).as_mut() self.cursor.sub(1).as_mut()
} }
/// Push value to stack. /// Push value to stack.
@ -224,7 +224,7 @@ impl<T> Stack<T> {
// SAFETY: Cursor is not at end, so `self.cursor` is in bounds for writing // SAFETY: Cursor is not at end, so `self.cursor` is in bounds for writing
unsafe { self.cursor.as_ptr().write(value) }; unsafe { self.cursor.as_ptr().write(value) };
// SAFETY: Cursor is not at end, so advancing by a `T` cannot be out of bounds // SAFETY: Cursor is not at end, so advancing by a `T` cannot be out of bounds
self.cursor = unsafe { NonNull::new_unchecked(self.cursor.as_ptr().add(1)) }; self.cursor = unsafe { self.cursor.add(1) };
} }
} }
@ -260,7 +260,7 @@ impl<T> Stack<T> {
// `self.cursor` is aligned for `T`. // `self.cursor` is aligned for `T`.
unsafe { self.cursor.as_ptr().write(value) } unsafe { self.cursor.as_ptr().write(value) }
// SAFETY: Cursor is not at end, so advancing by a `T` cannot be out of bounds // SAFETY: Cursor is not at end, so advancing by a `T` cannot be out of bounds
self.cursor = unsafe { NonNull::new_unchecked(self.cursor.as_ptr().add(1)) }; self.cursor = unsafe { self.cursor.add(1) };
} }
/// Pop value from stack. /// Pop value from stack.
@ -286,7 +286,7 @@ impl<T> Stack<T> {
debug_assert!(self.cursor > self.start); debug_assert!(self.cursor > self.start);
debug_assert!(self.cursor <= self.end); debug_assert!(self.cursor <= self.end);
// SAFETY: Caller guarantees stack is not empty, so subtracting 1 cannot be out of bounds // SAFETY: Caller guarantees stack is not empty, so subtracting 1 cannot be out of bounds
self.cursor = NonNull::new_unchecked(self.cursor.as_ptr().sub(1)); self.cursor = self.cursor.sub(1);
// SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`, // SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`,
// and points to a valid initialized `T`, if stack is not empty. // and points to a valid initialized `T`, if stack is not empty.
// Caller guarantees stack was not empty. // Caller guarantees stack was not empty.