提交 05de8ba2 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix NormElemwisePass

GitOrigin-RevId: a92d19a013aba55fce4d5fd19798fc46d123fe97
上级 683b3e30
......@@ -128,10 +128,14 @@ class NormElemWise(BackwardPass):
cofee, left_node, right_node = 1, None, None
if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]:
left_node = expr.inputs[0]
right_node = expr.const_val[0][-1]
named_args = (expr.named_args).values()
for v in named_args:
if not isinstance(v, TensorNode):
right_node = v
break
if target in ["__rsub__", "__rtruediv__"]:
cofee = -1
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
if target in [F.sub, F.div] and left_node is not expr.named_args["x"]:
cofee = -1
elif len(expr.inputs) == 2 and (
target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr)
......@@ -139,7 +143,7 @@ class NormElemWise(BackwardPass):
left_node, right_node = expr.inputs
if target in ["__rsub__", "__rtruediv__"]:
left_node, right_node = right_node, left_node
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
if target in [F.sub, F.div] and left_node is not expr.named_args["x"]:
left_node, right_node = right_node, left_node
if is_constant(left_node.expr):
left_node, right_node = right_node, left_node
......@@ -152,30 +156,38 @@ class NormElemWise(BackwardPass):
right_node = get_const_value(right_node.expr, right_node)
graph = expr.top_graph
mul_f, add_f, sub_f, div_f = F.mul, F.add, F.sub, F.div
def map_f(value, func):
if isinstance(value, (list, tuple)):
return [func(v) for v in value]
return func(value)
with graph.insert_exprs():
if target in ["__mul__", "__imul__", "__rmul__", F.mul]:
if target in ["__mul__", "__imul__", "__rmul__", mul_f]:
out_node = left_node * right_node
elif target in ["__add__", "__iadd__", "__radd__", F.add]:
elif target in ["__add__", "__iadd__", "__radd__", add_f]:
out_node = left_node + right_node
elif target in ["__sub__", "__isub__", "__rsub__", F.sub]:
elif target in ["__sub__", "__isub__", "__rsub__", sub_f]:
f_l, f_r = lambda v: v, lambda v: v
if cofee == -1:
left_node = F.neg(left_node)
f_l = lambda v: F.neg(v)
else:
if isinstance(right_node, TensorNode):
right_node = F.neg(right_node)
f_r = lambda v: F.neg(v)
else:
right_node = -1 * right_node
out_node = left_node + right_node
elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]:
f_r = lambda v: -1 * v
out_node = map_f(left_node, f_l) + map_f(right_node, f_r)
elif target in ["__truediv__", "__itruediv__", "__rtruediv__", div_f]:
f_l, f_r = lambda v: v, lambda v: v
if cofee == -1:
left_node = F.pow(left_node, -1)
f_l = lambda v: F.pow(v, -1)
else:
if isinstance(right_node, TensorNode):
right_node = F.pow(right_node, -1)
f_r = lambda v: F.pow(v, -1)
else:
right_node = 1 / right_node
out_node = left_node * right_node
f_r = lambda v: 1 / v
out_node = map_f(left_node, f_l) * map_f(right_node, f_r)
graph.replace_node({expr.outputs[0]: out_node})
graph.compile()
return out_node.expr
......@@ -145,7 +145,7 @@ class PatternMatcher:
def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool:
if not is_call_function(expr, pattern.target):
return False
kwargs = expr.kwargs
kwargs = expr.named_args
for key, target in pattern.params.items():
value = kwargs.get(key, None)
if target != value:
......
......@@ -36,7 +36,7 @@ class MyBlock(M.Module):
x2 = F.relu(x2)
x2 = x2 * self.scale[1]
y = x1 + x2
y = y + 4
y = F.add(y, 4)
y = self.scale[0] + y
y = F.relu(y) * 3
return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册