提交 1a987b7b 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix infinite loop caused by BackwardFoldScale

GitOrigin-RevId: 898d3ed9c36653d03523870bbe97216639dc977f
上级 63032170
......@@ -165,15 +165,21 @@ class BackwardFoldScale(BackwardPass):
def reset_expr_message_to_none(
self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr],
):
if expr in skip_exprs:
return
scale_message[expr] = None
if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d):
return
for out_node in expr.outputs:
for user in out_node.users:
if user in scale_message:
self.reset_expr_message_to_none(user, scale_message, skip_exprs)
visited_expr = set()
def _forward_trave(expr):
if expr in skip_exprs or expr in visited_expr:
return
visited_expr.add(expr)
scale_message[expr] = None
if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d):
return
for out_node in expr.outputs:
for user in out_node.users:
if user in scale_message:
_forward_trave(user)
_forward_trave(expr)
def before_visit_graph(self, graph: InternalGraph):
var = is_var().check_users(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册