提交 6d77f5db 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix NormElemwisePass

GitOrigin-RevId: a92d19a013aba55fce4d5fd19798fc46d123fe97
上级 1d76bd5a
...@@ -128,10 +128,14 @@ class NormElemWise(BackwardPass): ...@@ -128,10 +128,14 @@ class NormElemWise(BackwardPass):
cofee, left_node, right_node = 1, None, None cofee, left_node, right_node = 1, None, None
if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]: if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]:
left_node = expr.inputs[0] left_node = expr.inputs[0]
right_node = expr.const_val[0][-1] named_args = (expr.named_args).values()
for v in named_args:
if not isinstance(v, TensorNode):
right_node = v
break
if target in ["__rsub__", "__rtruediv__"]: if target in ["__rsub__", "__rtruediv__"]:
cofee = -1 cofee = -1
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: if target in [F.sub, F.div] and left_node is not expr.named_args["x"]:
cofee = -1 cofee = -1
elif len(expr.inputs) == 2 and ( elif len(expr.inputs) == 2 and (
target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr) target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr)
...@@ -139,7 +143,7 @@ class NormElemWise(BackwardPass): ...@@ -139,7 +143,7 @@ class NormElemWise(BackwardPass):
left_node, right_node = expr.inputs left_node, right_node = expr.inputs
if target in ["__rsub__", "__rtruediv__"]: if target in ["__rsub__", "__rtruediv__"]:
left_node, right_node = right_node, left_node left_node, right_node = right_node, left_node
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: if target in [F.sub, F.div] and left_node is not expr.named_args["x"]:
left_node, right_node = right_node, left_node left_node, right_node = right_node, left_node
if is_constant(left_node.expr): if is_constant(left_node.expr):
left_node, right_node = right_node, left_node left_node, right_node = right_node, left_node
...@@ -152,30 +156,38 @@ class NormElemWise(BackwardPass): ...@@ -152,30 +156,38 @@ class NormElemWise(BackwardPass):
right_node = get_const_value(right_node.expr, right_node) right_node = get_const_value(right_node.expr, right_node)
graph = expr.top_graph graph = expr.top_graph
mul_f, add_f, sub_f, div_f = F.mul, F.add, F.sub, F.div
def map_f(value, func):
if isinstance(value, (list, tuple)):
return [func(v) for v in value]
return func(value)
with graph.insert_exprs(): with graph.insert_exprs():
if target in ["__mul__", "__imul__", "__rmul__", F.mul]: if target in ["__mul__", "__imul__", "__rmul__", mul_f]:
out_node = left_node * right_node out_node = left_node * right_node
elif target in ["__add__", "__iadd__", "__radd__", F.add]: elif target in ["__add__", "__iadd__", "__radd__", add_f]:
out_node = left_node + right_node out_node = left_node + right_node
elif target in ["__sub__", "__isub__", "__rsub__", F.sub]: elif target in ["__sub__", "__isub__", "__rsub__", sub_f]:
f_l, f_r = lambda v: v, lambda v: v
if cofee == -1: if cofee == -1:
left_node = F.neg(left_node) f_l = lambda v: F.neg(v)
else: else:
if isinstance(right_node, TensorNode): if isinstance(right_node, TensorNode):
right_node = F.neg(right_node) f_r = lambda v: F.neg(v)
else: else:
right_node = -1 * right_node f_r = lambda v: -1 * v
out_node = left_node + right_node out_node = map_f(left_node, f_l) + map_f(right_node, f_r)
elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]: elif target in ["__truediv__", "__itruediv__", "__rtruediv__", div_f]:
f_l, f_r = lambda v: v, lambda v: v
if cofee == -1: if cofee == -1:
left_node = F.pow(left_node, -1) f_l = lambda v: F.pow(v, -1)
else: else:
if isinstance(right_node, TensorNode): if isinstance(right_node, TensorNode):
right_node = F.pow(right_node, -1) f_r = lambda v: F.pow(v, -1)
else: else:
right_node = 1 / right_node f_r = lambda v: 1 / v
out_node = left_node * right_node out_node = map_f(left_node, f_l) * map_f(right_node, f_r)
graph.replace_node({expr.outputs[0]: out_node}) graph.replace_node({expr.outputs[0]: out_node})
graph.compile() graph.compile()
return out_node.expr return out_node.expr
...@@ -145,7 +145,7 @@ class PatternMatcher: ...@@ -145,7 +145,7 @@ class PatternMatcher:
def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool: def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool:
if not is_call_function(expr, pattern.target): if not is_call_function(expr, pattern.target):
return False return False
kwargs = expr.kwargs kwargs = expr.named_args
for key, target in pattern.params.items(): for key, target in pattern.params.items():
value = kwargs.get(key, None) value = kwargs.get(key, None)
if target != value: if target != value:
......
...@@ -36,7 +36,7 @@ class MyBlock(M.Module): ...@@ -36,7 +36,7 @@ class MyBlock(M.Module):
x2 = F.relu(x2) x2 = F.relu(x2)
x2 = x2 * self.scale[1] x2 = x2 * self.scale[1]
y = x1 + x2 y = x1 + x2
y = y + 4 y = F.add(y, 4)
y = self.scale[0] + y y = self.scale[0] + y
y = F.relu(y) * 3 y = F.relu(y) * 3
return y return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册