提交 3ff5ca5f 编写于 作者: M Megvii Engine Team

feat(mge/traced_module): support to modify the name of Node during graph surgery

GitOrigin-RevId: 9ecf6f2c5b700d4c91947def2cdc00cce4e0efc7
上级 3a219209
......@@ -92,7 +92,6 @@ BUILTIN_TENSOR_WRAP_METHOD = [
"dtype",
"grad",
"item",
"name",
"ndim",
"numpy",
"qparams",
......@@ -152,6 +151,11 @@ class module_tracer:
return self._active_scopes[-1]
return None
def top_scope(self):
if self._active_scopes:
return self._active_scopes[0]
return None
class NotExist:
pass
......
......@@ -180,6 +180,25 @@ def _tensor_to_node(tensors):
return nodes
def _name_setter(node: Node, new_name: str):
surgery_mode = _set_graph_surgery_mode(False)
graph = active_module_tracer().current_scope()
if node.top_graph is not None:
top_graph = active_module_tracer().top_scope()
if node is top_graph._namespace.used_names.get(node._name, None):
graph = top_graph
else:
graph = node.top_graph
assert (
graph._namespace.used_names.get(new_name, None) is None
), "The name(%s) is already in use. Please try a different one again." % (new_name)
graph._namespace.unassociate_name_with_obj(node)
node._name = graph._namespace.create_unique_name(new_name, node)
_set_graph_surgery_mode(surgery_mode)
def _wrap_method_to_tensor_node():
def _any_method(name, func):
def _any(*args, **kwargs):
......@@ -213,6 +232,10 @@ def _wrap_method_to_tensor_node():
else:
patch.set_func(_any_method(method, patch.origin_fn))
tensor_method_patch.append(patch)
patch = PatchedFn(Node, "name")
patch.set_func(property(patch.origin_fn.fget, _name_setter))
tensor_method_patch.append(patch)
return tensor_method_patch
......
......@@ -377,6 +377,33 @@ def test_set_node_name():
rename("output")
np.testing.assert_equal(str(graph.outputs[0]), "output")
def add_1(x):
x = x + 1
x.name = "func_add_1"
return x
class ModuleAdd_3(M.Module):
def forward(self, x):
x = x + 1
x.name = "module_add_1"
x = x + 2
return x
setattr(traced_module, "add_3", ModuleAdd_3())
self = graph.inputs[0]
with graph.insert_exprs():
x = output_node + 1
x.name = "_add_1"
x = add_1(x)
x = self.add_3(x)
graph.replace_node({output_node: x})
graph.compile()
assert "_add_1" in graph._namespace.used_names
assert "func_add_1" in graph._namespace.used_names
assert "module_add_1" in traced_module.add_3.graph._namespace.used_names
def test_set_graph_name():
traced_module, x, expect = _init_module()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册