diff --git a/imperative/python/megengine/traced_module/_passes/const_pass.py b/imperative/python/megengine/traced_module/_passes/const_pass.py index 280874a07cd13785cfc6ce1920fb4a97b8cdfcdc..f21b95cb67349ebd888bb343ad60aa0e9a615ba3 100644 --- a/imperative/python/megengine/traced_module/_passes/const_pass.py +++ b/imperative/python/megengine/traced_module/_passes/const_pass.py @@ -128,10 +128,14 @@ class NormElemWise(BackwardPass): cofee, left_node, right_node = 1, None, None if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]: left_node = expr.inputs[0] - right_node = expr.const_val[0][-1] + named_args = (expr.named_args).values() + for v in named_args: + if not isinstance(v, TensorNode): + right_node = v + break if target in ["__rsub__", "__rtruediv__"]: cofee = -1 - if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: + if target in [F.sub, F.div] and left_node is not expr.named_args["x"]: cofee = -1 elif len(expr.inputs) == 2 and ( target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr) @@ -139,7 +143,7 @@ class NormElemWise(BackwardPass): left_node, right_node = expr.inputs if target in ["__rsub__", "__rtruediv__"]: left_node, right_node = right_node, left_node - if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]: + if target in [F.sub, F.div] and left_node is not expr.named_args["x"]: left_node, right_node = right_node, left_node if is_constant(left_node.expr): left_node, right_node = right_node, left_node @@ -152,30 +156,38 @@ class NormElemWise(BackwardPass): right_node = get_const_value(right_node.expr, right_node) graph = expr.top_graph + mul_f, add_f, sub_f, div_f = F.mul, F.add, F.sub, F.div + + def map_f(value, func): + if isinstance(value, (list, tuple)): + return [func(v) for v in value] + return func(value) + with graph.insert_exprs(): - if target in ["__mul__", "__imul__", "__rmul__", F.mul]: + if target in ["__mul__", "__imul__", "__rmul__", mul_f]: out_node = left_node * right_node - elif target in ["__add__", "__iadd__", "__radd__", F.add]: + elif target in ["__add__", "__iadd__", "__radd__", add_f]: out_node = left_node + right_node - elif target in ["__sub__", "__isub__", "__rsub__", F.sub]: + elif target in ["__sub__", "__isub__", "__rsub__", sub_f]: + f_l, f_r = lambda v: v, lambda v: v if cofee == -1: - left_node = F.neg(left_node) + f_l = lambda v: F.neg(v) else: if isinstance(right_node, TensorNode): - right_node = F.neg(right_node) + f_r = lambda v: F.neg(v) else: - right_node = -1 * right_node - out_node = left_node + right_node - elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]: + f_r = lambda v: -1 * v + out_node = map_f(left_node, f_l) + map_f(right_node, f_r) + elif target in ["__truediv__", "__itruediv__", "__rtruediv__", div_f]: + f_l, f_r = lambda v: v, lambda v: v if cofee == -1: - left_node = F.pow(left_node, -1) + f_l = lambda v: F.pow(v, -1) else: if isinstance(right_node, TensorNode): - right_node = F.pow(right_node, -1) + f_r = lambda v: F.pow(v, -1) else: - right_node = 1 / right_node - out_node = left_node * right_node - + f_r = lambda v: 1 / v + out_node = map_f(left_node, f_l) * map_f(right_node, f_r) graph.replace_node({expr.outputs[0]: out_node}) graph.compile() return out_node.expr diff --git a/imperative/python/megengine/traced_module/_passes/matcher.py b/imperative/python/megengine/traced_module/_passes/matcher.py index 0a90bf0d3055dd48e13875e331cd77479a03a801..a5fa3c6b68b718d194ec16c20bc2868ef0dec1eb 100644 --- a/imperative/python/megengine/traced_module/_passes/matcher.py +++ b/imperative/python/megengine/traced_module/_passes/matcher.py @@ -145,7 +145,7 @@ class PatternMatcher: def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool: if not is_call_function(expr, pattern.target): return False - kwargs = expr.kwargs + kwargs = expr.named_args for key, target in pattern.params.items(): value = kwargs.get(key, None) if target != value: diff --git a/imperative/python/test/unit/traced_module/test_passes.py b/imperative/python/test/unit/traced_module/test_passes.py index 56743598f42f93af26a682d219d9a0ea155d841e..863e48a0199c270f31c8bc7e764806d44d7c270d 100644 --- a/imperative/python/test/unit/traced_module/test_passes.py +++ b/imperative/python/test/unit/traced_module/test_passes.py @@ -36,7 +36,7 @@ class MyBlock(M.Module): x2 = F.relu(x2) x2 = x2 * self.scale[1] y = x1 + x2 - y = y + 4 + y = F.add(y, 4) y = self.scale[0] + y y = F.relu(y) * 3 return y