diff --git a/crates/oxc_linter/src/rules/jest/expect_expect.rs b/crates/oxc_linter/src/rules/jest/expect_expect.rs index c13b29e10..3b9c51b74 100644 --- a/crates/oxc_linter/src/rules/jest/expect_expect.rs +++ b/crates/oxc_linter/src/rules/jest/expect_expect.rs @@ -7,6 +7,7 @@ use oxc_diagnostics::OxcDiagnostic; use oxc_macros::declare_oxc_lint; use oxc_span::{GetSpan, Span}; use regex::Regex; +use rustc_hash::FxHashSet; use crate::{ ast_util::get_declaration_of_variable, @@ -127,8 +128,11 @@ fn run<'a>( } } + // Record visited nodes to avoid infinite loop. + let mut visited: FxHashSet = FxHashSet::default(); + let has_assert_function = - check_arguments(call_expr, &rule.assert_function_names, None, ctx); + check_arguments(call_expr, &rule.assert_function_names, &mut visited, ctx); if !has_assert_function { ctx.diagnostic(expect_expect_diagnostic(call_expr.callee.span())); @@ -140,12 +144,12 @@ fn run<'a>( fn check_arguments<'a>( call_expr: &'a CallExpression<'a>, assert_function_names: &[String], - fn_expr_name: Option<&'a str>, + visited: &mut FxHashSet, ctx: &LintContext<'a>, ) -> bool { for argument in &call_expr.arguments { if let Some(expr) = argument.as_expression() { - if check_assert_function_used(expr, assert_function_names, fn_expr_name, ctx) { + if check_assert_function_used(expr, assert_function_names, visited, ctx) { return true; } } @@ -156,25 +160,36 @@ fn check_arguments<'a>( fn check_assert_function_used<'a>( expr: &'a Expression<'a>, assert_function_names: &[String], - fn_expr_name: Option<&'a str>, + visited: &mut FxHashSet, ctx: &LintContext<'a>, ) -> bool { + // If we have visited this node before and didn't find any assert function, we can return + // `false` to avoid infinite loop. + // + // ```javascript + // test("should fail", () => { + // function foo() { + // if (condition) { + // foo() + // } + // } + // foo() + // }) + // ``` + if !visited.insert(expr.span()) { + return false; + } + match expr { Expression::FunctionExpression(fn_expr) => { let body = &fn_expr.body; if let Some(body) = body { - let fn_expr_name = fn_expr.id.as_ref().map(|id| id.name.as_str()); - return check_statements( - &body.statements, - assert_function_names, - fn_expr_name, - ctx, - ); + return check_statements(&body.statements, assert_function_names, visited, ctx); } } Expression::ArrowFunctionExpression(arrow_expr) => { let body = &arrow_expr.body; - return check_statements(&body.statements, assert_function_names, fn_expr_name, ctx); + return check_statements(&body.statements, assert_function_names, visited, ctx); } Expression::CallExpression(call_expr) => { let name = get_node_name(&call_expr.callee); @@ -182,8 +197,13 @@ fn check_assert_function_used<'a>( return true; } + // If CallExpression is not an assert function, we need to check its arguments, it may trigger + // another assert function. + // ```javascript + // it('should pass', () => somePromise().then(() => expect(true).toBeDefined())) + // ``` let has_assert_function = - check_arguments(call_expr, assert_function_names, fn_expr_name, ctx); + check_arguments(call_expr, assert_function_names, visited, ctx); return has_assert_function; } @@ -194,24 +214,13 @@ fn check_assert_function_used<'a>( let AstKind::Function(function) = node.kind() else { return false; }; - // Stop recursing into self - if let Some(name) = fn_expr_name { - if function.id.as_ref().is_some_and(|id| id.name.as_str() == name) { - return false; - } - } let Some(body) = &function.body else { return false; }; - return check_statements(&body.statements, assert_function_names, fn_expr_name, ctx); + return check_statements(&body.statements, assert_function_names, visited, ctx); } Expression::AwaitExpression(expr) => { - return check_assert_function_used( - &expr.argument, - assert_function_names, - fn_expr_name, - ctx, - ); + return check_assert_function_used(&expr.argument, assert_function_names, visited, ctx); } _ => {} }; @@ -222,7 +231,7 @@ fn check_assert_function_used<'a>( fn check_statements<'a>( statements: &'a oxc_allocator::Vec>, assert_function_names: &[String], - fn_expr_name: Option<&'a str>, + visited: &mut FxHashSet, ctx: &LintContext<'a>, ) -> bool { statements.iter().any(|statement| { @@ -230,7 +239,7 @@ fn check_statements<'a>( return check_assert_function_used( &expr_stmt.expression, assert_function_names, - fn_expr_name, + visited, ctx, ); }