diff --git a/imperative/python/megengine/traced_module/_passes/fold_scale_pass.py b/imperative/python/megengine/traced_module/_passes/fold_scale_pass.py index 878b9a35f266ac70cb4fa9d4e3306f96055231d2..11fd4fa0080e24bfe3ded03d41111773f38838d5 100644 --- a/imperative/python/megengine/traced_module/_passes/fold_scale_pass.py +++ b/imperative/python/megengine/traced_module/_passes/fold_scale_pass.py @@ -165,15 +165,21 @@ class BackwardFoldScale(BackwardPass): def reset_expr_message_to_none( self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr], ): - if expr in skip_exprs: - return - scale_message[expr] = None - if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d): - return - for out_node in expr.outputs: - for user in out_node.users: - if user in scale_message: - self.reset_expr_message_to_none(user, scale_message, skip_exprs) + visited_expr = set() + + def _forward_trave(expr): + if expr in skip_exprs or expr in visited_expr: + return + visited_expr.add(expr) + scale_message[expr] = None + if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d): + return + for out_node in expr.outputs: + for user in out_node.users: + if user in scale_message: + _forward_trave(user) + + _forward_trave(expr) def before_visit_graph(self, graph: InternalGraph): var = is_var().check_users(False)