From 46c02aee615bfefa457c9e917c0b704c3230cc7f Mon Sep 17 00:00:00 2001 From: overlookmotel <557937+overlookmotel@users.noreply.github.com> Date: Sat, 11 May 2024 04:39:42 +0000 Subject: [PATCH] feat(traverse): add scope flags to `TraverseCtx` (#3229) Add scope flags to `TraverseCtx`. Closes #3189. `walk_*` functions build a stack of `ScopeFlags` as AST is traversed, and they can be queried from within visitors with `ctx.scope()`, `ctx.ancestor_scope()` and `ctx.find_scope()`. The codegen which generates `walk_*` functions gets the info about which AST types have scopes, and how to check for strict mode from the `#[visited_node]` attrs on AST type definitions in `oxc_ast`. A few notes: Each scope inherits the strict mode flag from the level before it in the stack, so if you need to know "am I in strict mode context here?", `ctx.scope().is_strict_mode()` will tell you - no need to travel back up the stack to find out. Scopes do *not* inherit any other flags from level before it. So `ctx.scope()` in a block nested in a function will return `ScopeFlags::empty()` not `ScopeFlags::Function`. I had to add an extra flag `ScopeFlags::Method`. The reason for this is to deal with when a `Function` is actually a `MethodDefinition`, and to avoid creating 2 scopes in this case. The principle I'm trying to follow is to encode as little logic in the codegen as possible, as it's rather hidden away. Instead the codegen follows a standard logic for every node, guided by attributes which are visible next to the types in `oxc_ast`. This hopefully makes how `Traverse`'s visitors are generated less mysterious, and easier to change. The case of `Function` within `MethodDefinition` is a weird one and would not be possible to implement without encoding a magic "special case" within the codegen without this extra `ScopeFlags::Method` variant. Its existence does not alter the operation of any other code in Oxc which uses `ScopeFlags`. In my view `ScopeFlags` might benefit from a little bit of an overhaul anyway. I believe we could pack more information into the bits and make it more useful. --- crates/oxc_ast/src/ast/js.rs | 49 +++++++++--- crates/oxc_ast/src/ast/ts.rs | 6 +- crates/oxc_syntax/src/scope.rs | 5 +- crates/oxc_traverse/scripts/lib/parse.mjs | 80 +++++++++++++++++-- crates/oxc_traverse/scripts/lib/walk.mjs | 58 +++++++++++++- crates/oxc_traverse/src/context.rs | 79 ++++++++++++++++++- crates/oxc_traverse/src/lib.rs | 1 + crates/oxc_traverse/src/walk.rs | 94 ++++++++++++++++++++++- 8 files changed, 344 insertions(+), 28 deletions(-) diff --git a/crates/oxc_ast/src/ast/js.rs b/crates/oxc_ast/src/ast/js.rs index 930b2f954..c7446f8cc 100644 --- a/crates/oxc_ast/src/ast/js.rs +++ b/crates/oxc_ast/src/ast/js.rs @@ -14,6 +14,7 @@ use oxc_syntax::{ AssignmentOperator, BinaryOperator, LogicalOperator, UnaryOperator, UpdateOperator, }, reference::{ReferenceFlag, ReferenceId}, + scope::ScopeFlags, symbol::SymbolId, }; #[cfg(feature = "serialize")] @@ -41,7 +42,10 @@ export interface FormalParameterRest extends Span { } "#; -#[visited_node] +#[visited_node( + scope(ScopeFlags::Top), + strict_if(self.source_type.is_strict() || self.directives.iter().any(Directive::is_use_strict)) +)] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type", rename_all = "camelCase"))] @@ -1508,7 +1512,7 @@ pub struct Hashbang<'a> { } /// Block Statement -#[visited_node] +#[visited_node(scope(ScopeFlags::empty()))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] @@ -1735,7 +1739,10 @@ pub struct WhileStatement<'a> { } /// For Statement -#[visited_node] +#[visited_node( + scope(ScopeFlags::empty()), + scope_if(self.init.as_ref().is_some_and(ForStatementInit::is_lexical_declaration)) +)] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] @@ -1776,7 +1783,7 @@ impl<'a> ForStatementInit<'a> { } /// For-In Statement -#[visited_node] +#[visited_node(scope(ScopeFlags::empty()), scope_if(self.left.is_lexical_declaration()))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] @@ -1789,7 +1796,7 @@ pub struct ForInStatement<'a> { } /// For-Of Statement -#[visited_node] +#[visited_node(scope(ScopeFlags::empty()), scope_if(self.left.is_lexical_declaration()))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] @@ -1875,7 +1882,7 @@ pub struct WithStatement<'a> { } /// Switch Statement -#[visited_node] +#[visited_node(scope(ScopeFlags::empty()), enter_scope_before(cases))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] @@ -1939,7 +1946,7 @@ pub struct TryStatement<'a> { pub finalizer: Option>>, } -#[visited_node] +#[visited_node(scope(ScopeFlags::empty()), scope_if(self.param.is_some()))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] @@ -2124,7 +2131,12 @@ pub struct BindingRestElement<'a> { } /// Function Definitions -#[visited_node] +#[visited_node( + scope(ScopeFlags::Function), + // Don't create a 2nd scope if `MethodDefinition` already created one + scope_if((ctx.scope() & ScopeFlags::Modifiers).is_empty()), + strict_if(self.body.as_ref().is_some_and(|body| body.has_use_strict_directive())) +)] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(rename_all = "camelCase"))] @@ -2292,7 +2304,7 @@ impl<'a> FunctionBody<'a> { } /// Arrow Function Definitions -#[visited_node] +#[visited_node(scope(ScopeFlags::Function | ScopeFlags::Arrow))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type", rename_all = "camelCase"))] @@ -2335,7 +2347,7 @@ pub struct YieldExpression<'a> { } /// Class Definitions -#[visited_node] +#[visited_node(scope(ScopeFlags::StrictMode), enter_scope_before(id))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(rename_all = "camelCase"))] @@ -2491,7 +2503,11 @@ impl<'a> ClassElement<'a> { } } -#[visited_node] +#[visited_node( + scope(self.kind.scope_flags()), + strict_if(self.value.is_strict()), + enter_scope_before(value) +)] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(rename_all = "camelCase"))] @@ -2566,6 +2582,15 @@ impl MethodDefinitionKind { pub fn is_set(&self) -> bool { matches!(self, Self::Set) } + + pub fn scope_flags(self) -> ScopeFlags { + match self { + Self::Constructor => ScopeFlags::Constructor, + Self::Method => ScopeFlags::Method, + Self::Get => ScopeFlags::GetAccessor, + Self::Set => ScopeFlags::SetAccessor, + } + } } #[visited_node] @@ -2584,7 +2609,7 @@ impl<'a> PrivateIdentifier<'a> { } } -#[visited_node] +#[visited_node(scope(ScopeFlags::ClassStaticBlock))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] diff --git a/crates/oxc_ast/src/ast/ts.rs b/crates/oxc_ast/src/ast/ts.rs index a181c0276..339f8948e 100644 --- a/crates/oxc_ast/src/ast/ts.rs +++ b/crates/oxc_ast/src/ast/ts.rs @@ -43,7 +43,7 @@ pub struct TSThisParameter<'a> { /// Enum Declaration /// /// `const_opt` enum `BindingIdentifier` { `EnumBody_opt` } -#[visited_node] +#[visited_node(scope(ScopeFlags::empty()), enter_scope_before(members))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type"))] @@ -597,7 +597,7 @@ pub struct TSTypeParameterInstantiation<'a> { pub params: Vec<'a, TSType<'a>>, } -#[visited_node] +#[visited_node(scope(ScopeFlags::empty()))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type", rename_all = "camelCase"))] @@ -878,7 +878,7 @@ pub enum TSModuleDeclarationBody<'a> { TSModuleBlock(Box<'a, TSModuleBlock<'a>>), } -#[visited_node] +#[visited_node(scope(ScopeFlags::TsModuleBlock))] #[derive(Debug, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize, Tsify))] #[cfg_attr(feature = "serialize", serde(tag = "type", rename_all = "camelCase"))] diff --git a/crates/oxc_syntax/src/scope.rs b/crates/oxc_syntax/src/scope.rs index d6b9e0f48..7aa3f0ea9 100644 --- a/crates/oxc_syntax/src/scope.rs +++ b/crates/oxc_syntax/src/scope.rs @@ -23,8 +23,11 @@ bitflags! { const Constructor = 1 << 6; const GetAccessor = 1 << 7; const SetAccessor = 1 << 8; + // Only used in `Traverse` + const Method = 1 << 9; const Var = Self::Top.bits() | Self::Function.bits() | Self::ClassStaticBlock.bits() | Self::TsModuleBlock.bits(); - const Modifiers = Self::Constructor.bits() | Self::GetAccessor.bits() | Self::SetAccessor.bits(); + const Modifiers = Self::Constructor.bits() | Self::GetAccessor.bits() + | Self::SetAccessor.bits() | Self::Method.bits(); } } diff --git a/crates/oxc_traverse/scripts/lib/parse.mjs b/crates/oxc_traverse/scripts/lib/parse.mjs index 72a4f565c..0f9f95d91 100644 --- a/crates/oxc_traverse/scripts/lib/parse.mjs +++ b/crates/oxc_traverse/scripts/lib/parse.mjs @@ -21,9 +21,26 @@ export default async function getTypesFromCode() { } function parseFile(code, filename, types) { - const lines = code.split(/\r?\n/); + const lines = code.split(/\r?\n/).map( + line => line.replace(/\s+/g, ' ').replace(/ ?\/\/.*$/, '') + ); for (let lineIndex = 0; lineIndex < lines.length; lineIndex++) { - if (lines[lineIndex] !== '#[visited_node]') continue; + const lineMatch = lines[lineIndex].match(/^#\[visited_node ?([\]\(])/); + if (!lineMatch) continue; + + let scopeArgs = null; + if (lineMatch[1] === '(') { + let line = lines[lineIndex].slice(lineMatch[0].length), + scopeArgsStr = ''; + while (!line.endsWith(')]')) { + scopeArgsStr += ` ${line}`; + line = lines[++lineIndex]; + } + scopeArgsStr += ` ${line.slice(0, -2)}`; + scopeArgsStr = scopeArgsStr.trim().replace(/ +/g, ' '); + + scopeArgs = parseScopeArgs(scopeArgsStr, filename, lineIndex); + } let match; while (true) { @@ -36,20 +53,20 @@ function parseFile(code, filename, types) { const itemLines = []; while (true) { - const line = lines[++lineIndex].replace(/\/\/.*$/, '').replace(/\s+/g, ' ').trim(); + const line = lines[++lineIndex].trim(); if (line === '}') break; if (line !== '') itemLines.push(line); } if (kind === 'struct') { - types[name] = parseStruct(name, hasLifetime, itemLines, filename, startLineIndex); + types[name] = parseStruct(name, hasLifetime, itemLines, scopeArgs, filename, startLineIndex); } else { types[name] = parseEnum(name, hasLifetime, itemLines, filename, startLineIndex); } } } -function parseStruct(name, hasLifetime, lines, filename, startLineIndex) { +function parseStruct(name, hasLifetime, lines, scopeArgs, filename, startLineIndex) { const fields = []; for (let i = 0; i < lines.length; i++) { const line = lines[i]; @@ -71,7 +88,7 @@ function parseStruct(name, hasLifetime, lines, filename, startLineIndex) { fields.push({name, typeName, rawName, rawTypeName, innerTypeName, wrappers}); } - return {kind: 'struct', name, hasLifetime, fields}; + return {kind: 'struct', name, hasLifetime, fields, scopeArgs}; } function parseEnum(name, hasLifetime, lines, filename, startLineIndex) { @@ -81,7 +98,7 @@ function parseEnum(name, hasLifetime, lines, filename, startLineIndex) { const match = line.match(/^(.+?)\((.+?)\)(?: ?= ?(\d+))?,$/); if (match) { const [, name, rawTypeName, discriminantStr] = match, - typeName = rawTypeName.replace(/<'a>/g, '').replace(/<'a,\s*/g, '<'), + typeName = rawTypeName.replace(/<'a>/g, '').replace(/<'a, ?/g, '<'), {name: innerTypeName, wrappers} = typeAndWrappers(typeName), discriminant = discriminantStr ? +discriminantStr : null; variants.push({name, typeName, rawTypeName, innerTypeName, wrappers, discriminant}); @@ -96,3 +113,52 @@ function parseEnum(name, hasLifetime, lines, filename, startLineIndex) { } return {kind: 'enum', name, hasLifetime, variants, inherits}; } + +function parseScopeArgs(argsStr, filename, lineIndex) { + if (!argsStr) return null; + + const matchAndConsume = (regex) => { + const match = argsStr.match(regex); + assert(match); + argsStr = argsStr.slice(match[0].length); + return match.slice(1); + }; + + const args = {}; + try { + while (true) { + const [key] = matchAndConsume(/^([a-z_]+)\(/); + assert( + ['scope', 'scope_if', 'strict_if', 'enter_scope_before'].includes(key), + `Unexpected visited_node macro arg: ${key}` + ); + + let bracketCount = 1, + index = 0; + for (; index < argsStr.length; index++) { + const char = argsStr[index]; + if (char === '(') { + bracketCount++; + } else if (char === ')') { + bracketCount--; + if (bracketCount === 0) break; + } + } + assert(bracketCount === 0); + + args[key] = argsStr.slice(0, index).trim(); + argsStr = argsStr.slice(index + 1); + if (argsStr === '') break; + + matchAndConsume(/^ ?, ?/); + } + + assert(args.scope, 'Missing key `scope`'); + } catch (err) { + throw new Error( + `Cannot parse visited_node args: ${argsStr} in ${filename}:${lineIndex}\n${err?.message}` + ); + } + + return args; +} diff --git a/crates/oxc_traverse/scripts/lib/walk.mjs b/crates/oxc_traverse/scripts/lib/walk.mjs index fd84e1321..ec717dad4 100644 --- a/crates/oxc_traverse/scripts/lib/walk.mjs +++ b/crates/oxc_traverse/scripts/lib/walk.mjs @@ -20,12 +20,15 @@ export default function generateWalkFunctionsCode(types) { clippy::semicolon_if_nothing_returned, clippy::ptr_as_ptr, clippy::borrow_as_ptr, - clippy::cast_ptr_alignment + clippy::cast_ptr_alignment, + clippy::needless_borrow )] use oxc_allocator::Vec; #[allow(clippy::wildcard_imports)] use oxc_ast::ast::*; + use oxc_span::SourceType; + use oxc_syntax::scope::ScopeFlags; use crate::{ancestor::{self, AncestorType}, Ancestor, Traverse, TraverseCtx}; @@ -48,13 +51,55 @@ export default function generateWalkFunctionsCode(types) { function generateWalkForStruct(type, types) { const visitedFields = type.fields.filter(field => field.innerTypeName in types); + const {scopeArgs} = type; + let scopeEnterField, enterScopeCode, exitScopeCode; + if (scopeArgs) { + // Get field to enter scope before + const enterFieldName = scopeArgs.enter_scope_before; + if (enterFieldName) { + scopeEnterField = visitedFields.find(field => field.name === enterFieldName); + assert( + scopeEnterField, + `\`visited_node\` attr says to enter scope before field '${enterFieldName}' ` + + `in '${type.name}', but that field is not visited` + ); + } else { + scopeEnterField = visitedFields[0]; + } + + const convertExpressionToUsePointers = arg => arg.replace( + /(^|[^a-zA-Z0-9_])self\.(?:r#)?([A-Za-z0-9_]+)/g, + (_, before, fieldName) => { + const field = type.fields.find(field => field.name === fieldName); + assert(`Cannot parse conditional in visited_node args: '${arg}' for ${type.name}`); + return `${before}(&*(${makeFieldCode(field)}))`; + } + ); + + let scopeType = convertExpressionToUsePointers(scopeArgs.scope); + if (scopeArgs.strict_if) { + scopeType += `.with_strict_mode(${convertExpressionToUsePointers(scopeArgs.strict_if)})`; + } + + enterScopeCode = `ctx.push_scope_stack(${scopeType});`; + exitScopeCode = `ctx.pop_scope_stack();`; + if (scopeArgs.scope_if) { + enterScopeCode = ` + let has_scope = ${convertExpressionToUsePointers(scopeArgs.scope_if)}; + if has_scope { ${enterScopeCode} } + `; + exitScopeCode = `if has_scope { ${exitScopeCode} }`; + } + } + const fieldsCodes = visitedFields.map((field, index) => { const fieldWalkName = `walk_${camelToSnake(field.innerTypeName)}`; const retagCode = index === 0 ? '' : `ctx.retag_stack(AncestorType::${type.name}${snakeToCamel(field.name)});`; - const fieldCode = `(node as *mut u8).add(ancestor::${field.offsetVarName}) as *mut ${field.typeName}`; + const fieldCode = makeFieldCode(field); + const scopeCode = field === scopeEnterField ? enterScopeCode : ''; if (field.wrappers[0] === 'Option') { let walkCode; @@ -77,6 +122,7 @@ function generateWalkForStruct(type, types) { } return ` + ${scopeCode} if let Some(field) = &mut *(${fieldCode}) { ${retagCode} ${walkCode} @@ -108,6 +154,7 @@ function generateWalkForStruct(type, types) { } return ` + ${scopeCode} ${retagCode} ${walkVecCode} `; @@ -115,6 +162,7 @@ function generateWalkForStruct(type, types) { if (field.wrappers.length === 1 && field.wrappers[0] === 'Box') { return ` + ${scopeCode} ${retagCode} ${fieldWalkName}(traverser, (&mut **(${fieldCode})) as *mut _, ctx); `; @@ -123,6 +171,7 @@ function generateWalkForStruct(type, types) { assert(field.wrappers.length === 0, `Cannot handle struct field with type: ${field.type}`); return ` + ${scopeCode} ${retagCode} ${fieldWalkName}(traverser, ${fieldCode}, ctx); `; @@ -138,6 +187,7 @@ function generateWalkForStruct(type, types) { ) ); `); + if (exitScopeCode) fieldsCodes.push(exitScopeCode); fieldsCodes.push('ctx.pop_stack();'); } @@ -155,6 +205,10 @@ function generateWalkForStruct(type, types) { `.replace(/\n\s*\n+/g, '\n'); } +function makeFieldCode(field) { + return `(node as *mut u8).add(ancestor::${field.offsetVarName}) as *mut ${field.typeName}`; +} + function generateWalkForEnum(type, types) { const variantCodes = type.variants.map((variant) => { const variantType = types[variant.innerTypeName]; diff --git a/crates/oxc_traverse/src/context.rs b/crates/oxc_traverse/src/context.rs index 2bb2d7223..140eca181 100644 --- a/crates/oxc_traverse/src/context.rs +++ b/crates/oxc_traverse/src/context.rs @@ -1,9 +1,11 @@ use oxc_allocator::{Allocator, Box}; use oxc_ast::AstBuilder; +use oxc_syntax::scope::ScopeFlags; use crate::ancestor::{Ancestor, AncestorType}; -const INITIAL_STACK_CAPACITY: usize = 64; +const INITIAL_STACK_CAPACITY: usize = 64; // 64 entries = 1 KiB +const INITIAL_SCOPE_STACK_CAPACITY: usize = 32; // 32 entries = 64 bytes /// Traverse context. /// @@ -11,16 +13,21 @@ const INITIAL_STACK_CAPACITY: usize = 64; /// /// Provides ability to: /// * Query parent/ancestor of current node via [`parent`], [`ancestor`], [`find_ancestor`]. +/// * Get type of current scope via [`scope`], [`ancestor_scope`], [`find_scope`]. /// * Create AST nodes via AST builder [`ast`]. /// * Allocate into arena via [`alloc`]. /// /// [`parent`]: `TraverseCtx::parent` /// [`ancestor`]: `TraverseCtx::ancestor` /// [`find_ancestor`]: `TraverseCtx::find_ancestor` +/// [`scope`]: `TraverseCtx::scope` +/// [`ancestor_scope`]: `TraverseCtx::ancestor_scope` +/// [`find_scope`]: `TraverseCtx::find_scope` /// [`ast`]: `TraverseCtx::ast` /// [`alloc`]: `TraverseCtx::alloc` pub struct TraverseCtx<'a> { stack: Vec>, + scope_stack: Vec, pub ast: AstBuilder<'a>, } @@ -37,7 +44,11 @@ impl<'a> TraverseCtx<'a> { pub(crate) fn new(allocator: &'a Allocator) -> Self { let mut stack = Vec::with_capacity(INITIAL_STACK_CAPACITY); stack.push(Ancestor::None); - Self { stack, ast: AstBuilder::new(allocator) } + + let mut scope_stack = Vec::with_capacity(INITIAL_SCOPE_STACK_CAPACITY); + scope_stack.push(ScopeFlags::empty()); + + Self { stack, scope_stack, ast: AstBuilder::new(allocator) } } /// Allocate a node in the arena. @@ -91,6 +102,52 @@ impl<'a> TraverseCtx<'a> { pub fn ancestors_depth(&self) -> usize { self.stack.len() } + + /// Get current scope info. + #[inline] + #[allow(unsafe_code)] + pub fn scope(&self) -> ScopeFlags { + // SAFETY: Scope stack contains 1 entry initially. Entries are pushed as traverse down the AST, + // and popped as go back up. So even when visiting `Program`, the initial entry is in the stack. + unsafe { *self.scope_stack.last().unwrap_unchecked() } + } + + /// Get scope ancestor. + /// `level` is number of scopes above. + /// `ancestor_scope(1).unwrap()` is equivalent to `scope()`. + #[inline] + pub fn ancestor_scope(&self, level: usize) -> Option { + self.scope_stack.get(self.stack.len() - level).copied() + } + + /// Walk up trail of scopes to find a scope. + /// + /// `finder` should return: + /// * `FinderRet::Found(value)` to stop walking and return `Some(value)`. + /// * `FinderRet::Stop` to stop walking and return `None`. + /// * `FinderRet::Continue` to continue walking up. + pub fn find_scope(&self, finder: F) -> Option + where + F: Fn(ScopeFlags) -> FinderRet, + { + for flags in self.scope_stack.iter().rev().copied() { + match finder(flags) { + FinderRet::Found(res) => return Some(res), + FinderRet::Stop => return None, + FinderRet::Continue => {} + } + } + None + } + + /// Get depth of scopes. + /// + /// Count includes global scope. + /// i.e. in `Program`, depth is 2 (global scope + program top level scope). + #[inline] + pub fn scopes_depth(&self) -> usize { + self.scope_stack.len() + } } // Methods used internally within crate @@ -134,4 +191,22 @@ impl<'a> TraverseCtx<'a> { pub(crate) unsafe fn retag_stack(&mut self, ty: AncestorType) { *(self.stack.last_mut().unwrap_unchecked() as *mut _ as *mut AncestorType) = ty; } + + /// Push scope flags onto scope stack. + /// + /// `StrictMode` flag is inherited from parent. + #[inline] + pub(crate) fn push_scope_stack(&mut self, flags: ScopeFlags) { + self.scope_stack.push(flags | (self.scope() & ScopeFlags::StrictMode)); + } + + /// Pop last item off scope stack. + /// # SAFETY + /// * Stack must not be empty. + /// * Each `pop_scope_stack` call must correspond to an earlier `push_scope_stack` call. + #[inline] + #[allow(unsafe_code)] + pub(crate) unsafe fn pop_scope_stack(&mut self) { + self.scope_stack.pop().unwrap_unchecked(); + } } diff --git a/crates/oxc_traverse/src/lib.rs b/crates/oxc_traverse/src/lib.rs index ae32c71f6..5540e15c3 100644 --- a/crates/oxc_traverse/src/lib.rs +++ b/crates/oxc_traverse/src/lib.rs @@ -146,4 +146,5 @@ pub fn traverse_mut<'a, Tr: Traverse<'a>>( // SAFETY: Walk functions are constructed to avoid unsoundness unsafe { walk::walk_program(traverser, program as *mut Program, &mut ctx) }; debug_assert!(ctx.ancestors_depth() == 1); + debug_assert!(ctx.scopes_depth() == 1); } diff --git a/crates/oxc_traverse/src/walk.rs b/crates/oxc_traverse/src/walk.rs index 94c5e20d0..c1690a292 100644 --- a/crates/oxc_traverse/src/walk.rs +++ b/crates/oxc_traverse/src/walk.rs @@ -8,12 +8,15 @@ clippy::semicolon_if_nothing_returned, clippy::ptr_as_ptr, clippy::borrow_as_ptr, - clippy::cast_ptr_alignment + clippy::cast_ptr_alignment, + clippy::needless_borrow )] use oxc_allocator::Vec; #[allow(clippy::wildcard_imports)] use oxc_ast::ast::*; +use oxc_span::SourceType; +use oxc_syntax::scope::ScopeFlags; use crate::{ ancestor::{self, AncestorType}, @@ -27,6 +30,16 @@ pub(crate) unsafe fn walk_program<'a, Tr: Traverse<'a>>( ) { traverser.enter_program(&mut *node, ctx); ctx.push_stack(Ancestor::ProgramDirectives(ancestor::ProgramWithoutDirectives(node))); + ctx.push_scope_stack( + ScopeFlags::Top.with_strict_mode( + (&*((node as *mut u8).add(ancestor::OFFSET_PROGRAM_SOURCE_TYPE) as *mut SourceType)) + .is_strict() + || (&*((node as *mut u8).add(ancestor::OFFSET_PROGRAM_DIRECTIVES) + as *mut Vec)) + .iter() + .any(Directive::is_use_strict), + ), + ); for item in (*((node as *mut u8).add(ancestor::OFFSET_PROGRAM_DIRECTIVES) as *mut Vec)) .iter_mut() @@ -45,6 +58,7 @@ pub(crate) unsafe fn walk_program<'a, Tr: Traverse<'a>>( (node as *mut u8).add(ancestor::OFFSET_PROGRAM_BODY) as *mut Vec, ctx, ); + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_program(&mut *node, ctx); } @@ -1389,11 +1403,13 @@ pub(crate) unsafe fn walk_block_statement<'a, Tr: Traverse<'a>>( ) { traverser.enter_block_statement(&mut *node, ctx); ctx.push_stack(Ancestor::BlockStatementBody(ancestor::BlockStatementWithoutBody(node))); + ctx.push_scope_stack(ScopeFlags::empty()); walk_statements( traverser, (node as *mut u8).add(ancestor::OFFSET_BLOCK_STATEMENT_BODY) as *mut Vec, ctx, ); + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_block_statement(&mut *node, ctx); } @@ -1600,6 +1616,13 @@ pub(crate) unsafe fn walk_for_statement<'a, Tr: Traverse<'a>>( ) { traverser.enter_for_statement(&mut *node, ctx); ctx.push_stack(Ancestor::ForStatementInit(ancestor::ForStatementWithoutInit(node))); + let has_scope = (&*((node as *mut u8).add(ancestor::OFFSET_FOR_STATEMENT_INIT) + as *mut Option)) + .as_ref() + .is_some_and(ForStatementInit::is_lexical_declaration); + if has_scope { + ctx.push_scope_stack(ScopeFlags::empty()); + } if let Some(field) = &mut *((node as *mut u8).add(ancestor::OFFSET_FOR_STATEMENT_INIT) as *mut Option) { @@ -1623,6 +1646,9 @@ pub(crate) unsafe fn walk_for_statement<'a, Tr: Traverse<'a>>( (node as *mut u8).add(ancestor::OFFSET_FOR_STATEMENT_BODY) as *mut Statement, ctx, ); + if has_scope { + ctx.pop_scope_stack(); + } ctx.pop_stack(); traverser.exit_for_statement(&mut *node, ctx); } @@ -1695,6 +1721,12 @@ pub(crate) unsafe fn walk_for_in_statement<'a, Tr: Traverse<'a>>( ) { traverser.enter_for_in_statement(&mut *node, ctx); ctx.push_stack(Ancestor::ForInStatementLeft(ancestor::ForInStatementWithoutLeft(node))); + let has_scope = (&*((node as *mut u8).add(ancestor::OFFSET_FOR_IN_STATEMENT_LEFT) + as *mut ForStatementLeft)) + .is_lexical_declaration(); + if has_scope { + ctx.push_scope_stack(ScopeFlags::empty()); + } walk_for_statement_left( traverser, (node as *mut u8).add(ancestor::OFFSET_FOR_IN_STATEMENT_LEFT) as *mut ForStatementLeft, @@ -1712,6 +1744,9 @@ pub(crate) unsafe fn walk_for_in_statement<'a, Tr: Traverse<'a>>( (node as *mut u8).add(ancestor::OFFSET_FOR_IN_STATEMENT_BODY) as *mut Statement, ctx, ); + if has_scope { + ctx.pop_scope_stack(); + } ctx.pop_stack(); traverser.exit_for_in_statement(&mut *node, ctx); } @@ -1723,6 +1758,12 @@ pub(crate) unsafe fn walk_for_of_statement<'a, Tr: Traverse<'a>>( ) { traverser.enter_for_of_statement(&mut *node, ctx); ctx.push_stack(Ancestor::ForOfStatementLeft(ancestor::ForOfStatementWithoutLeft(node))); + let has_scope = (&*((node as *mut u8).add(ancestor::OFFSET_FOR_OF_STATEMENT_LEFT) + as *mut ForStatementLeft)) + .is_lexical_declaration(); + if has_scope { + ctx.push_scope_stack(ScopeFlags::empty()); + } walk_for_statement_left( traverser, (node as *mut u8).add(ancestor::OFFSET_FOR_OF_STATEMENT_LEFT) as *mut ForStatementLeft, @@ -1740,6 +1781,9 @@ pub(crate) unsafe fn walk_for_of_statement<'a, Tr: Traverse<'a>>( (node as *mut u8).add(ancestor::OFFSET_FOR_OF_STATEMENT_BODY) as *mut Statement, ctx, ); + if has_scope { + ctx.pop_scope_stack(); + } ctx.pop_stack(); traverser.exit_for_of_statement(&mut *node, ctx); } @@ -1860,6 +1904,7 @@ pub(crate) unsafe fn walk_switch_statement<'a, Tr: Traverse<'a>>( (node as *mut u8).add(ancestor::OFFSET_SWITCH_STATEMENT_DISCRIMINANT) as *mut Expression, ctx, ); + ctx.push_scope_stack(ScopeFlags::empty()); ctx.retag_stack(AncestorType::SwitchStatementCases); for item in (*((node as *mut u8).add(ancestor::OFFSET_SWITCH_STATEMENT_CASES) as *mut Vec)) @@ -1867,6 +1912,7 @@ pub(crate) unsafe fn walk_switch_statement<'a, Tr: Traverse<'a>>( { walk_switch_case(traverser, item as *mut _, ctx); } + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_switch_statement(&mut *node, ctx); } @@ -1967,6 +2013,12 @@ pub(crate) unsafe fn walk_catch_clause<'a, Tr: Traverse<'a>>( ) { traverser.enter_catch_clause(&mut *node, ctx); ctx.push_stack(Ancestor::CatchClauseParam(ancestor::CatchClauseWithoutParam(node))); + let has_scope = (&*((node as *mut u8).add(ancestor::OFFSET_CATCH_CLAUSE_PARAM) + as *mut Option)) + .is_some(); + if has_scope { + ctx.push_scope_stack(ScopeFlags::empty()); + } if let Some(field) = &mut *((node as *mut u8).add(ancestor::OFFSET_CATCH_CLAUSE_PARAM) as *mut Option) { @@ -1979,6 +2031,9 @@ pub(crate) unsafe fn walk_catch_clause<'a, Tr: Traverse<'a>>( as *mut Box)) as *mut _, ctx, ); + if has_scope { + ctx.pop_scope_stack(); + } ctx.pop_stack(); traverser.exit_catch_clause(&mut *node, ctx); } @@ -2173,6 +2228,17 @@ pub(crate) unsafe fn walk_function<'a, Tr: Traverse<'a>>( ) { traverser.enter_function(&mut *node, ctx); ctx.push_stack(Ancestor::FunctionId(ancestor::FunctionWithoutId(node))); + let has_scope = (ctx.scope() & ScopeFlags::Modifiers).is_empty(); + if has_scope { + ctx.push_scope_stack( + ScopeFlags::Function.with_strict_mode( + (&*((node as *mut u8).add(ancestor::OFFSET_FUNCTION_BODY) + as *mut Option>)) + .as_ref() + .is_some_and(|body| body.has_use_strict_directive()), + ), + ); + } if let Some(field) = &mut *((node as *mut u8).add(ancestor::OFFSET_FUNCTION_ID) as *mut Option) { @@ -2209,6 +2275,9 @@ pub(crate) unsafe fn walk_function<'a, Tr: Traverse<'a>>( ctx.retag_stack(AncestorType::FunctionReturnType); walk_ts_type_annotation(traverser, (&mut **field) as *mut _, ctx); } + if has_scope { + ctx.pop_scope_stack(); + } ctx.pop_stack(); traverser.exit_function(&mut *node, ctx); } @@ -2291,6 +2360,7 @@ pub(crate) unsafe fn walk_arrow_function_expression<'a, Tr: Traverse<'a>>( ctx.push_stack(Ancestor::ArrowFunctionExpressionParams( ancestor::ArrowFunctionExpressionWithoutParams(node), )); + ctx.push_scope_stack(ScopeFlags::Function | ScopeFlags::Arrow); walk_formal_parameters( traverser, (&mut **((node as *mut u8).add(ancestor::OFFSET_ARROW_FUNCTION_EXPRESSION_PARAMS) @@ -2318,6 +2388,7 @@ pub(crate) unsafe fn walk_arrow_function_expression<'a, Tr: Traverse<'a>>( ctx.retag_stack(AncestorType::ArrowFunctionExpressionReturnType); walk_ts_type_annotation(traverser, (&mut **field) as *mut _, ctx); } + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_arrow_function_expression(&mut *node, ctx); } @@ -2352,6 +2423,7 @@ pub(crate) unsafe fn walk_class<'a, Tr: Traverse<'a>>( { walk_decorator(traverser, item as *mut _, ctx); } + ctx.push_scope_stack(ScopeFlags::StrictMode); if let Some(field) = &mut *((node as *mut u8).add(ancestor::OFFSET_CLASS_ID) as *mut Option) { @@ -2391,6 +2463,7 @@ pub(crate) unsafe fn walk_class<'a, Tr: Traverse<'a>>( walk_ts_class_implements(traverser, item as *mut _, ctx); } } + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_class(&mut *node, ctx); } @@ -2459,6 +2532,16 @@ pub(crate) unsafe fn walk_method_definition<'a, Tr: Traverse<'a>>( (node as *mut u8).add(ancestor::OFFSET_METHOD_DEFINITION_KEY) as *mut PropertyKey, ctx, ); + ctx.push_scope_stack( + (&*((node as *mut u8).add(ancestor::OFFSET_METHOD_DEFINITION_KIND) + as *mut MethodDefinitionKind)) + .scope_flags() + .with_strict_mode( + (&*((node as *mut u8).add(ancestor::OFFSET_METHOD_DEFINITION_VALUE) + as *mut Box)) + .is_strict(), + ), + ); ctx.retag_stack(AncestorType::MethodDefinitionValue); walk_function( traverser, @@ -2466,6 +2549,7 @@ pub(crate) unsafe fn walk_method_definition<'a, Tr: Traverse<'a>>( as *mut Box)) as *mut _, ctx, ); + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_method_definition(&mut *node, ctx); } @@ -2522,11 +2606,13 @@ pub(crate) unsafe fn walk_static_block<'a, Tr: Traverse<'a>>( ) { traverser.enter_static_block(&mut *node, ctx); ctx.push_stack(Ancestor::StaticBlockBody(ancestor::StaticBlockWithoutBody(node))); + ctx.push_scope_stack(ScopeFlags::ClassStaticBlock); walk_statements( traverser, (node as *mut u8).add(ancestor::OFFSET_STATIC_BLOCK_BODY) as *mut Vec, ctx, ); + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_static_block(&mut *node, ctx); } @@ -3503,6 +3589,7 @@ pub(crate) unsafe fn walk_ts_enum_declaration<'a, Tr: Traverse<'a>>( (node as *mut u8).add(ancestor::OFFSET_TS_ENUM_DECLARATION_ID) as *mut BindingIdentifier, ctx, ); + ctx.push_scope_stack(ScopeFlags::empty()); ctx.retag_stack(AncestorType::TSEnumDeclarationMembers); for item in (*((node as *mut u8).add(ancestor::OFFSET_TS_ENUM_DECLARATION_MEMBERS) as *mut Vec)) @@ -3510,6 +3597,7 @@ pub(crate) unsafe fn walk_ts_enum_declaration<'a, Tr: Traverse<'a>>( { walk_ts_enum_member(traverser, item as *mut _, ctx); } + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_ts_enum_declaration(&mut *node, ctx); } @@ -4232,6 +4320,7 @@ pub(crate) unsafe fn walk_ts_type_parameter<'a, Tr: Traverse<'a>>( ) { traverser.enter_ts_type_parameter(&mut *node, ctx); ctx.push_stack(Ancestor::TSTypeParameterName(ancestor::TSTypeParameterWithoutName(node))); + ctx.push_scope_stack(ScopeFlags::empty()); walk_binding_identifier( traverser, (node as *mut u8).add(ancestor::OFFSET_TS_TYPE_PARAMETER_NAME) as *mut BindingIdentifier, @@ -4249,6 +4338,7 @@ pub(crate) unsafe fn walk_ts_type_parameter<'a, Tr: Traverse<'a>>( ctx.retag_stack(AncestorType::TSTypeParameterDefault); walk_ts_type(traverser, field as *mut _, ctx); } + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_ts_type_parameter(&mut *node, ctx); } @@ -4730,11 +4820,13 @@ pub(crate) unsafe fn walk_ts_module_block<'a, Tr: Traverse<'a>>( ) { traverser.enter_ts_module_block(&mut *node, ctx); ctx.push_stack(Ancestor::TSModuleBlockBody(ancestor::TSModuleBlockWithoutBody(node))); + ctx.push_scope_stack(ScopeFlags::TsModuleBlock); walk_statements( traverser, (node as *mut u8).add(ancestor::OFFSET_TS_MODULE_BLOCK_BODY) as *mut Vec, ctx, ); + ctx.pop_scope_stack(); ctx.pop_stack(); traverser.exit_ts_module_block(&mut *node, ctx); }