提交 b1c46ba4 编写于 作者: M Megvii Engine Team

feat(traced_module): add some functions of graph modification

GitOrigin-RevId: ac0603057adaedf864f2d0ceb7bfb6d3c5a50640
上级 4bb25369
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import builtins import builtins
import collections import collections
import inspect
from typing import Callable, List from typing import Callable, List
from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
...@@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor ...@@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const from ...core.ops.special import Const
from ...module import Module from ...module import Module
from ...tensor import Tensor from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef from .pytree import TreeDef, tree_flatten
class Expr: class Expr:
...@@ -38,25 +39,28 @@ class Expr: ...@@ -38,25 +39,28 @@ class Expr:
for val in vals: for val in vals:
node = NodeMixin.get(val, None) node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)): if isinstance(node, (TensorNode, ModuleNode)):
if node not in self.inputs: self.inputs.append(node)
self.inputs.append(node) node.users.append(self)
else: else:
assert node is None assert node is None
assert type(val) in builtins.__dict__.values()
idx = len(self.inputs) + len(self.const_val) idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val)) self.const_val.append((idx, val))
def add_outputs(self, outputs): def add_outputs(self, outputs, check_inplace=True):
self.outputs = [] self.outputs = []
if not isinstance(outputs, collections.Sequence): if outputs is not None:
outputs = (outputs,) if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
for i in outputs: for i in outputs:
assert isinstance(i, RawTensor) assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) node = NodeMixin.get(i, None) if check_inplace else None
self.outputs.append(
node if node else NodeMixin.get_wrapped_type(i)(self)
)
for i, node in zip(outputs, self.outputs,): for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node) NodeMixin.wrap_safe(i, node)
def unflatten_args(self, inputs): def unflatten_args(self, inputs):
if self.arg_def is not None: if self.arg_def is not None:
...@@ -110,6 +114,7 @@ class GetAttr(Expr): ...@@ -110,6 +114,7 @@ class GetAttr(Expr):
self.inputs = [ self.inputs = [
module, module,
] ]
module.users.append(self)
self.name = name self.name = name
node_cls = type if type else Node node_cls = type if type else Node
self.outputs = [ self.outputs = [
...@@ -134,12 +139,20 @@ class GetAttr(Expr): ...@@ -134,12 +139,20 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr): class CallMethod(Expr):
def __init__(self, module, method="__call__"): def __init__(self, node, method="__call__"):
assert isinstance(module, (TensorNode, ModuleNode)) if isinstance(node, type):
self.inputs = [ assert issubclass(node, Tensor)
module, cls = Parameter if issubclass(node, Parameter) else Tensor
]
self.const_val = [] self.inputs = []
self.const_val = [(0, cls)]
else:
assert isinstance(node, (TensorNode, ModuleNode))
node.users.append(self)
self.inputs = [
node,
]
self.const_val = []
self.method = method self.method = method
@classmethod @classmethod
...@@ -160,10 +173,13 @@ class CallMethod(Expr): ...@@ -160,10 +173,13 @@ class CallMethod(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs) args, kwargs = self.unflatten_args(inputs)
obj = args[0] obj = args[0]
args = args[1:] meth = getattr(obj, self.method)
if inspect.ismethod(meth):
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs) outputs = getattr(obj, self.method)(*args, **kwargs)
if isinstance(outputs, RawTensor): if outputs is None:
outputs = (outputs,) return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs return outputs
def __repr__(self): def __repr__(self):
...@@ -171,7 +187,7 @@ class CallMethod(Expr): ...@@ -171,7 +187,7 @@ class CallMethod(Expr):
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format( return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.inputs[0], self.args[0],
self.method, self.method,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )
...@@ -209,9 +225,8 @@ class Apply(Expr): ...@@ -209,9 +225,8 @@ class Apply(Expr):
if node is None: # capture as constant if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i)) NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef) apply_node = cls.make(opdef)
for i in inputs: apply_node.add_inputs(inputs)
assert isinstance(i, RawTensor) assert not apply_node.const_val
apply_node.inputs.append(NodeMixin.get(i))
unset_module_tracing() unset_module_tracing()
outputs = apply(opdef, *inputs) outputs = apply(opdef, *inputs)
...@@ -283,7 +298,7 @@ class Constant(Expr): ...@@ -283,7 +298,7 @@ class Constant(Expr):
return (self.value,) return (self.value,)
def __repr__(self): def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], self.value) return "{} = Constant({})".format(self.outputs[0], type(self.value))
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
......
...@@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [ ...@@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [
"min", "min",
"max", "max",
"mean", "mean",
"__getitem__",
"__setitem__",
] ]
...@@ -176,7 +178,8 @@ class Patcher: ...@@ -176,7 +178,8 @@ class Patcher:
self.patch_module(module) self.patch_module(module)
for meth in BUILTIN_ARRAY_METHOD: for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
self.patch_method(Tensor, "detach", self.wrap_fn)
self.patch_method(Tensor, "__new__", self.wrap_fn)
for i, j in self._builtin_functions: for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids: if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn) self.patch_function(i, j, self.wrap_fn)
...@@ -203,7 +206,13 @@ class Patcher: ...@@ -203,7 +206,13 @@ class Patcher:
import inspect import inspect
if id(module.__dict__) not in self.visited_frames_ids: if id(module.__dict__) not in self.visited_frames_ids:
for k, v in module.__dict__.items(): keys = (
getattr(module, "__all__")
if hasattr(module, "__all__")
else module.__dict__.keys()
)
for k in keys:
v = getattr(module, k)
if inspect.isfunction(v) and not k.startswith("_"): if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(module.__dict__, k, self.wrap_fn) self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__)) self.visited_frames_ids.add(id(module.__dict__))
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Any, Dict, Tuple, Type from typing import Any, Dict, List, Tuple, Type
import numpy import numpy
...@@ -31,6 +31,7 @@ class Node: ...@@ -31,6 +31,7 @@ class Node:
def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
self.expr = expr self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id self._id = Node.__total_id
Node.__total_id += 1 Node.__total_id += 1
self._name = name self._name = name
...@@ -59,11 +60,13 @@ class ModuleNode(Node): ...@@ -59,11 +60,13 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
attr_type_map = None # type: Dict[str, Type[Any]] attr_type_map = None # type: Dict[str, Type[Any]]
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
argdef_outdef_map = None # type: Dict[Treedef, Treedef]
def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name) super().__init__(expr, name)
self.attr_type_map = {} self.attr_type_map = {}
self.argdef_graph_map = {} self.argdef_graph_map = {}
self.argdef_outdef_map = {}
def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
import collections import collections
from typing import Callable, NamedTuple from typing import Callable, NamedTuple
import numpy as np
SUPPORTED_TYPE = {} SUPPORTED_TYPE = {}
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
...@@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data): ...@@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data):
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten) register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type( register_supported_type(
slice, slice,
...@@ -52,7 +54,10 @@ def tree_flatten( ...@@ -52,7 +54,10 @@ def tree_flatten(
assert is_leaf(values) assert is_leaf(values)
node = LeafDef(leaf_type(values)) node = LeafDef(leaf_type(values))
if is_const_leaf(values): if is_const_leaf(values):
node.const_val = values if isinstance(values, np.ndarray):
node.const_val = str(values)
else:
node.const_val = values
return [values,], node return [values,], node
rst = [] rst = []
......
...@@ -10,8 +10,13 @@ import collections ...@@ -10,8 +10,13 @@ import collections
import copy import copy
import functools import functools
from inspect import getmembers, isclass, ismethod from inspect import getmembers, isclass, ismethod
from typing import Dict, List, Type from typing import Callable, Dict, Iterable, List, Sequence, Type
import numpy as np
from numpy.lib.arraysetops import isin
from ... import functional as F
from ... import get_logger
from ... import module as M from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import ( from ...core._imperative_rt.core2 import (
...@@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import ( ...@@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import (
set_module_tracing, set_module_tracing,
unset_module_tracing, unset_module_tracing,
) )
from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
...@@ -32,6 +38,8 @@ from .module_tracer import ( ...@@ -32,6 +38,8 @@ from .module_tracer import (
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten from .pytree import tree_flatten
logger = get_logger(__name__)
def _leaf_type(node): def _leaf_type(node):
if isinstance(node, RawTensor): if isinstance(node, RawTensor):
...@@ -42,6 +50,11 @@ def _leaf_type(node): ...@@ -42,6 +50,11 @@ def _leaf_type(node):
return type(node) return type(node)
def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
return isinstance(node, RawTensor)
def _is_const_leaf(node): def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)): if isinstance(node, (RawTensor, NodeMixin, Module)):
return False return False
...@@ -80,7 +93,13 @@ class InternalGraph: ...@@ -80,7 +93,13 @@ class InternalGraph:
@property @property
def exprs(self): def exprs(self):
return _expr_list(self) return ExprFilter(_expr_iter(self))
def get_call_function(self, func: Callable = None):
return self.exprs.call_function(func)
def get_call_method(self, method: str = None):
return self.exprs.call_method(method)
def add_input(self, i): def add_input(self, i):
self._inputs.append(i) self._inputs.append(i)
...@@ -88,16 +107,131 @@ class InternalGraph: ...@@ -88,16 +107,131 @@ class InternalGraph:
def add_output(self, o): def add_output(self, o):
self._outputs.append(o) self._outputs.append(o)
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
nodes = (nodes,)
ret = list()
queue = list(nodes)
while queue:
node = queue.pop()
expr = node.expr
if expr not in ret:
ret.append(expr)
for i in expr.inputs:
if i not in queue:
queue.append(i)
return ret
def insert_call_function(self, func: Callable, nodes: Sequence[Node]):
if not isinstance(nodes, Sequence):
nodes = [nodes]
assert isinstance(func, Callable)
for i in nodes:
assert isinstance(
i, TensorNode
), "CallFunction only accept TensorNode as inputs"
expr = CallFunction(func)
expr.inputs = nodes
for i in nodes:
i.users.append(expr)
idx = max(self._exprs.index(i.expr) for i in nodes) + 1
self._exprs.insert(idx, expr)
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes)
fake_out_val = func(*fake_inp_val)
def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node
out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
return out_nodes
def insert_call_method(self, target, method, args):
if not isinstance(args, Sequence):
args = [args]
assert isinstance(target, (TensorNode, ModuleNode))
assert isinstance(method, str)
for i in args:
assert isinstance(i, TensorNode)
expr = CallMethod(method)
expr.inputs = [target, *args]
if isinstance(target, TensorNode):
fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype)
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args)
fake_out_val = getattr(fake_target_val, method)(fake_inp_val)
def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node
out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
else:
raise NotImplementedError()
return out_nodes
def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
# check graph inputs and outputs
assert node not in self.inputs, "Cannot replace inputs"
for i, n in enumerate(self.outputs):
if n is node:
self.outputs[i] = repl_node
# update users of node and repl_node
# update inputs of expr in node.users
dep_exprs = self.get_dep_exprs(repl_node)
i = 0
while i < len(node.users):
n = node.users[i]
if n in dep_exprs:
logger.info("Find a loop: ignore this replacement once")
logger.info("node: %s" % node.__repr__())
logger.info("repl_node: %s" % repl_node.__repr__())
i += 1
continue
repl_node.users.append(n)
node.users.pop(i)
idx = n.inputs.index(node)
n.inputs[idx] = repl_node
def compile(self):
"""
Delete unused expr.
"""
dep_exprs = self.get_dep_exprs(self.outputs)
i = 0
while i < len(self._exprs):
expr = self._exprs[i]
if expr in dep_exprs:
i += 1
continue
for n in expr.inputs:
n.users.remove(expr)
self._exprs.remove(expr)
def interpret(self, *inputs): def interpret(self, *inputs):
# TODO: support kwargs ?
# TODO: skip expressions which are independent and have no side effect
node2value = {} node2value = {}
for n, v in zip(self._inputs, inputs): for n, v in zip(self._inputs, inputs):
node2value[n] = v node2value[n] = v
for expr in self._exprs: for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) values = expr.interpret(*list(node2value[i] for i in expr.inputs))
for n, v in zip(expr.outputs, values): if values is not None:
node2value[n] = v for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs) return list(node2value[i] for i in self._outputs)
def __repr__(self): def __repr__(self):
...@@ -109,7 +243,8 @@ class InternalGraph: ...@@ -109,7 +243,8 @@ class InternalGraph:
def _get_meth_name(obj, func): def _get_meth_name(obj, func):
for cls in type(obj).mro(): tp = obj if isinstance(obj, type) else type(obj)
for cls in tp.mro():
for k, v in cls.__dict__.items(): for k, v in cls.__dict__.items():
if v == func: if v == func:
return k return k
...@@ -131,15 +266,31 @@ def _wrapped_function(orig_func): ...@@ -131,15 +266,31 @@ def _wrapped_function(orig_func):
meth_name = _get_meth_name(args[0], wrapped_fn) meth_name = _get_meth_name(args[0], wrapped_fn)
if meth_name: if meth_name:
self = inputs[0] self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), meth_name) if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]):
# only trace Tensor.__new__() when there are tensors in args
set_module_tracing()
return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1])
inputs[1] = copy.copy(inputs[1])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
NodeMixin.wrap_safe(inputs[1], node)
args, kwargs = tree_def.unflatten(inputs)
call_node = CallMethod.make(self, meth_name)
else:
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
call_node.add_inputs(inputs[1:])
else: else:
call_node = CallFunction.make(orig_func) call_node = CallFunction.make(orig_func)
call_node.add_inputs(inputs)
call_node.add_inputs(inputs)
call_node.arg_def = tree_def call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs) outputs = orig_func(*args, **kwargs)
call_node.add_outputs(outputs) if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
set_module_tracing() set_module_tracing()
return outputs return outputs
return orig_func(*args, **kwargs) return orig_func(*args, **kwargs)
...@@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin): ...@@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin):
mark_constant(i) mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self)) callnode = CallMethod.make(NodeMixin.get(self))
callnode.add_inputs(inputs) callnode.add_inputs(inputs[1:])
callnode.arg_def = tree_def callnode.arg_def = tree_def
if self._is_builtin: if self._is_builtin:
unset_module_tracing() unset_module_tracing()
outputs = self._mod(*args, **kwargs) rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
set_module_tracing() set_module_tracing()
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
...@@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin): ...@@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin):
NodeMixin.wrap_safe( NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self)) self, Input.make("self", NodeMixin.get_wrapped_type(self))
) )
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x): def wrap(x):
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap( NodeMixin.wrap(
wrapped, x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
) )
return wrapped return x
args = [self] args = [self]
for i in inputs[1:]: for i in inputs[1:]:
...@@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin): ...@@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().patcher.auto_patch( active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
) )
outputs = type(self._mod).forward(*args, **kwargs) rst = type(self._mod).forward(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
for i in ( for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
): ):
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.wrap_safe(self, orig_self) NodeMixin.wrap_safe(self, orig_self)
for arg, node in zip(inputs[1:], origin_inp_node):
if node:
NodeMixin.wrap_safe(arg, node)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
# rebind output to outer graph # rebind output to outer graph
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
self_node = NodeMixin.get(self) self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body self_node.argdef_graph_map[callnode.arg_def] = self._body
return outputs self_node.argdef_outdef_map[callnode.arg_def] = out_def
return rst
def __getattr__(self, name): def __getattr__(self, name):
if name not in self._mod.__dict__: if name not in self._mod.__dict__:
...@@ -268,20 +423,29 @@ class TracedModuleBuilder(NodeMixin): ...@@ -268,20 +423,29 @@ class TracedModuleBuilder(NodeMixin):
return super().__getattribute__(name) return super().__getattribute__(name)
else: else:
wrapped = super().__getattribute__(name) wrapped = super().__getattribute__(name)
if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None): if name in self._mod.__dict__:
assert not self._is_builtin if not NodeMixin.get(wrapped, None):
NodeMixin.wrap( assert not self._is_builtin
wrapped, NodeMixin.wrap(
lambda: GetAttr.make( wrapped,
lambda: GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
),
)
else:
node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self), NodeMixin.get(self),
name, name,
type=NodeMixin.get_wrapped_type(wrapped), type=NodeMixin.get_wrapped_type(wrapped),
), ).expr
) expr.outputs[0] = node
return wrapped return wrapped
class _expr_list: class _expr_iter:
def __init__(self, graph: InternalGraph): def __init__(self, graph: InternalGraph):
self.graph = graph self.graph = graph
...@@ -295,6 +459,59 @@ class _expr_list: ...@@ -295,6 +459,59 @@ class _expr_list:
yield expr yield expr
class ExprFilter:
def __init__(self, expr_iter: Iterable):
self._iter = expr_iter
def __iter__(self):
return iter(self._iter)
def call_function(self, func):
return ExprFilterCallFunction(self, func)
def call_method(self, method):
return ExprFilterCallMethod(self, method)
def as_list(self):
return list(self)
def as_dict(self):
raise NotImplementedError("need key")
def as_unique(self):
(expr,) = self
return expr
def as_count(self):
return sum(1 for _ in self)
class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter)
self.func = func
def __iter__(self):
for i in self._iter:
if not isinstance(i, CallFunction):
continue
if self.func is None or i.func == self.func:
yield i
class ExprFilterCallMethod(ExprFilter):
def __init__(self, expr_iter, method: str = None):
super().__init__(expr_iter)
self.method = method
def __iter__(self):
for i in self._iter:
if not isinstance(i, CallMethod):
continue
if self.method is None or i.method == self.method:
yield i
class TracedModule(Module): class TracedModule(Module):
""" """
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
...@@ -312,10 +529,12 @@ class TracedModule(Module): ...@@ -312,10 +529,12 @@ class TracedModule(Module):
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
) )
assert treedef in self.m_node.argdef_graph_map assert treedef in self.m_node.argdef_graph_map
inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))] inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs) outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
if len(outputs) == 1: out_def = self.m_node.argdef_outdef_map[treedef]
return outputs[0] outputs = out_def.unflatten(outputs)
return outputs return outputs
@property @property
...@@ -339,9 +558,8 @@ class TracedModule(Module): ...@@ -339,9 +558,8 @@ class TracedModule(Module):
if graph is None: if graph is None:
assert not isinstance(module, TracedModule) assert not isinstance(module, TracedModule)
const = Constant(module) const = Constant(module)
modulenode = const.outputs[0] const.outputs[0] = call.inputs[0]
modulenode.module_type = type(module) const.outputs[0].expr = const
call.inputs[0] = modulenode
return [const, call] return [const, call]
exprs = [] exprs = []
for expr in graph._exprs: for expr in graph._exprs:
...@@ -350,30 +568,41 @@ class TracedModule(Module): ...@@ -350,30 +568,41 @@ class TracedModule(Module):
if call and inp in graph._inputs: if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp) inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx] expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr # replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs): for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs: if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp) oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx] expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr
if isinstance(expr, GetAttr): if isinstance(expr, GetAttr):
# replace GetAttr with Constant # replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode): if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name)) const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs const.outputs = expr.outputs
const.outputs[0].expr = const
exprs.append(const) exprs.append(const)
elif isinstance(expr, CallMethod): elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0] obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode): if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr) pre_expr = expr.inputs[0].expr
(obj,) = expr.inputs[0].expr.interpret(module) if isinstance(pre_expr, GetAttr):
exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) (obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
else:
# module has been replaced.
assert isinstance(pre_expr, Constant)
else: else:
exprs.append(expr) exprs.append(expr)
else: else:
exprs.append(expr) exprs.append(expr)
if call is not None:
for i in call.inputs:
i.users.remove(call)
return exprs return exprs
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
...@@ -422,22 +651,26 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: ...@@ -422,22 +651,26 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
""" """
assert active_module_tracer() is None assert active_module_tracer() is None
try: try:
use_sym_shape = set_symbolic_shape(True)
set_module_tracing() set_module_tracing()
set_active_module_tracer(module_tracer(_wrapped_function)) set_active_module_tracer(module_tracer(_wrapped_function))
with active_module_tracer().patcher: with active_module_tracer().patcher:
global_scope = InternalGraph() global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod, True) builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
NodeMixin.wrap_safe( if isinstance(i, RawTensor):
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) NodeMixin.wrap_safe(
) i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
)
builder(*args, **kwargs) builder(*args, **kwargs)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
return builder.build() return builder.build()
finally: finally:
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None) set_active_module_tracer(None)
unset_module_tracing() unset_module_tracing()
import io
import pickle
import numpy as np
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
set_symbolic_shape(True)
class Main(M.Module):
def forward(self, x):
return x
class PreProcess(M.Module):
def __init__(self):
super().__init__()
self.I = F.ones((1,))
self.M = F.zeros((1,))
def forward(self, data, idx, roi):
N, H, W, C = data.shape
xmax = roi[:, 1, 0]
xmin = roi[:, 0, 0]
ymax = roi[:, 1, 1]
ymin = roi[:, 0, 1]
scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H)
I = F.broadcast_to(self.I, (N,))
M = F.broadcast_to(self.M, (N, 3, 3))
M[:, 0, 0] = scale
M[:, 0, 2] = xmin
M[:, 1, 1] = scale
M[:, 1, 2] = ymin
M[:, 2, 2] = I
resized = (
F.warp_perspective(
data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC"
)
.transpose(0, 3, 1, 2)
.astype(np.float32)
)
return resized
class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.pre_process = PreProcess()
self.traced_module = traced_module
def forward(self, data, idx, roi):
x = self.pre_process(data, idx, roi)
x = self.traced_module(x)
return x
def test_preprocess():
module = Main()
data = F.ones((1, 14, 8, 8), dtype=np.uint8)
traced_module = trace_module(module, data)
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
module = Net(traced_module)
module.eval()
idx = F.zeros((1,), dtype=np.int32)
roi = F.ones((1, 2, 2), dtype=np.float32)
y = module(data, idx, roi)
traced_module = trace_module(module, data, idx, roi)
np.testing.assert_array_equal(traced_module(data, idx, roi), y)
func = trace(traced_module, capture_as_const=True)
np.testing.assert_array_equal(func(data, idx, roi), y)
model = io.BytesIO()
func.dump(model, arg_names=("data", "idx", "roi"))
model.seek(0)
infer_cg = cgtools.GraphInference(model)
np.testing.assert_allclose(
list(
infer_cg.run(
inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()}
).values()
)[0],
y,
atol=1e-6,
)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine.experimental.traced_module import trace_module
from megengine.experimental.traced_module.expr import CallFunction, GetAttr
class MyBlock(M.Module):
def __init__(self, in_channels=3, channels=3):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x
class MyModule(M.Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock()
self.block1 = MyBlock()
def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x
def _init_cls(cls):
module = cls()
x = F.ones((1, 3, 3, 3))
y = module(x)
traced_module = trace_module(module, x)
return traced_module, x, y
def _init_block():
return _init_cls(MyBlock)
def _init_module():
return _init_cls(MyModule)
def test_search():
traced_module, *_ = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
def test_insert():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_node = graph.get_call_function(F.relu).as_unique().outputs
neg_node = graph.insert_call_function(F.neg, relu_node)
graph.replace_node({relu_node[0]: neg_node[0]})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
def test_delete():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
node = relu_expr.outputs
repl_node = relu_expr.inputs
graph.replace_node({node[0]: repl_node[0]})
graph.compile()
np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)
def test_flatten():
traced_module, x, expect = _init_module()
traced_module = traced_module.flatten()
traced_module.graph.compile()
assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
assert len(traced_module.graph._exprs) == 12
def test_extra_block():
class PostProcess(M.Module):
def forward(self, x):
return x * 2
class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.post_process = PostProcess()
self.traced_module = traced_module
def forward(self, x):
x = self.traced_module(x)
x = self.post_process(x)
return x
traced_module, x, expect = _init_block()
module = Net(traced_module)
np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
traced_module = trace_module(module, x)
np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module
class MyBlock(Module):
def __init__(self, in_channels, channels):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x
class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock(8, 4)
self.block1 = MyBlock(4, 2)
def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x
def test_dump_and_load():
module = MyModule()
x = Tensor(np.ones((1, 8, 14, 14)))
expect = module(x)
traced_module = trace_module(module, x)
np.testing.assert_array_equal(expect, traced_module(x))
obj = pickle.dumps(traced_module)
pickle.loads(obj)
np.testing.assert_array_equal(expect, traced_module(x))
import numpy as np
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module as M
class MyModule1(M):
def forward(self, x):
y = Tensor(x)
y += 1
x = x + 2
return x, y
class MyModule2(M):
def forward(self, x):
y = Tensor([1, x, 1])
y += 1
x = x + 2
return x, y
def test_trace_module():
x = Tensor(1)
m1 = MyModule1()
tm1 = trace_module(m1, x)
m2 = MyModule2()
tm2 = trace_module(m2, x)
inp = Tensor(2)
gt = m1(inp)
output = tm1(inp)
for a, b in zip(output, gt):
np.testing.assert_equal(a.numpy(), b.numpy())
gt1 = m2(inp)
output1 = tm2(inp)
for a, b in zip(output1, gt1):
np.testing.assert_equal(a.numpy(), b.numpy())
import io
import pickle
import numpy as np
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
set_symbolic_shape(True)
class Main(M.Module):
def forward(self, x):
return x["data"]
class PreProcess(M.Module):
def __init__(self):
super().__init__()
self.A = F.zeros((1,))
self.I = F.ones((1,))
self.bb_out = mge.tensor(
np.array([[[0, 0], [160, 0], [160, 48], [0, 48]]], dtype="float32")
)
def forward(self, data, quad):
"""
data: (1, 3, 48, 160)
quad: (1, 4, 2)
"""
N = quad.shape[0]
dst = F.repeat(self.bb_out, N, axis=0).reshape(-1, 4, 2)
I = F.broadcast_to(self.I, quad.shape)
A = F.broadcast_to(self.A, (N, 8, 8))
A[:, 0:4, 0:2] = quad
A[:, 4:8, 5:6] = I[:, :, 0:1]
A[:, 0:4, 6:8] = -quad * dst[:, :, 0:1]
A[:, 4:8, 3:5] = quad
A[:, 0:4, 2:3] = I[:, :, 0:1]
A[:, 4:8, 6:8] = -quad * dst[:, :, 1:2]
B = dst.transpose(0, 2, 1).reshape(-1, 8, 1)
M = F.concat([F.matmul(F.matinv(A), B)[:, :, 0], I[:, 0:1, 0]], axis=1).reshape(
-1, 3, 3
)
new_data = F.warp_perspective(data, M, (48, 160)) # (N, 3, 48, 160)
return {"data": new_data}
class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.pre_process = PreProcess()
self.traced_module = traced_module
def forward(self, data, quad):
x = self.pre_process(data, quad)
x = self.traced_module(x)
return x
def test_preprocess():
batch_size = 2
module = Main()
data = mge.tensor(
np.random.randint(0, 256, size=(batch_size, 3, 48, 160)), dtype=np.float32
)
traced_module = trace_module(module, {"data": data})
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
module = Net(traced_module)
module.eval()
quad = mge.tensor(np.random.normal(size=(batch_size, 4, 2)), dtype=np.float32)
expect = module(data, quad)
traced_module = trace_module(module, data, quad)
actual = traced_module(data, quad)
for i, j in zip(expect, actual):
np.testing.assert_array_equal(i, j)
func = trace(traced_module, capture_as_const=True)
actual = func(data, quad)
for i, j in zip(expect, actual):
np.testing.assert_array_equal(i, j)
model = io.BytesIO()
func.dump(model, arg_names=("data", "quad"))
model.seek(0)
infer_cg = cgtools.GraphInference(model)
actual = list(
infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values()
)[0]
np.testing.assert_allclose(expect, actual)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册