提交 e6c271ae 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix some bugs for graph surgery

GitOrigin-RevId: 6328a84cbc8554c847530ce990331145ca82e043
上级 2d54ad18
......@@ -122,18 +122,18 @@ def _is_leaf(node):
return isinstance(node, RawTensor)
_enable_node_to_tensor = False
_enable_graph_surgery_mode = False
def _convert_node_flag():
return _enable_node_to_tensor
def _graph_surgery_mode():
return _enable_graph_surgery_mode
def _set_convert_node_flag(flag: bool = False):
global _enable_node_to_tensor
pre_flag = _enable_node_to_tensor
_enable_node_to_tensor = flag
return pre_flag
def _set_graph_surgery_mode(mode: bool):
global _enable_graph_surgery_mode
pre_mode = _enable_graph_surgery_mode
_enable_graph_surgery_mode = mode
return pre_mode
def _node_to_tensor(*args, **kwargs):
......@@ -145,11 +145,11 @@ def _node_to_tensor(*args, **kwargs):
active_module_tracer().current_scope()._add_input(n)
value = n.value
if value is None:
flag = _set_convert_node_flag(False)
flag = _set_graph_surgery_mode(False)
unset_module_tracing()
value = F.zeros(shape=n._shape, dtype=n._dtype)
set_module_tracing()
_set_convert_node_flag(flag)
_set_graph_surgery_mode(flag)
orig_n = NodeMixin.get(value, None)
if orig_n is None or "setitem" not in orig_n._name:
NodeMixin.wrap_safe(value, n)
......@@ -180,17 +180,25 @@ def _tensor_to_node(tensors):
def _wrap_method_to_tensor_node():
def _any_method(name):
def _any_method(name, func):
def _any(*args, **kwargs):
args, kwargs = _node_to_tensor(*args, **kwargs)
attr = getattr(args[0], name)
outs = attr
if callable(attr):
outs = attr(*(args[1:]), **kwargs)
if name == "__setitem__":
_node_to_tensor(outs)
return None
outs = _tensor_to_node(outs)
if is_tracing_module() and _graph_surgery_mode():
args, kwargs = _node_to_tensor(*args, **kwargs)
attr = getattr(args[0], name)
outs = attr
if callable(attr):
outs = attr(*(args[1:]), **kwargs)
if name == "__setitem__":
_node_to_tensor(outs)
return None
outs = _tensor_to_node(outs)
return outs
else:
outs = func
if callable(func):
outs = func(*args, **kwargs)
if isinstance(func, property):
outs = func.__get__(*args, **kwargs)
return outs
return _any
......@@ -199,9 +207,9 @@ def _wrap_method_to_tensor_node():
for method in get_tensor_wrapable_method():
patch = PatchedFn(TensorNode, method)
if type(getattr(Tensor, method)) == property:
patch.set_func(property(_any_method(method)))
patch.set_func(property(_any_method(method, patch.origin_fn)))
else:
patch.set_func(_any_method(method))
patch.set_func(_any_method(method, patch.origin_fn))
tensor_method_patch.append(patch)
return tensor_method_patch
......@@ -209,7 +217,7 @@ def _wrap_method_to_tensor_node():
def _convert_node_and_tensor(orig_func):
@functools.wraps(orig_func)
def _convert(*args, **kwargs):
if _convert_node_flag() and is_tracing_module():
if is_tracing_module() and _graph_surgery_mode():
args, kwargs = _node_to_tensor(*args, **kwargs)
rst = orig_func(*args, **kwargs, method_func=_convert)
rst = _tensor_to_node(rst)
......@@ -224,31 +232,35 @@ def _convert_node_and_tensor(orig_func):
def _wrap_mnode_getattr(orig_getattr):
@functools.wraps(orig_getattr)
def wraped_fn(self, name):
obj = self.owner
current_graph = active_module_tracer().current_scope()
if self.top_graph is not None:
current_graph._add_input(self)
attr = getattr(obj, name)
node = attr
if not isinstance(attr, TracedModuleBuilder):
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(obj, name, attr)
if is_tracing_module() and _graph_surgery_mode():
obj = self.owner
current_graph = active_module_tracer().current_scope()
if self.top_graph is not None:
current_graph._add_input(self)
attr = getattr(obj, name)
node = attr
if not isinstance(attr, TracedModuleBuilder):
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(obj, name, attr)
if isinstance(attr, (NodeMixin, RawTensor)):
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
self,
type=NodeMixin.get_wrapped_type(attr),
attr_name=name,
name="",
),
)
if isinstance(attr, (NodeMixin, RawTensor)):
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
self,
type=NodeMixin.get_wrapped_type(attr),
attr_name=name,
name="",
),
)
if isinstance(attr, (NodeMixin, RawTensor)):
node = NodeMixin.get(attr)
if isinstance(node, ModuleNode):
node._owner = weakref.ref(attr)
node = NodeMixin.get(attr)
if isinstance(node, ModuleNode) and isinstance(attr, (NodeMixin, Module)):
node._owner = weakref.ref(attr)
return node
else:
node = object.__getattribute__(self, name)
return node
return wraped_fn
......@@ -257,10 +269,13 @@ def _wrap_mnode_getattr(orig_getattr):
def _wrap_mnode_call(orig_call):
@functools.wraps(orig_call)
def wraped_fn(self, *args, **kwargs):
obj = self.owner
if self.top_graph is not None:
active_module_tracer().current_scope()._add_input(self)
rst = obj(*args, **kwargs)
if is_tracing_module() and _graph_surgery_mode():
obj = self.owner
if self.top_graph is not None:
active_module_tracer().current_scope()._add_input(self)
rst = obj(*args, **kwargs)
else:
raise TypeError("'ModuleNode' object is not callable")
return rst
return wraped_fn
......@@ -284,7 +299,7 @@ class _InsertExprs:
Node._set_next_id(node_id)
Expr._set_next_id(expr_id)
set_module_tracing()
_set_convert_node_flag(True)
_set_graph_surgery_mode(True)
assert active_module_tracer() is None
set_active_module_tracer(
module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x)))
......@@ -303,20 +318,30 @@ class _InsertExprs:
if va is not None:
return False
active_module_tracer().patcher.__exit__(ty, va, tr)
_set_convert_node_flag(False)
while self._tensor_method_patch:
pf = self._tensor_method_patch.pop()
pf.set_func(pf.origin_fn)
# delete ModuleNode.__call__ to avoid entering the
# ModuleNode.__init__ method when call a ModuleNode object.
delattr(ModuleNode, "__call__")
module = self.graph.inputs[0].owner
for k, v in module.__dict__.items():
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(module, k, v)
def build_traced_module(
module: TracedModuleBuilder, target_module: TracedModule
):
for k, v in module.__dict__.items():
if isinstance(v, TracedModuleBuilder):
traced_v = v.build()
build_traced_module(v, traced_v)
setattr(target_module, k, traced_v)
build_traced_module(module, module)
set_symbolic_shape(self.use_sym_shape)
_set_graph_surgery_mode(False)
set_active_module_tracer(None)
unset_module_tracing()
......@@ -435,7 +460,7 @@ class NameSpace:
def unassociate_name_with_obj(self, node: Node):
assert node.name in self.used_names
assert self.used_names[node.name] is node
# assert self.used_names[node.name] is node
self._used_names[node.name] = None
@property
......@@ -1364,6 +1389,8 @@ class TracedModuleBuilder(NodeMixin):
for node in self.nodes:
node.module_type = mod_type
return self._mod
elif isinstance(self._mod, TracedModule) and _graph_surgery_mode():
return self._mod
else:
is_qat = isinstance(self._mod, QATModule) or (
......@@ -1409,6 +1436,10 @@ class TracedModuleBuilder(NodeMixin):
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
is_graph_surgery_mode = _graph_surgery_mode()
if isinstance(self._mod, TracedModule) and is_graph_surgery_mode:
_set_graph_surgery_mode(False)
# prepare args and kwargs for inner graph
if "method_func" in kwargs:
kwargs.pop("method_func")
......@@ -1514,7 +1545,7 @@ class TracedModuleBuilder(NodeMixin):
)
rst = type(self._mod).forward(*args, **kwargs)
if _convert_node_flag():
if _graph_surgery_mode():
rst = _node_to_tensor(rst)[0][0]
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
......@@ -1536,6 +1567,7 @@ class TracedModuleBuilder(NodeMixin):
callnode.add_outputs(outputs)
self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def
_set_graph_surgery_mode(is_graph_surgery_mode)
return rst
def __setattr__(self, name, value):
......@@ -1556,7 +1588,7 @@ class TracedModuleBuilder(NodeMixin):
return active_module_tracer().patcher.wrap_fn(attr)
if isinstance(attr, (List, Dict)):
flag = _set_convert_node_flag(False)
flag = _set_graph_surgery_mode(False)
unset_module_tracing()
has_module, m_container = replace_container_with_module_container(attr)
if m_container:
......@@ -1567,7 +1599,7 @@ class TracedModuleBuilder(NodeMixin):
" Module and Non-Module objects."
)
set_module_tracing()
_set_convert_node_flag(flag)
_set_graph_surgery_mode(flag)
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
......@@ -1628,20 +1660,25 @@ class _expr_iter:
self._visited_graph = set()
def __iter__(self):
for inp_node in self.graph.inputs:
yield inp_node.expr
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
if (
self.recursive
and expr.graph is not None
and id(expr.graph) not in self._visited_graph
):
yield from expr.graph.exprs(self.recursive)
self._visited_graph.add(id(expr.graph))
else:
yield expr
yield from self._gen_expr(self.graph)
def _gen_expr(self, graph: InternalGraph):
visit_inp = set()
for inp_node in graph.inputs:
if inp_node not in visit_inp:
yield inp_node.expr
visit_inp.add(inp_node)
for expr in graph._exprs:
yield expr
if (
self.recursive
and hasattr(expr, "graph")
and expr.graph is not None
and id(expr.graph) not in self._visited_graph
):
self._visited_graph.add(id(expr.graph))
yield from self._gen_expr(expr.graph)
class _node_iter:
......
......@@ -15,7 +15,7 @@ import megengine.functional as F
import megengine.module as M
import megengine.module.qat as qat
from megengine.module.identity import Identity
from megengine.traced_module import trace_module
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
from megengine.traced_module.node import ModuleNode, Node, TensorNode
......@@ -182,7 +182,6 @@ def test_insert_module():
setattr(traced_module, "neg", Neg(name="neg"))
setattr(traced_module, "neg2", Neg(name="neg"))
setattr(traced_module, "param", F.zeros((1,)))
with graph.insert_exprs():
neg_out = self.neg(relu_out)
neg_out = self.neg2(relu_out)
......@@ -199,6 +198,32 @@ def test_insert_module():
if isinstance(n, TensorNode):
assert n.value is None
traced_module, x, expect = _init_module()
setattr(traced_module.block0, "neg", Neg(name=None))
graph = traced_module.graph
self = graph.inputs[0]
out_node = graph.outputs[0]
with graph.insert_exprs():
neg_out = self.block0.neg(out_node)
graph.replace_node({out_node: neg_out})
graph.compile()
np.testing.assert_allclose(expect, -traced_module(x), atol=1e-6)
assert isinstance(traced_module.block0.neg, TracedModule)
assert traced_module.block0.neg.graph is not None
setattr(traced_module.block0.neg, "neg", Neg(name=None))
setattr(traced_module.block0.neg.neg, "relu", M.ReLU())
out_node = graph.outputs[0]
with graph.insert_exprs():
neg_out = self.block0.neg.neg(out_node)
neg_out = self.block0.neg.neg(neg_out)
relu_out = self.block0.neg.neg.relu(neg_out)
graph.replace_node({out_node: relu_out})
graph.compile()
np.testing.assert_allclose(F.relu(-expect), traced_module(x), atol=1e-6)
assert isinstance(traced_module.block0.neg.neg, TracedModule)
assert traced_module.block0.neg.neg.graph is not None
def test_insert_qat_module():
class concat(qat.Concat):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册