refactor(parser): make Source::set_position safe (#2341)

Make `Source::set_position` a safe function.

This addresses a shortcoming of #2288.

Instead of requiring caller of `Source::set_position` to guarantee that the `SourcePosition` is created from this `Source`, the preceding PRs enforce this guarantee at the type level.

`Source::set_position` is going to be a central API for transitioning the lexer to processing the source as bytes, rather than `char`s (and the anticipated speed-ups that will produce). So making this method safe will remove the need for a *lot* of unsafe code blocks, and boilerplate comments promising "SAFETY: There's only one `Source`", when to the developer, this is blindingly obvious anyway.

So, while splitting the parser into `Parser` and `ParserImpl` (#2339) is an annoying change to have to make, I believe the benefit of this PR justifies it.
This commit is contained in:
overlookmotel 2024-02-08 06:56:26 +00:00 committed by GitHub
parent aef593fb50
commit f3470163d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 17 additions and 28 deletions

View file

@ -255,9 +255,7 @@ impl<'a> ParserImpl<'a> {
let ParserCheckpoint { lexer, cur_token, prev_span_end, errors_pos: errors_lens } = let ParserCheckpoint { lexer, cur_token, prev_span_end, errors_pos: errors_lens } =
checkpoint; checkpoint;
// SAFETY: Parser only ever creates a single `Lexer`, self.lexer.rewind(lexer);
// therefore all checkpoints must be created from it.
unsafe { self.lexer.rewind(lexer) };
self.token = cur_token; self.token = cur_token;
self.prev_token_end = prev_span_end; self.prev_token_end = prev_span_end;
self.errors.truncate(errors_lens); self.errors.truncate(errors_lens);

View file

@ -153,14 +153,8 @@ impl<'a> Lexer<'a> {
} }
/// Rewinds the lexer to the same state as when the passed in `checkpoint` was created. /// Rewinds the lexer to the same state as when the passed in `checkpoint` was created.
/// pub fn rewind(&mut self, checkpoint: LexerCheckpoint<'a>) {
/// # SAFETY
/// `checkpoint` must have been created from this `Lexer`.
#[allow(clippy::missing_safety_doc)] // Clippy is wrong!
pub unsafe fn rewind(&mut self, checkpoint: LexerCheckpoint<'a>) {
self.errors.truncate(checkpoint.errors_pos); self.errors.truncate(checkpoint.errors_pos);
// SAFETY: Caller guarantees `checkpoint` was created from this `Lexer`,
// and therefore `checkpoint.position` was created from `self.source`.
self.source.set_position(checkpoint.position); self.source.set_position(checkpoint.position);
self.token = checkpoint.token; self.token = checkpoint.token;
self.lookahead.clear(); self.lookahead.clear();
@ -178,10 +172,7 @@ impl<'a> Lexer<'a> {
let position = self.source.position(); let position = self.source.position();
if let Some(lookahead) = self.lookahead.back() { if let Some(lookahead) = self.lookahead.back() {
// SAFETY: `self.lookahead` only contains lookaheads created by this `Lexer`. self.source.set_position(lookahead.position);
// `self.source` never changes, so `lookahead.position` must have been created
// from `self.source`.
unsafe { self.source.set_position(lookahead.position) };
} }
for _i in self.lookahead.len()..n { for _i in self.lookahead.len()..n {
@ -197,8 +188,7 @@ impl<'a> Lexer<'a> {
// read, so that's not possible. So no need to restore `self.token` here. // read, so that's not possible. So no need to restore `self.token` here.
// It's already in same state as it was at start of this function. // It's already in same state as it was at start of this function.
// SAFETY: `position` was created above from `self.source`. `self.source` never changes. self.source.set_position(position);
unsafe { self.source.set_position(position) };
self.lookahead[n - 1].token self.lookahead[n - 1].token
} }
@ -211,10 +201,7 @@ impl<'a> Lexer<'a> {
/// Main entry point /// Main entry point
pub fn next_token(&mut self) -> Token { pub fn next_token(&mut self) -> Token {
if let Some(lookahead) = self.lookahead.pop_front() { if let Some(lookahead) = self.lookahead.pop_front() {
// SAFETY: `self.lookahead` only contains lookaheads created by this `Lexer`. self.source.set_position(lookahead.position);
// `self.source` never changes, so `lookahead.position` must have been created
// from `self.source`.
unsafe { self.source.set_position(lookahead.position) };
return lookahead.token; return lookahead.token;
} }
let kind = self.read_next_token(); let kind = self.read_next_token();

View file

@ -140,14 +140,17 @@ impl<'a> Source<'a> {
} }
/// Move current position. /// Move current position.
///
/// # SAFETY
/// `pos` must be created from this `Source`, not another `Source`.
/// If this is the case, the invariants of `Source` are guaranteed to be upheld.
#[inline] #[inline]
pub(super) unsafe fn set_position(&mut self, pos: SourcePosition) { pub(super) fn set_position(&mut self, pos: SourcePosition) {
// `SourcePosition` always upholds the invariants of `Source`, // `SourcePosition` always upholds the invariants of `Source`, as long as it's created
// as long as it's created from this `Source`. // from this `Source`. `SourcePosition`s can only be created from a `Source`.
// `Source::new` takes a `UniquePromise`, which guarantees that it's the only `Source`
// in existence on this thread. `Source` is not `Sync` or `Send`, so no possibility another
// `Source` originated on another thread can "jump" onto this one.
// This is sufficient to guarantee that any `SourcePosition` that parser/lexer holds must be
// from this `Source`.
// This guarantee is what allows this function to be safe.
// SAFETY: `read_u8`'s contract is upheld by: // SAFETY: `read_u8`'s contract is upheld by:
// * The preceding checks that `pos.ptr` >= `self.start` and < `self.end`. // * The preceding checks that `pos.ptr` >= `self.start` and < `self.end`.
// * `Source`'s invariants guarantee that `self.start` - `self.end` contains allocated memory. // * `Source`'s invariants guarantee that `self.start` - `self.end` contains allocated memory.
@ -157,7 +160,8 @@ impl<'a> Source<'a> {
debug_assert!( debug_assert!(
pos.ptr >= self.start pos.ptr >= self.start
&& pos.ptr <= self.end && pos.ptr <= self.end
&& (pos.ptr == self.end || !is_utf8_cont_byte(read_u8(pos.ptr))) // SAFETY: See above
&& (pos.ptr == self.end || !is_utf8_cont_byte(unsafe { read_u8(pos.ptr) }))
); );
self.ptr = pos.ptr; self.ptr = pos.ptr;
} }