提交 e6c271ae 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix some bugs for graph surgery

GitOrigin-RevId: 6328a84cbc8554c847530ce990331145ca82e043
上级 2d54ad18
...@@ -122,18 +122,18 @@ def _is_leaf(node): ...@@ -122,18 +122,18 @@ def _is_leaf(node):
return isinstance(node, RawTensor) return isinstance(node, RawTensor)
_enable_node_to_tensor = False _enable_graph_surgery_mode = False
def _convert_node_flag(): def _graph_surgery_mode():
return _enable_node_to_tensor return _enable_graph_surgery_mode
def _set_convert_node_flag(flag: bool = False): def _set_graph_surgery_mode(mode: bool):
global _enable_node_to_tensor global _enable_graph_surgery_mode
pre_flag = _enable_node_to_tensor pre_mode = _enable_graph_surgery_mode
_enable_node_to_tensor = flag _enable_graph_surgery_mode = mode
return pre_flag return pre_mode
def _node_to_tensor(*args, **kwargs): def _node_to_tensor(*args, **kwargs):
...@@ -145,11 +145,11 @@ def _node_to_tensor(*args, **kwargs): ...@@ -145,11 +145,11 @@ def _node_to_tensor(*args, **kwargs):
active_module_tracer().current_scope()._add_input(n) active_module_tracer().current_scope()._add_input(n)
value = n.value value = n.value
if value is None: if value is None:
flag = _set_convert_node_flag(False) flag = _set_graph_surgery_mode(False)
unset_module_tracing() unset_module_tracing()
value = F.zeros(shape=n._shape, dtype=n._dtype) value = F.zeros(shape=n._shape, dtype=n._dtype)
set_module_tracing() set_module_tracing()
_set_convert_node_flag(flag) _set_graph_surgery_mode(flag)
orig_n = NodeMixin.get(value, None) orig_n = NodeMixin.get(value, None)
if orig_n is None or "setitem" not in orig_n._name: if orig_n is None or "setitem" not in orig_n._name:
NodeMixin.wrap_safe(value, n) NodeMixin.wrap_safe(value, n)
...@@ -180,8 +180,9 @@ def _tensor_to_node(tensors): ...@@ -180,8 +180,9 @@ def _tensor_to_node(tensors):
def _wrap_method_to_tensor_node(): def _wrap_method_to_tensor_node():
def _any_method(name): def _any_method(name, func):
def _any(*args, **kwargs): def _any(*args, **kwargs):
if is_tracing_module() and _graph_surgery_mode():
args, kwargs = _node_to_tensor(*args, **kwargs) args, kwargs = _node_to_tensor(*args, **kwargs)
attr = getattr(args[0], name) attr = getattr(args[0], name)
outs = attr outs = attr
...@@ -192,6 +193,13 @@ def _wrap_method_to_tensor_node(): ...@@ -192,6 +193,13 @@ def _wrap_method_to_tensor_node():
return None return None
outs = _tensor_to_node(outs) outs = _tensor_to_node(outs)
return outs return outs
else:
outs = func
if callable(func):
outs = func(*args, **kwargs)
if isinstance(func, property):
outs = func.__get__(*args, **kwargs)
return outs
return _any return _any
...@@ -199,9 +207,9 @@ def _wrap_method_to_tensor_node(): ...@@ -199,9 +207,9 @@ def _wrap_method_to_tensor_node():
for method in get_tensor_wrapable_method(): for method in get_tensor_wrapable_method():
patch = PatchedFn(TensorNode, method) patch = PatchedFn(TensorNode, method)
if type(getattr(Tensor, method)) == property: if type(getattr(Tensor, method)) == property:
patch.set_func(property(_any_method(method))) patch.set_func(property(_any_method(method, patch.origin_fn)))
else: else:
patch.set_func(_any_method(method)) patch.set_func(_any_method(method, patch.origin_fn))
tensor_method_patch.append(patch) tensor_method_patch.append(patch)
return tensor_method_patch return tensor_method_patch
...@@ -209,7 +217,7 @@ def _wrap_method_to_tensor_node(): ...@@ -209,7 +217,7 @@ def _wrap_method_to_tensor_node():
def _convert_node_and_tensor(orig_func): def _convert_node_and_tensor(orig_func):
@functools.wraps(orig_func) @functools.wraps(orig_func)
def _convert(*args, **kwargs): def _convert(*args, **kwargs):
if _convert_node_flag() and is_tracing_module(): if is_tracing_module() and _graph_surgery_mode():
args, kwargs = _node_to_tensor(*args, **kwargs) args, kwargs = _node_to_tensor(*args, **kwargs)
rst = orig_func(*args, **kwargs, method_func=_convert) rst = orig_func(*args, **kwargs, method_func=_convert)
rst = _tensor_to_node(rst) rst = _tensor_to_node(rst)
...@@ -224,6 +232,7 @@ def _convert_node_and_tensor(orig_func): ...@@ -224,6 +232,7 @@ def _convert_node_and_tensor(orig_func):
def _wrap_mnode_getattr(orig_getattr): def _wrap_mnode_getattr(orig_getattr):
@functools.wraps(orig_getattr) @functools.wraps(orig_getattr)
def wraped_fn(self, name): def wraped_fn(self, name):
if is_tracing_module() and _graph_surgery_mode():
obj = self.owner obj = self.owner
current_graph = active_module_tracer().current_scope() current_graph = active_module_tracer().current_scope()
if self.top_graph is not None: if self.top_graph is not None:
...@@ -247,9 +256,12 @@ def _wrap_mnode_getattr(orig_getattr): ...@@ -247,9 +256,12 @@ def _wrap_mnode_getattr(orig_getattr):
) )
if isinstance(attr, (NodeMixin, RawTensor)): if isinstance(attr, (NodeMixin, RawTensor)):
node = NodeMixin.get(attr) node = NodeMixin.get(attr)
if isinstance(node, ModuleNode): if isinstance(node, ModuleNode) and isinstance(attr, (NodeMixin, Module)):
node._owner = weakref.ref(attr) node._owner = weakref.ref(attr)
return node return node
else:
node = object.__getattribute__(self, name)
return node
return wraped_fn return wraped_fn
...@@ -257,10 +269,13 @@ def _wrap_mnode_getattr(orig_getattr): ...@@ -257,10 +269,13 @@ def _wrap_mnode_getattr(orig_getattr):
def _wrap_mnode_call(orig_call): def _wrap_mnode_call(orig_call):
@functools.wraps(orig_call) @functools.wraps(orig_call)
def wraped_fn(self, *args, **kwargs): def wraped_fn(self, *args, **kwargs):
if is_tracing_module() and _graph_surgery_mode():
obj = self.owner obj = self.owner
if self.top_graph is not None: if self.top_graph is not None:
active_module_tracer().current_scope()._add_input(self) active_module_tracer().current_scope()._add_input(self)
rst = obj(*args, **kwargs) rst = obj(*args, **kwargs)
else:
raise TypeError("'ModuleNode' object is not callable")
return rst return rst
return wraped_fn return wraped_fn
...@@ -284,7 +299,7 @@ class _InsertExprs: ...@@ -284,7 +299,7 @@ class _InsertExprs:
Node._set_next_id(node_id) Node._set_next_id(node_id)
Expr._set_next_id(expr_id) Expr._set_next_id(expr_id)
set_module_tracing() set_module_tracing()
_set_convert_node_flag(True) _set_graph_surgery_mode(True)
assert active_module_tracer() is None assert active_module_tracer() is None
set_active_module_tracer( set_active_module_tracer(
module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x))) module_tracer(lambda x: _convert_node_and_tensor(_wrapped_function(x)))
...@@ -303,20 +318,30 @@ class _InsertExprs: ...@@ -303,20 +318,30 @@ class _InsertExprs:
if va is not None: if va is not None:
return False return False
active_module_tracer().patcher.__exit__(ty, va, tr) active_module_tracer().patcher.__exit__(ty, va, tr)
_set_convert_node_flag(False)
while self._tensor_method_patch: while self._tensor_method_patch:
pf = self._tensor_method_patch.pop() pf = self._tensor_method_patch.pop()
pf.set_func(pf.origin_fn) pf.set_func(pf.origin_fn)
# delete ModuleNode.__call__ to avoid entering the
# ModuleNode.__init__ method when call a ModuleNode object.
delattr(ModuleNode, "__call__")
module = self.graph.inputs[0].owner module = self.graph.inputs[0].owner
def build_traced_module(
module: TracedModuleBuilder, target_module: TracedModule
):
for k, v in module.__dict__.items(): for k, v in module.__dict__.items():
if isinstance(v, TracedModuleBuilder): if isinstance(v, TracedModuleBuilder):
v = v.build() traced_v = v.build()
setattr(module, k, v) build_traced_module(v, traced_v)
setattr(target_module, k, traced_v)
build_traced_module(module, module)
set_symbolic_shape(self.use_sym_shape) set_symbolic_shape(self.use_sym_shape)
_set_graph_surgery_mode(False)
set_active_module_tracer(None) set_active_module_tracer(None)
unset_module_tracing() unset_module_tracing()
...@@ -435,7 +460,7 @@ class NameSpace: ...@@ -435,7 +460,7 @@ class NameSpace:
def unassociate_name_with_obj(self, node: Node): def unassociate_name_with_obj(self, node: Node):
assert node.name in self.used_names assert node.name in self.used_names
assert self.used_names[node.name] is node # assert self.used_names[node.name] is node
self._used_names[node.name] = None self._used_names[node.name] = None
@property @property
...@@ -1364,6 +1389,8 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1364,6 +1389,8 @@ class TracedModuleBuilder(NodeMixin):
for node in self.nodes: for node in self.nodes:
node.module_type = mod_type node.module_type = mod_type
return self._mod
elif isinstance(self._mod, TracedModule) and _graph_surgery_mode():
return self._mod return self._mod
else: else:
is_qat = isinstance(self._mod, QATModule) or ( is_qat = isinstance(self._mod, QATModule) or (
...@@ -1409,6 +1436,10 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1409,6 +1436,10 @@ class TracedModuleBuilder(NodeMixin):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module) assert isinstance(self._mod, Module)
is_graph_surgery_mode = _graph_surgery_mode()
if isinstance(self._mod, TracedModule) and is_graph_surgery_mode:
_set_graph_surgery_mode(False)
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
if "method_func" in kwargs: if "method_func" in kwargs:
kwargs.pop("method_func") kwargs.pop("method_func")
...@@ -1514,7 +1545,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1514,7 +1545,7 @@ class TracedModuleBuilder(NodeMixin):
) )
rst = type(self._mod).forward(*args, **kwargs) rst = type(self._mod).forward(*args, **kwargs)
if _convert_node_flag(): if _graph_surgery_mode():
rst = _node_to_tensor(rst)[0][0] rst = _node_to_tensor(rst)[0][0]
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
...@@ -1536,6 +1567,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1536,6 +1567,7 @@ class TracedModuleBuilder(NodeMixin):
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
self._argdef_graph_map[callnode.arg_def] = self._body self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def self._argdef_outdef_map[callnode.arg_def] = out_def
_set_graph_surgery_mode(is_graph_surgery_mode)
return rst return rst
def __setattr__(self, name, value): def __setattr__(self, name, value):
...@@ -1556,7 +1588,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1556,7 +1588,7 @@ class TracedModuleBuilder(NodeMixin):
return active_module_tracer().patcher.wrap_fn(attr) return active_module_tracer().patcher.wrap_fn(attr)
if isinstance(attr, (List, Dict)): if isinstance(attr, (List, Dict)):
flag = _set_convert_node_flag(False) flag = _set_graph_surgery_mode(False)
unset_module_tracing() unset_module_tracing()
has_module, m_container = replace_container_with_module_container(attr) has_module, m_container = replace_container_with_module_container(attr)
if m_container: if m_container:
...@@ -1567,7 +1599,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -1567,7 +1599,7 @@ class TracedModuleBuilder(NodeMixin):
" Module and Non-Module objects." " Module and Non-Module objects."
) )
set_module_tracing() set_module_tracing()
_set_convert_node_flag(flag) _set_graph_surgery_mode(flag)
if isinstance(attr, Module): if isinstance(attr, Module):
attr = TracedModuleBuilder(attr) attr = TracedModuleBuilder(attr)
...@@ -1628,20 +1660,25 @@ class _expr_iter: ...@@ -1628,20 +1660,25 @@ class _expr_iter:
self._visited_graph = set() self._visited_graph = set()
def __iter__(self): def __iter__(self):
for inp_node in self.graph.inputs: yield from self._gen_expr(self.graph)
def _gen_expr(self, graph: InternalGraph):
visit_inp = set()
for inp_node in graph.inputs:
if inp_node not in visit_inp:
yield inp_node.expr yield inp_node.expr
for expr in self.graph._exprs: visit_inp.add(inp_node)
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
for expr in graph._exprs:
yield expr yield expr
if ( if (
self.recursive self.recursive
and hasattr(expr, "graph")
and expr.graph is not None and expr.graph is not None
and id(expr.graph) not in self._visited_graph and id(expr.graph) not in self._visited_graph
): ):
yield from expr.graph.exprs(self.recursive)
self._visited_graph.add(id(expr.graph)) self._visited_graph.add(id(expr.graph))
else: yield from self._gen_expr(expr.graph)
yield expr
class _node_iter: class _node_iter:
......
...@@ -15,7 +15,7 @@ import megengine.functional as F ...@@ -15,7 +15,7 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.module.qat as qat import megengine.module.qat as qat
from megengine.module.identity import Identity from megengine.module.identity import Identity
from megengine.traced_module import trace_module from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
from megengine.traced_module.node import ModuleNode, Node, TensorNode from megengine.traced_module.node import ModuleNode, Node, TensorNode
...@@ -182,7 +182,6 @@ def test_insert_module(): ...@@ -182,7 +182,6 @@ def test_insert_module():
setattr(traced_module, "neg", Neg(name="neg")) setattr(traced_module, "neg", Neg(name="neg"))
setattr(traced_module, "neg2", Neg(name="neg")) setattr(traced_module, "neg2", Neg(name="neg"))
setattr(traced_module, "param", F.zeros((1,))) setattr(traced_module, "param", F.zeros((1,)))
with graph.insert_exprs(): with graph.insert_exprs():
neg_out = self.neg(relu_out) neg_out = self.neg(relu_out)
neg_out = self.neg2(relu_out) neg_out = self.neg2(relu_out)
...@@ -199,6 +198,32 @@ def test_insert_module(): ...@@ -199,6 +198,32 @@ def test_insert_module():
if isinstance(n, TensorNode): if isinstance(n, TensorNode):
assert n.value is None assert n.value is None
traced_module, x, expect = _init_module()
setattr(traced_module.block0, "neg", Neg(name=None))
graph = traced_module.graph
self = graph.inputs[0]
out_node = graph.outputs[0]
with graph.insert_exprs():
neg_out = self.block0.neg(out_node)
graph.replace_node({out_node: neg_out})
graph.compile()
np.testing.assert_allclose(expect, -traced_module(x), atol=1e-6)
assert isinstance(traced_module.block0.neg, TracedModule)
assert traced_module.block0.neg.graph is not None
setattr(traced_module.block0.neg, "neg", Neg(name=None))
setattr(traced_module.block0.neg.neg, "relu", M.ReLU())
out_node = graph.outputs[0]
with graph.insert_exprs():
neg_out = self.block0.neg.neg(out_node)
neg_out = self.block0.neg.neg(neg_out)
relu_out = self.block0.neg.neg.relu(neg_out)
graph.replace_node({out_node: relu_out})
graph.compile()
np.testing.assert_allclose(F.relu(-expect), traced_module(x), atol=1e-6)
assert isinstance(traced_module.block0.neg.neg, TracedModule)
assert traced_module.block0.neg.neg.graph is not None
def test_insert_qat_module(): def test_insert_qat_module():
class concat(qat.Concat): class concat(qat.Concat):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册