refactor(transformer): add wrapper around NonNull (#6115)

Introduce a wrapper around `NonNull` which enables methods which exist on `std::ptr::NonNull` but are not yet stable in our MSRV. These methods remove a lot of boilerplate code from `Stack` and `NonEmptyStack` and make them easier to understand - which is important, since they contain so much unsafe code.
This commit is contained in:
overlookmotel 2024-09-27 16:48:37 +00:00
parent c50500ec42
commit 9ac80bd2d2
5 changed files with 134 additions and 31 deletions

View file

@ -3,12 +3,12 @@
use std::{
alloc::{self, Layout},
mem::{align_of, size_of},
ptr::{self, NonNull},
ptr,
};
use assert_unchecked::assert_unchecked;
use super::StackCapacity;
use super::{NonNull, StackCapacity};
pub trait StackCommon<T>: StackCapacity<T> {
// Getter setter methods defined by implementer
@ -34,11 +34,8 @@ pub trait StackCommon<T>: StackCapacity<T> {
let layout = Self::layout_for(capacity_bytes);
let (start, end) = allocate(layout);
// SAFETY: `start` and `end` are `NonNull` - just casting them
let start = NonNull::new_unchecked(start.as_ptr().cast::<T>());
let end = NonNull::new_unchecked(end.as_ptr().cast::<T>());
(start, end)
// SAFETY: `layout_for` produces a layout with `T`'s alignment, so pointers are aligned for `T`
(start.cast::<T>(), end.cast::<T>())
}
/// Grow allocation.
@ -64,7 +61,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
// `MAX_CAPACITY_BYTES` is also a multiple of `size_of::<T>()`.
// So new capacity in bytes must be a multiple of `size_of::<T>()`.
// `MAX_CAPACITY_BYTES <= isize::MAX`.
let old_start_ptr = NonNull::new_unchecked(self.start().as_ptr().cast::<u8>());
let old_start_ptr = self.start().cast::<u8>();
let old_layout = Self::layout_for(self.capacity_bytes());
let (start, end, current) = grow(old_start_ptr, old_layout, Self::MAX_CAPACITY_BYTES);
@ -73,9 +70,9 @@ pub trait StackCommon<T>: StackCapacity<T> {
// All pointers returned from `grow` are aligned for `T`.
// Old capacity and new capacity in bytes are both multiples of `size_of::<T>()`,
// so distances `end - start` and `current - start` are both multiples of `size_of::<T>()`.
self.set_start(NonNull::new_unchecked(start.as_ptr().cast::<T>()));
self.set_end(NonNull::new_unchecked(end.as_ptr().cast::<T>()));
self.set_cursor(NonNull::new_unchecked(current.as_ptr().cast::<T>()));
self.set_start(start.cast::<T>());
self.set_end(end.cast::<T>());
self.set_cursor(current.cast::<T>());
}
/// Deallocate stack memory.
@ -149,7 +146,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
#[expect(clippy::cast_sign_loss)]
unsafe {
assert_unchecked!(self.cursor() >= self.start());
self.cursor().as_ptr().offset_from(self.start().as_ptr()) as usize
self.cursor().offset_from(self.start()) as usize
}
}
@ -166,7 +163,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
#[expect(clippy::cast_sign_loss)]
unsafe {
assert_unchecked!(self.end() >= self.start());
self.end().as_ptr().offset_from(self.start().as_ptr()) as usize
self.end().offset_from(self.start()) as usize
}
}
@ -183,7 +180,7 @@ pub trait StackCommon<T>: StackCapacity<T> {
#[expect(clippy::cast_sign_loss)]
unsafe {
assert_unchecked!(self.end() >= self.start());
self.end().as_ptr().byte_offset_from(self.start().as_ptr()) as usize
self.end().byte_offset_from(self.start()) as usize
}
}
}
@ -209,7 +206,7 @@ unsafe fn allocate(layout: Layout) -> (/* start */ NonNull<u8>, /* end */ NonNul
// SAFETY: We checked `ptr` is non-null
let start = NonNull::new_unchecked(ptr);
// SAFETY: We allocated `layout.size()` bytes, so `end` is end of allocation
let end = NonNull::new_unchecked(ptr.add(layout.size()));
let end = start.add(layout.size());
(start, end)
}
@ -249,7 +246,7 @@ unsafe fn grow(
// So `new_capacity_bytes` must be a multiple of `size_of::<T>()`.
// `new_capacity_bytes` is `<= MAX_CAPACITY_BYTES`, so is a legal allocation size.
// `layout_for` produces a layout with `T`'s alignment, so `new_ptr` is aligned for `T`.
let new_ptr = unsafe {
let new_start = unsafe {
let old_ptr = old_start.as_ptr();
let new_ptr = alloc::realloc(old_ptr, old_layout, new_capacity_bytes);
if new_ptr.is_null() {
@ -257,7 +254,7 @@ unsafe fn grow(
Layout::from_size_align_unchecked(old_capacity_bytes, old_layout.align());
alloc::handle_alloc_error(new_layout);
}
new_ptr
NonNull::new_unchecked(new_ptr)
};
// Update pointers.
@ -273,9 +270,8 @@ unsafe fn grow(
//
// SAFETY: We checked that `new_ptr` is non-null.
// `old_capacity_bytes < new_capacity_bytes` (ensured above), so `new_cursor` must be in bounds.
let new_start = NonNull::new_unchecked(new_ptr);
let new_end = NonNull::new_unchecked(new_ptr.add(new_capacity_bytes));
let new_cursor = NonNull::new_unchecked(new_ptr.add(old_capacity_bytes));
let new_end = new_start.add(new_capacity_bytes);
let new_cursor = new_start.add(old_capacity_bytes);
(new_start, new_end, new_cursor)
}

View file

@ -1,11 +1,13 @@
mod capacity;
mod common;
mod non_empty;
mod non_null;
mod sparse;
mod standard;
use capacity::StackCapacity;
use common::StackCommon;
pub use non_empty::NonEmptyStack;
use non_null::NonNull;
pub use sparse::SparseStack;
pub use standard::Stack;

View file

@ -1,8 +1,8 @@
#![expect(clippy::unnecessary_safety_comment)]
use std::{mem::size_of, ptr::NonNull};
use std::mem::size_of;
use super::{StackCapacity, StackCommon};
use super::{NonNull, StackCapacity, StackCommon};
/// A stack which can never be empty.
///
@ -199,7 +199,7 @@ impl<T> NonEmptyStack<T> {
// of allocation. So advancing by a `T` cannot be out of bounds.
// The distance between `self.cursor` and `self.end` is always a multiple of `size_of::<T>()`,
// so `==` check is sufficient to detect when full to capacity.
let new_cursor = unsafe { NonNull::new_unchecked(self.cursor.as_ptr().add(1)) };
let new_cursor = unsafe { self.cursor.add(1) };
if new_cursor == self.end {
// Needs to grow
// SAFETY: Stack is full to capacity
@ -264,7 +264,7 @@ impl<T> NonEmptyStack<T> {
let value = self.cursor.as_ptr().read();
// SAFETY: Caller guarantees there's at least 2 entries on stack, so subtracting 1
// cannot be out of bounds
self.cursor = NonNull::new_unchecked(self.cursor.as_ptr().sub(1));
self.cursor = self.cursor.sub(1);
value
}

View file

@ -0,0 +1,105 @@
use std::{cmp::Ordering, ptr::NonNull as NativeNonNull};
/// Wrapper around `NonNull<T>`, which adds methods `add`, `sub`, `offset_from` and `byte_offset_from`.
/// These methods exist on `std::ptr::NonNull`, and became stable in Rust 1.80.0, but are not yet
/// stable in our MSRV.
///
/// These methods are much cleaner than the workarounds required in older Rust versions,
/// and make code using them easier to understand.
///
/// Once we bump MSRV and these methods are natively supported, this type can be removed.
/// `#[expect(clippy::incompatible_msrv)]` on `non_null_add_is_not_stable` below will trigger
/// a lint warning when that happens.
/// Then this module can be deleted, and all uses of this type can be switched to `std::ptr::NonNull`.
#[derive(Debug)]
pub struct NonNull<T>(NativeNonNull<T>);
#[expect(dead_code, clippy::incompatible_msrv)]
unsafe fn non_null_add_is_not_stable(ptr: NativeNonNull<u8>) -> NativeNonNull<u8> {
ptr.add(1)
}
impl<T> NonNull<T> {
#[inline]
pub const unsafe fn new_unchecked(ptr: *mut T) -> Self {
Self(NativeNonNull::new_unchecked(ptr))
}
#[inline]
pub const fn dangling() -> Self {
Self(NativeNonNull::dangling())
}
#[inline]
pub const fn as_ptr(self) -> *mut T {
self.0.as_ptr()
}
#[inline]
pub const fn cast<U>(self) -> NonNull<U> {
// SAFETY: `self` is non-null, so it's still non-null after casting
unsafe { NonNull::new_unchecked(self.as_ptr().cast()) }
}
#[inline]
pub const unsafe fn add(self, count: usize) -> Self {
NonNull(NativeNonNull::new_unchecked(self.as_ptr().add(count)))
}
#[inline]
pub const unsafe fn sub(self, count: usize) -> Self {
NonNull(NativeNonNull::new_unchecked(self.as_ptr().sub(count)))
}
#[inline]
pub const unsafe fn offset_from(self, origin: Self) -> isize {
self.as_ptr().offset_from(origin.as_ptr())
}
#[inline]
pub const unsafe fn byte_offset_from(self, origin: Self) -> isize {
self.as_ptr().byte_offset_from(origin.as_ptr())
}
#[inline]
pub const unsafe fn as_ref<'t>(self) -> &'t T {
self.0.as_ref()
}
#[inline]
pub unsafe fn as_mut<'t>(mut self) -> &'t mut T {
self.0.as_mut()
}
}
impl<T> Copy for NonNull<T> {}
impl<T> Clone for NonNull<T> {
#[inline]
fn clone(&self) -> Self {
*self
}
}
impl<T> Eq for NonNull<T> {}
impl<T> PartialEq for NonNull<T> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T> Ord for NonNull<T> {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.as_ptr().cmp(&other.as_ptr())
}
}
impl<T> PartialOrd for NonNull<T> {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

View file

@ -1,8 +1,8 @@
#![expect(clippy::unnecessary_safety_comment)]
use std::{mem::size_of, ptr::NonNull};
use std::mem::size_of;
use super::{StackCapacity, StackCommon};
use super::{NonNull, StackCapacity, StackCommon};
/// A simple stack.
///
@ -177,7 +177,7 @@ impl<T> Stack<T> {
// SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`,
// and `self.current.sub(1)` points to a valid initialized `T`, if stack is not empty.
// Caller guarantees stack is not empty.
NonNull::new_unchecked(self.cursor.as_ptr().sub(1)).as_ref()
self.cursor.sub(1).as_ref()
}
/// Get mutable reference to last value on stack.
@ -205,7 +205,7 @@ impl<T> Stack<T> {
// SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`,
// and `self.current.sub(1)` points to a valid initialized `T`, if stack is not empty.
// Caller guarantees stack is not empty.
NonNull::new_unchecked(self.cursor.as_ptr().sub(1)).as_mut()
self.cursor.sub(1).as_mut()
}
/// Push value to stack.
@ -224,7 +224,7 @@ impl<T> Stack<T> {
// SAFETY: Cursor is not at end, so `self.cursor` is in bounds for writing
unsafe { self.cursor.as_ptr().write(value) };
// SAFETY: Cursor is not at end, so advancing by a `T` cannot be out of bounds
self.cursor = unsafe { NonNull::new_unchecked(self.cursor.as_ptr().add(1)) };
self.cursor = unsafe { self.cursor.add(1) };
}
}
@ -260,7 +260,7 @@ impl<T> Stack<T> {
// `self.cursor` is aligned for `T`.
unsafe { self.cursor.as_ptr().write(value) }
// SAFETY: Cursor is not at end, so advancing by a `T` cannot be out of bounds
self.cursor = unsafe { NonNull::new_unchecked(self.cursor.as_ptr().add(1)) };
self.cursor = unsafe { self.cursor.add(1) };
}
/// Pop value from stack.
@ -286,7 +286,7 @@ impl<T> Stack<T> {
debug_assert!(self.cursor > self.start);
debug_assert!(self.cursor <= self.end);
// SAFETY: Caller guarantees stack is not empty, so subtracting 1 cannot be out of bounds
self.cursor = NonNull::new_unchecked(self.cursor.as_ptr().sub(1));
self.cursor = self.cursor.sub(1);
// SAFETY: All methods ensure `self.cursor` is always in bounds, is aligned for `T`,
// and points to a valid initialized `T`, if stack is not empty.
// Caller guarantees stack was not empty.