fix(transformer): remove AstBuilder::copy from arrow functions transform (#5825)

Remove `AstBuilder::copy` calls by taking ownership of the `ArrowFunctionExpression` at start of transform.
This commit is contained in:
overlookmotel 2024-09-18 02:23:32 +00:00
parent c6d97e9366
commit d74c7fa0c9

View file

@ -180,8 +180,14 @@ impl<'a> Traverse<'a> for ArrowFunctions<'a> {
fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) {
match expr {
Expression::ArrowFunctionExpression(arrow_function_expr) => {
*expr = self.transform_arrow_function_expression(arrow_function_expr, ctx);
Expression::ArrowFunctionExpression(_) => {
let Expression::ArrowFunctionExpression(arrow_function_expr) =
ctx.ast.move_expression(expr)
else {
unreachable!()
};
*expr = self.transform_arrow_function_expression(arrow_function_expr.unbox(), ctx);
self.stacks.pop();
}
Expression::FunctionExpression(_) => {
@ -253,22 +259,18 @@ impl<'a> ArrowFunctions<'a> {
fn transform_arrow_function_expression(
&mut self,
arrow_function_expr: &mut ArrowFunctionExpression<'a>,
arrow_function_expr: ArrowFunctionExpression<'a>,
ctx: &mut TraverseCtx<'a>,
) -> Expression<'a> {
// SAFETY: `ast.copy` is unsound! We need to fix.
let mut body = unsafe { self.ctx.ast.copy(&arrow_function_expr.body) };
let mut body = arrow_function_expr.body;
if arrow_function_expr.expression {
let first_stmt = body.statements.remove(0);
if let Statement::ExpressionStatement(stmt) = first_stmt {
let return_statement = self.ctx.ast.statement_return(
stmt.span,
// SAFETY: `ast.copy` is unsound! We need to fix.
Some(unsafe { self.ctx.ast.copy(&stmt.expression) }),
);
body.statements.push(return_statement);
}
assert!(body.statements.len() == 1);
let stmt = body.statements.pop().unwrap();
let Statement::ExpressionStatement(stmt) = stmt else { unreachable!() };
let stmt = stmt.unbox();
let return_statement = self.ctx.ast.statement_return(stmt.span, Some(stmt.expression));
body.statements.push(return_statement);
}
// There shouldn't need to be a conditional here. Every arrow function should have a scope ID.
@ -295,13 +297,10 @@ impl<'a> ArrowFunctions<'a> {
false,
arrow_function_expr.r#async,
false,
// SAFETY: `ast.copy` is unsound! We need to fix.
unsafe { self.ctx.ast.copy(&arrow_function_expr.type_parameters) },
arrow_function_expr.type_parameters,
None::<TSThisParameter<'a>>,
// SAFETY: `ast.copy` is unsound! We need to fix.
unsafe { self.ctx.ast.copy(&arrow_function_expr.params) },
// SAFETY: `ast.copy` is unsound! We need to fix.
unsafe { self.ctx.ast.copy(&arrow_function_expr.return_type) },
arrow_function_expr.params,
arrow_function_expr.return_type,
Some(body),
);
new_function.scope_id.set(scope_id);