提交 e918f0aa 编写于 作者: M Megvii Engine Team

feat(traced_module): add treedef leaf node check and add some graph api

GitOrigin-RevId: 36c069bfee1905b9c390337125e5c0470a79d55e
上级 c7e730bc
......@@ -15,7 +15,6 @@ import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module
from ...tensor import Tensor
from .pytree import TreeDef
class Node:
......@@ -102,6 +101,8 @@ class TensorNode(Node):
shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype
qparam = None
device = None
def __repr__(self):
if self._name is None:
......@@ -109,6 +110,17 @@ class TensorNode(Node):
else:
return "%{}_{}(Tensor)".format(self._id, self._name)
def __getstate__(self):
return {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"qparam": self.qparam,
"shape": self.shape,
"dtype": self.dtype,
"device": self.device,
}
class NodeMixin(abc.ABC):
__node = None
......@@ -118,15 +130,25 @@ class NodeMixin(abc.ABC):
# record the nodes which had been bound to this NodeMixin
pass
@classmethod
def _record_tensornode_property(cls, node, value):
assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor)
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
node.device = value.device
if hasattr(value, "_qparams") and value._qparams is not None:
node.qparams = value.qparams
@classmethod
def wrap(cls, value, node):
if isinstance(value, (NodeMixin, RawTensor)):
if isinstance(node, Node):
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
cls._record_tensornode_property(node, value)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
setattr(value, "_NodeMixin__node", node)
......@@ -135,10 +157,7 @@ class NodeMixin(abc.ABC):
n = node()
assert isinstance(n, Node)
if isinstance(value, RawTensor):
n.dtype = value.dtype
n.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
cls._record_tensornode_property(n, value)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(n)
setattr(value, "_NodeMixin__node", n)
......@@ -147,10 +166,7 @@ class NodeMixin(abc.ABC):
def wrap_safe(cls, value, node):
assert isinstance(value, (NodeMixin, RawTensor))
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
cls._record_tensornode_property(node, value)
setattr(value, "_NodeMixin__node", node)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
......
......@@ -13,13 +13,45 @@ from typing import Callable, NamedTuple
import numpy as np
from ...core._imperative_rt.common import CompNode
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._wrap import Device
from ...core.tensor.dtype import QuantDtypeMeta
from ...module import Module
from ...quantization.utils import LSQParams, QParams, QuantMode
from ...tensor import Parameter, Tensor
from .node import ModuleNode, Node, NodeMixin, TensorNode
SUPPORTED_TYPE = {}
# if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree
SUPPORTED_LEAF_TYPE = {
RawTensor,
Tensor,
Parameter,
str,
int,
float,
bool,
QuantDtypeMeta,
CompNode,
Device,
type(None),
type(Ellipsis),
QuantMode,
}
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number]
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
def register_supported_type(type, flatten, unflatten):
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
def register_supported_type(type, flatten=None, unflatten=None):
if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
else:
SUPPORTED_LEAF_CLS.append(type)
def _dict_flatten(inp):
......@@ -48,6 +80,22 @@ def _ordereddict_unflatten(inps, aux_data):
return OrderedDict(zip(aux_data, inps))
def qparams_flatten(inp):
aux_data = []
results = []
for key in inp.__slots__:
aux_data.append(key)
results.append(getattr(inp, key, None))
return results, tuple(aux_data)
def qparams_unflatten(inp, aux_data):
obj = QParams.__new__(QParams)
for k, v in zip(aux_data, inp):
setattr(obj, k, v)
return obj
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten)
......@@ -60,15 +108,40 @@ register_supported_type(
lambda x, aux_data: slice(x[0], x[1], x[2]),
)
register_supported_type(QParams, qparams_flatten, qparams_unflatten)
def _is_leaf(obj):
if isinstance(obj, type):
return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE
return (
isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE
)
def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)
def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
return True
def tree_flatten(
values,
leaf_type: Callable = lambda x: type(x),
is_leaf: Callable = lambda _: True,
is_const_leaf: Callable = lambda _: False,
leaf_type: Callable = _leaf_type,
is_leaf: Callable = _is_leaf,
is_const_leaf: Callable = _is_const_leaf,
):
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values)
assert is_leaf(values), values
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
if isinstance(values, np.ndarray):
......
......@@ -26,6 +26,12 @@ from ...core._imperative_rt.core2 import (
from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize
from ...quantization.observer import (
ExponentialMovingAverageObserver,
MinMaxObserver,
SyncMinMaxObserver,
)
from ...tensor import Tensor
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .module_tracer import (
......@@ -40,15 +46,6 @@ from .pytree import tree_flatten
logger = get_logger(__name__)
def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module, ModuleNode)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)
def _is_leaf(node):
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
......@@ -56,20 +53,10 @@ def _is_leaf(node):
return isinstance(node, RawTensor)
def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
return True
def wrap_tensors(tensors: Tensor, nodes: TensorNode):
inp_tensors = copy.deepcopy(tensors)
inp_tensors, inp_def_v = tree_flatten(
inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_nodes, inp_def_n = tree_flatten(
nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_tensors, inp_def_v = tree_flatten(inp_tensors)
inp_nodes, inp_def_n = tree_flatten(nodes)
for v, n in zip(inp_tensors, inp_nodes):
if isinstance(n, TensorNode) and isinstance(v, Tensor):
NodeMixin.wrap_safe(v, n)
......@@ -124,6 +111,9 @@ class InternalGraph:
self._exprs = []
self._inputs = []
self._outputs = []
self._watch_point = []
self._end_point = []
self._rst = collections.defaultdict(list)
def insert(self, expr):
self._exprs.append(expr)
......@@ -177,6 +167,7 @@ class InternalGraph:
for idx, i in enumerate(self._inputs):
if i in repl_dict:
self._inputs[idx] = repl_dict[i]
for idx, o in enumerate(self._outputs):
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
......@@ -224,11 +215,7 @@ class InternalGraph:
moudle = forma_mnode.owner
assert moudle._is_top, "reset_inputs only support the top-level graph"
inputs, tree_def = tree_flatten(
((moudle, *args), kwargs),
leaf_type=_leaf_type,
is_const_leaf=_is_const_leaf,
)
inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
def create_node(val: Tensor):
node = Input(type=TensorNode).outputs[0]
......@@ -302,7 +289,6 @@ class InternalGraph:
formal_inp_node = create_node(True)
inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
self._inputs[:] = inputs[:]
......@@ -313,7 +299,6 @@ class InternalGraph:
args = args + (create_node(False),)
inputs, tree_def = tree_flatten(
(args, kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
e.inputs[:] = inputs[:]
......@@ -328,7 +313,7 @@ class InternalGraph:
def reset_outputs(self, outputs):
outputs, out_def = tree_flatten(
outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode),
outputs, is_leaf=lambda x: isinstance(x, TensorNode),
)
forma_mnode = self.inputs[0]
......@@ -393,9 +378,7 @@ class InternalGraph:
org_out_def = moudle.argdef_outdef_map[tree_def]
org_outs = org_out_def.unflatten(self._outputs)
outputs, out_def = tree_flatten(
(org_outs, node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
(org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
)
self._outputs[:] = outputs
......@@ -404,9 +387,7 @@ class InternalGraph:
actual_node = create_node(node, e)
org_outs = org_out_def.unflatten(e.outputs)
outputs, out_def = tree_flatten(
(org_outs, actual_node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
(org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
)
e.outputs[:] = outputs
e.out_def = out_def
......@@ -419,9 +400,7 @@ class InternalGraph:
def insert_function(self, func: Callable, *args, **kwargs):
assert isinstance(func, Callable)
inp_nodes, inp_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_nodes, inp_def = tree_flatten((args, kwargs))
insert_idx = -1
for i in inp_nodes:
......@@ -449,7 +428,7 @@ class InternalGraph:
if rst is None:
return None
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
node_outputs = []
for out in outputs:
assert isinstance(out, RawTensor)
......@@ -510,15 +489,40 @@ class InternalGraph:
def interpret(self, *inputs):
node2value = {}
end_nodes_set = set(self._end_point)
endnode2value = {}
def get_all_endnode_val(n, v):
if n in end_nodes_set:
endnode2value[n] = v
end_nodes_set.remove(n)
return not end_nodes_set
return False
for n, v in zip(self._inputs, inputs):
node2value[n] = v
if n in self._watch_point:
self._rst[n].append(v)
if n in self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
if values is not None:
for n, v in zip(expr.outputs, values):
node2value[n] = v
if n in self._watch_point:
self._rst[n] = v
if self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)
return list(node2value[i] for i in self._outputs)
def eval(self, *inputs):
assert len(inputs) == len(self._inputs) - 1
inp = [self._inputs[0].owner] + list(inputs)
return self.interpret(*inp)
def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
", ".join(str(i) for i in self._inputs),
......@@ -541,9 +545,7 @@ def _wrapped_function(orig_func):
def wrapped_fn(*args, **kwargs):
if is_tracing_module():
unset_module_tracing()
inputs, tree_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inputs, tree_def = tree_flatten((args, kwargs))
for i in inputs:
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
......@@ -575,9 +577,7 @@ def _wrapped_function(orig_func):
if meth_name == "__setitem__":
rst = self
if rst is not None:
outputs, out_def = tree_flatten(
rst, leaf_type=_leaf_type, is_leaf=_is_leaf
)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
call_node.out_def = out_def
else:
outputs = None
......@@ -604,13 +604,17 @@ class TracedModuleBuilder(NodeMixin):
"_NodeMixin__node",
"_is_builtin",
"build",
"_record_wrapped_nodes",
"_argdef_graph_map",
"_argdef_outdef_map",
"nodes",
"__class__",
"__dict__",
]
def __init__(self, mod, is_top_module=False):
super(TracedModuleBuilder, self).__init__()
assert isinstance(mod, Module)
self._mod = mod
self._body = None
self._is_top = is_top_module
......@@ -618,6 +622,13 @@ class TracedModuleBuilder(NodeMixin):
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self.nodes = set()
# The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
# modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
self.__class__ = type(
"TracedModuleBuilder",
(TracedModuleBuilder, mod.__class__),
dict(TracedModuleBuilder.__dict__),
)
def build(self):
if self._is_builtin:
......@@ -631,8 +642,6 @@ class TracedModuleBuilder(NodeMixin):
)
for _, g in self._argdef_graph_map.items():
g.compile()
# for node in self.nodes:
# node._owner = weakref.ref(traced_module)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
......@@ -653,9 +662,7 @@ class TracedModuleBuilder(NodeMixin):
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))
inputs, tree_def = tree_flatten(
((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inputs, tree_def = tree_flatten(((self, *args), kwargs))
for i in inputs:
mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self))
......@@ -667,7 +674,7 @@ class TracedModuleBuilder(NodeMixin):
if self._is_builtin:
unset_module_tracing()
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
set_module_tracing()
if self._is_builtin:
self._body = None
......@@ -706,7 +713,7 @@ class TracedModuleBuilder(NodeMixin):
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
rst = type(self._mod).forward(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
......@@ -725,6 +732,12 @@ class TracedModuleBuilder(NodeMixin):
self._argdef_outdef_map[callnode.arg_def] = out_def
return rst
def __setattr__(self, name, value):
object.__setattr__(self, name, value)
def __repr__(self):
return repr(self._mod)
def __getattr__(self, name):
if name not in self._mod.__dict__:
attr = getattr(type(self._mod), name).__get__(self, type(self))
......@@ -743,11 +756,22 @@ class TracedModuleBuilder(NodeMixin):
def __getattribute__(self, name):
if name in TracedModuleBuilder.__builder_attributes__:
return super().__getattribute__(name)
return object.__getattribute__(self, name)
else:
wrapped = super().__getattribute__(name)
wrapped = object.__getattribute__(self, name)
if name in self._mod.__dict__:
assert not self._is_builtin
mod_attr = getattr(self._mod, name)
if not isinstance(mod_attr, Module) and wrapped is not mod_attr:
wrapped = mod_attr
setattr(self, name, wrapped)
if isinstance(mod_attr, Module):
assert mod_attr is wrapped._mod
else:
assert mod_attr is wrapped
# assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
wrapped,
......@@ -757,14 +781,6 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped),
),
)
"""
else:
node = NodeMixin.get(wrapped)
expr = node.expr
assert isinstance(expr, GetAttr)
if expr not in active_module_tracer().current_scope()._exprs:
active_module_tracer().current_scope().insert(expr)
"""
return wrapped
......@@ -924,20 +940,57 @@ class TracedModule(Module):
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map
self._is_top = is_top
self.watch_points = []
self.watch_node_value = {}
self.end_points = []
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
if self.watch_points:
self.watch_node_value = {}
for n in self.watch_points:
self.watch_node_value[n] = n.top_graph._rst.pop(n)
if self.end_points:
return outputs
out_def = self.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
return outputs
def set_watch_points(self, nodes):
if not isinstance(nodes, Sequence):
nodes = [nodes]
self.watch_points = nodes
for n in nodes:
n.top_graph._watch_point.append(n)
def clear_watch_points(self):
for n in self.watch_points:
n.top_graph._watch_point = []
self.watch_points = []
self.watch_node_value = {}
def set_end_points(self, nodes):
if not isinstance(nodes, Sequence):
nodes = [nodes]
self.end_points = nodes
graphs = list(self.argdef_graph_map.values())
for n in nodes:
assert n.top_graph in graphs
n.top_graph._end_point.append(n)
def clear_end_points(self):
for n in self.end_points:
n.top_graph._end_point = []
self.end_points = []
@property
def graph(self) -> InternalGraph:
if self._is_top:
......@@ -1014,6 +1067,9 @@ class TracedModule(Module):
node2obj[graph._inputs[0]] = module
if call:
node2obj[call.inputs[0]] = module
# replace inputs for submodule's exprx
if call:
repl_dict = dict(zip(graph._inputs, call.inputs))
for ind, out in enumerate(graph.outputs):
if isinstance(out.expr, Input):
......@@ -1028,8 +1084,8 @@ class TracedModule(Module):
repl_dict[out] = call.outputs[ind]
graph._replace_inputs_outputs(repl_dict)
for expr in graph._exprs:
for expr in graph._exprs:
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
......@@ -1129,6 +1185,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
param kwargs: the keyword arguments passed to forward method of ``mod``
"""
assert active_module_tracer() is None
assert isinstance(mod, Module)
try:
use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
......@@ -1140,9 +1197,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
assert isinstance(i, Tensor), "not support "
# assert isinstance(i, Tensor), "not support "
if isinstance(i, RawTensor):
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册