From 3ff5ca5ffed48e4c6a67bff49980d749fbce85f3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 27 Dec 2021 18:53:16 +0800 Subject: [PATCH] feat(mge/traced_module): support to modify the name of Node during graph surgery GitOrigin-RevId: 9ecf6f2c5b700d4c91947def2cdc00cce4e0efc7 --- .../megengine/traced_module/module_tracer.py | 6 ++++- .../megengine/traced_module/traced_module.py | 23 ++++++++++++++++ .../unit/traced_module/test_modification.py | 27 +++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index db2bf0552..4cba9a35d 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -92,7 +92,6 @@ BUILTIN_TENSOR_WRAP_METHOD = [ "dtype", "grad", "item", - "name", "ndim", "numpy", "qparams", @@ -152,6 +151,11 @@ class module_tracer: return self._active_scopes[-1] return None + def top_scope(self): + if self._active_scopes: + return self._active_scopes[0] + return None + class NotExist: pass diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 058fb4a6a..0c860f93a 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -180,6 +180,25 @@ def _tensor_to_node(tensors): return nodes +def _name_setter(node: Node, new_name: str): + surgery_mode = _set_graph_surgery_mode(False) + graph = active_module_tracer().current_scope() + + if node.top_graph is not None: + top_graph = active_module_tracer().top_scope() + if node is top_graph._namespace.used_names.get(node._name, None): + graph = top_graph + else: + graph = node.top_graph + + assert ( + graph._namespace.used_names.get(new_name, None) is None + ), "The name(%s) is already in use. Please try a different one again." % (new_name) + graph._namespace.unassociate_name_with_obj(node) + node._name = graph._namespace.create_unique_name(new_name, node) + _set_graph_surgery_mode(surgery_mode) + + def _wrap_method_to_tensor_node(): def _any_method(name, func): def _any(*args, **kwargs): @@ -213,6 +232,10 @@ def _wrap_method_to_tensor_node(): else: patch.set_func(_any_method(method, patch.origin_fn)) tensor_method_patch.append(patch) + + patch = PatchedFn(Node, "name") + patch.set_func(property(patch.origin_fn.fget, _name_setter)) + tensor_method_patch.append(patch) return tensor_method_patch diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 1a9c99f9f..036924a7c 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -377,6 +377,33 @@ def test_set_node_name(): rename("output") np.testing.assert_equal(str(graph.outputs[0]), "output") + def add_1(x): + x = x + 1 + x.name = "func_add_1" + return x + + class ModuleAdd_3(M.Module): + def forward(self, x): + x = x + 1 + x.name = "module_add_1" + x = x + 2 + return x + + setattr(traced_module, "add_3", ModuleAdd_3()) + + self = graph.inputs[0] + with graph.insert_exprs(): + x = output_node + 1 + x.name = "_add_1" + x = add_1(x) + x = self.add_3(x) + graph.replace_node({output_node: x}) + graph.compile() + + assert "_add_1" in graph._namespace.used_names + assert "func_add_1" in graph._namespace.used_names + assert "module_add_1" in traced_module.add_3.graph._namespace.used_names + def test_set_graph_name(): traced_module, x, expect = _init_module() -- GitLab