From e6c271ae46f30c91e5001d3def63ef123e6149f2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 24 Nov 2021 13:29:29 +0800 Subject: [PATCH] fix(mge/traced_module): fix some bugs for graph surgery GitOrigin-RevId: 6328a84cbc8554c847530ce990331145ca82e043 --- .../megengine/traced_module/traced_module.py | 187 +++++++++++------- .../unit/traced_module/test_modification.py | 29 ++- 2 files changed, 139 insertions(+), 77 deletions(-) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 824150ed4..7f6bcb268 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -122,18 +122,18 @@ def _is_leaf(node): return isinstance(node, RawTensor) -_enable_node_to_tensor = False +_enable_graph_surgery_mode = False -def _convert_node_flag(): - return _enable_node_to_tensor +def _graph_surgery_mode(): + return _enable_graph_surgery_mode -def _set_convert_node_flag(flag: bool = False): - global _enable_node_to_tensor - pre_flag = _enable_node_to_tensor - _enable_node_to_tensor = flag - return pre_flag +def _set_graph_surgery_mode(mode: bool): + global _enable_graph_surgery_mode + pre_mode = _enable_graph_surgery_mode + _enable_graph_surgery_mode = mode + return pre_mode def _node_to_tensor(*args, **kwargs): @@ -145,11 +145,11 @@ def _node_to_tensor(*args, **kwargs): active_module_tracer().current_scope()._add_input(n) value = n.value if value is None: - flag = _set_convert_node_flag(False) + flag = _set_graph_surgery_mode(False) unset_module_tracing() value = F.zeros(shape=n._shape, dtype=n._dtype) set_module_tracing() - _set_convert_node_flag(flag) + _set_graph_surgery_mode(flag) orig_n = NodeMixin.get(value, None) if orig_n is None or "setitem" not in orig_n._name: NodeMixin.wrap_safe(value, n) @@ -180,17 +180,25 @@ def _tensor_to_node(tensors): def _wrap_method_to_tensor_node(): - def _any_method(name): + def _any_method(name, func): def _any(*args, **kwargs): - args, kwargs = _node_to_tensor(*args, **kwargs) - attr = getattr(args[0], name) - outs = attr - if callable(attr): - outs = attr(*(args[1:]), **kwargs) - if name == "__setitem__": - _node_to_tensor(outs) - return None - outs = _tensor_to_node(outs) + if is_tracing_module() and _graph_surgery_mode(): + args, kwargs = _node_to_tensor(*args, **kwargs) + attr = getattr(args[0], name) + outs = attr + if callable(attr): + outs = attr(*(args[1:]), **kwargs) + if name == "__setitem__": + _node_to_tensor(outs) + return None + outs = _tensor_to_node(outs) + return outs + else: + outs = func + if callable(func): + outs = func(*args, **kwargs) + if isinstance(func, property): + outs = func.__get__(*args, **kwargs) return outs return _any @@ -199,9 +207,9 @@ def _wrap_method_to_tensor_node(): for method in get_tensor_wrapable_method(): patch = PatchedFn(TensorNode, method) if type(getattr(Tensor, method)) == property: - patch.set_func(property(_any_method(method))) + patch.set_func(property(_any_method(method, patch.origin_fn))) else: - patch.set_func(_any_method(method)) + patch.set_func(_any_method(method, patch.origin_fn)) tensor_method_patch.append(patch) return tensor_method_patch @@ -209,7 +217,7 @@ def _wrap_method_to_tensor_node(): def _convert_node_and_tensor(orig_func): @functools.wraps(orig_func) def _convert(*args, **kwargs): - if _convert_node_flag() and is_tracing_module(): + if is_tracing_module() and _graph_surgery_mode(): args, kwargs = _node_to_tensor(*args, **kwargs) rst = orig_func(*args, **kwargs, method_func=_convert) rst = _tensor_to_node(rst) @@ -224,31 +232,35 @@ def _convert_node_and_tensor(orig_func): def _wrap_mnode_getattr(orig_getattr): @functools.wraps(orig_getattr) def wraped_fn(self, name): - obj = self.owner - current_graph = active_module_tracer().current_scope() - if self.top_graph is not None: - current_graph._add_input(self) - attr = getattr(obj, name) - node = attr - if not isinstance(attr, TracedModuleBuilder): - if isinstance(attr, Module): - attr = TracedModuleBuilder(attr) - setattr(obj, name, attr) - + if is_tracing_module() and _graph_surgery_mode(): + obj = self.owner + current_graph = active_module_tracer().current_scope() + if self.top_graph is not None: + current_graph._add_input(self) + attr = getattr(obj, name) + node = attr + if not isinstance(attr, TracedModuleBuilder): + if isinstance(attr, Module): + attr = TracedModuleBuilder(attr) + setattr(obj, name, attr) + + if isinstance(attr, (NodeMixin, RawTensor)): + NodeMixin.wrap( + attr, + lambda: GetAttr.make( + self, + type=NodeMixin.get_wrapped_type(attr), + attr_name=name, + name="", + ), + ) if isinstance(attr, (NodeMixin, RawTensor)): - NodeMixin.wrap( - attr, - lambda: GetAttr.make( - self, - type=NodeMixin.get_wrapped_type(attr), - attr_name=name, - name="", - ), - ) - if isinstance(attr, (NodeMixin, RawTensor)): - node = NodeMixin.get(attr) - if isinstance(node, ModuleNode): - node._owner = weakref.ref(attr) + node = NodeMixin.get(attr) + if isinstance(node, ModuleNode) and isinstance(attr, (NodeMixin, Module)): + node._owner = weakref.ref(attr) + return node + else: + node = object.__getattribute__(self, name) return node return wraped_fn @@ -257,10 +269,13 @@ def _wrap_mnode_getattr(orig_getattr): def _wrap_mnode_call(orig_call): @functools.wraps(orig_call) def wraped_fn(self, *args, **kwargs): - obj = self.owner - if self.top_graph is not None: - active_module_tracer().current_scope()._add_input(self) - rst = obj(*args, **kwargs) + if is_tracing_module() and _graph_surgery_mode(): + obj = self.owner + if self.top_graph is not None: + active_module_tracer().current_scope()._add_input(self) + rst = obj(*args, **kwargs) + else: + raise TypeError("'ModuleNode' object is not callable") return rst return wraped_fn @@ -284,7 +299,7 @@ class _InsertExprs: Node._set_next_id(node_id) Expr._set_next_id(expr_id) set_module_tracing() - _set_convert_node_flag(True) + _set_graph_surgery_mode(True) assert active_module_tracer() is None set_active_module_tracer( module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x))) @@ -303,20 +318,30 @@ class _InsertExprs: if va is not None: return False active_module_tracer().patcher.__exit__(ty, va, tr) - _set_convert_node_flag(False) while self._tensor_method_patch: pf = self._tensor_method_patch.pop() pf.set_func(pf.origin_fn) + # delete ModuleNode.__call__ to avoid entering the + # ModuleNode.__init__ method when call a ModuleNode object. + delattr(ModuleNode, "__call__") + module = self.graph.inputs[0].owner - for k, v in module.__dict__.items(): - if isinstance(v, TracedModuleBuilder): - v = v.build() - setattr(module, k, v) + def build_traced_module( + module: TracedModuleBuilder, target_module: TracedModule + ): + for k, v in module.__dict__.items(): + if isinstance(v, TracedModuleBuilder): + traced_v = v.build() + build_traced_module(v, traced_v) + setattr(target_module, k, traced_v) + + build_traced_module(module, module) set_symbolic_shape(self.use_sym_shape) + _set_graph_surgery_mode(False) set_active_module_tracer(None) unset_module_tracing() @@ -435,7 +460,7 @@ class NameSpace: def unassociate_name_with_obj(self, node: Node): assert node.name in self.used_names - assert self.used_names[node.name] is node + # assert self.used_names[node.name] is node self._used_names[node.name] = None @property @@ -1364,6 +1389,8 @@ class TracedModuleBuilder(NodeMixin): for node in self.nodes: node.module_type = mod_type + return self._mod + elif isinstance(self._mod, TracedModule) and _graph_surgery_mode(): return self._mod else: is_qat = isinstance(self._mod, QATModule) or ( @@ -1409,6 +1436,10 @@ class TracedModuleBuilder(NodeMixin): def __call__(self, *args, **kwargs): assert isinstance(self._mod, Module) + is_graph_surgery_mode = _graph_surgery_mode() + if isinstance(self._mod, TracedModule) and is_graph_surgery_mode: + _set_graph_surgery_mode(False) + # prepare args and kwargs for inner graph if "method_func" in kwargs: kwargs.pop("method_func") @@ -1514,7 +1545,7 @@ class TracedModuleBuilder(NodeMixin): ) rst = type(self._mod).forward(*args, **kwargs) - if _convert_node_flag(): + if _graph_surgery_mode(): rst = _node_to_tensor(rst)[0][0] outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) @@ -1536,6 +1567,7 @@ class TracedModuleBuilder(NodeMixin): callnode.add_outputs(outputs) self._argdef_graph_map[callnode.arg_def] = self._body self._argdef_outdef_map[callnode.arg_def] = out_def + _set_graph_surgery_mode(is_graph_surgery_mode) return rst def __setattr__(self, name, value): @@ -1556,7 +1588,7 @@ class TracedModuleBuilder(NodeMixin): return active_module_tracer().patcher.wrap_fn(attr) if isinstance(attr, (List, Dict)): - flag = _set_convert_node_flag(False) + flag = _set_graph_surgery_mode(False) unset_module_tracing() has_module, m_container = replace_container_with_module_container(attr) if m_container: @@ -1567,7 +1599,7 @@ class TracedModuleBuilder(NodeMixin): " Module and Non-Module objects." ) set_module_tracing() - _set_convert_node_flag(flag) + _set_graph_surgery_mode(flag) if isinstance(attr, Module): attr = TracedModuleBuilder(attr) @@ -1628,20 +1660,25 @@ class _expr_iter: self._visited_graph = set() def __iter__(self): - for inp_node in self.graph.inputs: - yield inp_node.expr - for expr in self.graph._exprs: - if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): - yield expr - if ( - self.recursive - and expr.graph is not None - and id(expr.graph) not in self._visited_graph - ): - yield from expr.graph.exprs(self.recursive) - self._visited_graph.add(id(expr.graph)) - else: - yield expr + yield from self._gen_expr(self.graph) + + def _gen_expr(self, graph: InternalGraph): + visit_inp = set() + for inp_node in graph.inputs: + if inp_node not in visit_inp: + yield inp_node.expr + visit_inp.add(inp_node) + + for expr in graph._exprs: + yield expr + if ( + self.recursive + and hasattr(expr, "graph") + and expr.graph is not None + and id(expr.graph) not in self._visited_graph + ): + self._visited_graph.add(id(expr.graph)) + yield from self._gen_expr(expr.graph) class _node_iter: diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 4797a8daf..1a9c99f9f 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -15,7 +15,7 @@ import megengine.functional as F import megengine.module as M import megengine.module.qat as qat from megengine.module.identity import Identity -from megengine.traced_module import trace_module +from megengine.traced_module import TracedModule, trace_module from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input from megengine.traced_module.node import ModuleNode, Node, TensorNode @@ -182,7 +182,6 @@ def test_insert_module(): setattr(traced_module, "neg", Neg(name="neg")) setattr(traced_module, "neg2", Neg(name="neg")) setattr(traced_module, "param", F.zeros((1,))) - with graph.insert_exprs(): neg_out = self.neg(relu_out) neg_out = self.neg2(relu_out) @@ -199,6 +198,32 @@ def test_insert_module(): if isinstance(n, TensorNode): assert n.value is None + traced_module, x, expect = _init_module() + setattr(traced_module.block0, "neg", Neg(name=None)) + graph = traced_module.graph + self = graph.inputs[0] + out_node = graph.outputs[0] + with graph.insert_exprs(): + neg_out = self.block0.neg(out_node) + graph.replace_node({out_node: neg_out}) + graph.compile() + np.testing.assert_allclose(expect, -traced_module(x), atol=1e-6) + assert isinstance(traced_module.block0.neg, TracedModule) + assert traced_module.block0.neg.graph is not None + + setattr(traced_module.block0.neg, "neg", Neg(name=None)) + setattr(traced_module.block0.neg.neg, "relu", M.ReLU()) + out_node = graph.outputs[0] + with graph.insert_exprs(): + neg_out = self.block0.neg.neg(out_node) + neg_out = self.block0.neg.neg(neg_out) + relu_out = self.block0.neg.neg.relu(neg_out) + graph.replace_node({out_node: relu_out}) + graph.compile() + np.testing.assert_allclose(F.relu(-expect), traced_module(x), atol=1e-6) + assert isinstance(traced_module.block0.neg.neg, TracedModule) + assert traced_module.block0.neg.neg.graph is not None + def test_insert_qat_module(): class concat(qat.Concat): -- GitLab