From e918f0aa7579c6988d99622d8f3946278375dbb3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 29 Jul 2021 17:48:46 +0800 Subject: [PATCH] feat(traced_module): add treedef leaf node check and add some graph api GitOrigin-RevId: 36c069bfee1905b9c390337125e5c0470a79d55e --- .../experimental/traced_module/node.py | 42 ++-- .../experimental/traced_module/pytree.py | 85 +++++++- .../traced_module/traced_module.py | 195 +++++++++++------- 3 files changed, 234 insertions(+), 88 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index 44506ead..ae7160ce 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -15,7 +15,6 @@ import numpy from ...core._imperative_rt.core2 import Tensor as RawTensor from ...module import Module from ...tensor import Tensor -from .pytree import TreeDef class Node: @@ -102,6 +101,8 @@ class TensorNode(Node): shape = None # type: Tuple[int] dtype = None # type: numpy.dtype + qparam = None + device = None def __repr__(self): if self._name is None: @@ -109,6 +110,17 @@ class TensorNode(Node): else: return "%{}_{}(Tensor)".format(self._id, self._name) + def __getstate__(self): + return { + "expr": self.expr, + "users": self.users, + "_id": self._id, + "qparam": self.qparam, + "shape": self.shape, + "dtype": self.dtype, + "device": self.device, + } + class NodeMixin(abc.ABC): __node = None @@ -118,15 +130,25 @@ class NodeMixin(abc.ABC): # record the nodes which had been bound to this NodeMixin pass + @classmethod + def _record_tensornode_property(cls, node, value): + assert isinstance(node, TensorNode) + assert isinstance(value, RawTensor) + if isinstance(value, RawTensor): + node.dtype = value.dtype + node.shape = ( + value._tuple_shape if isinstance(value, Tensor) else value.shape + ) + node.device = value.device + if hasattr(value, "_qparams") and value._qparams is not None: + node.qparams = value.qparams + @classmethod def wrap(cls, value, node): if isinstance(value, (NodeMixin, RawTensor)): if isinstance(node, Node): if isinstance(value, RawTensor): - node.dtype = value.dtype - node.shape = ( - value._tuple_shape if isinstance(value, Tensor) else value.shape - ) + cls._record_tensornode_property(node, value) if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) setattr(value, "_NodeMixin__node", node) @@ -135,10 +157,7 @@ class NodeMixin(abc.ABC): n = node() assert isinstance(n, Node) if isinstance(value, RawTensor): - n.dtype = value.dtype - n.shape = ( - value._tuple_shape if isinstance(value, Tensor) else value.shape - ) + cls._record_tensornode_property(n, value) if isinstance(value, NodeMixin): value._record_wrapped_nodes(n) setattr(value, "_NodeMixin__node", n) @@ -147,10 +166,7 @@ class NodeMixin(abc.ABC): def wrap_safe(cls, value, node): assert isinstance(value, (NodeMixin, RawTensor)) if isinstance(value, RawTensor): - node.dtype = value.dtype - node.shape = ( - value._tuple_shape if isinstance(value, Tensor) else value.shape - ) + cls._record_tensornode_property(node, value) setattr(value, "_NodeMixin__node", node) if isinstance(value, NodeMixin): value._record_wrapped_nodes(node) diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index 9ca05347..5abb92be 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -13,13 +13,45 @@ from typing import Callable, NamedTuple import numpy as np +from ...core._imperative_rt.common import CompNode +from ...core._imperative_rt.core2 import Tensor as RawTensor +from ...core._wrap import Device +from ...core.tensor.dtype import QuantDtypeMeta +from ...module import Module +from ...quantization.utils import LSQParams, QParams, QuantMode +from ...tensor import Parameter, Tensor +from .node import ModuleNode, Node, NodeMixin, TensorNode + SUPPORTED_TYPE = {} +# if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree +SUPPORTED_LEAF_TYPE = { + RawTensor, + Tensor, + Parameter, + str, + int, + float, + bool, + QuantDtypeMeta, + CompNode, + Device, + type(None), + type(Ellipsis), + QuantMode, +} + +# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree +SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number] + NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) -def register_supported_type(type, flatten, unflatten): - SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) +def register_supported_type(type, flatten=None, unflatten=None): + if flatten and unflatten: + SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) + else: + SUPPORTED_LEAF_CLS.append(type) def _dict_flatten(inp): @@ -48,6 +80,22 @@ def _ordereddict_unflatten(inps, aux_data): return OrderedDict(zip(aux_data, inps)) +def qparams_flatten(inp): + aux_data = [] + results = [] + for key in inp.__slots__: + aux_data.append(key) + results.append(getattr(inp, key, None)) + return results, tuple(aux_data) + + +def qparams_unflatten(inp, aux_data): + obj = QParams.__new__(QParams) + for k, v in zip(aux_data, inp): + setattr(obj, k, v) + return obj + + register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) register_supported_type(dict, _dict_flatten, _dict_unflatten) @@ -60,15 +108,40 @@ register_supported_type( lambda x, aux_data: slice(x[0], x[1], x[2]), ) +register_supported_type(QParams, qparams_flatten, qparams_unflatten) + + +def _is_leaf(obj): + if isinstance(obj, type): + return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE + return ( + isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE + ) + + +def _leaf_type(node): + if isinstance(node, (RawTensor, TensorNode)): + return (Tensor, TensorNode) + elif isinstance(node, (NodeMixin, Module)): + return (Module, ModuleNode, NodeMixin) + else: + return type(node) + + +def _is_const_leaf(node): + if isinstance(node, (RawTensor, NodeMixin, Module)): + return False + return True + def tree_flatten( values, - leaf_type: Callable = lambda x: type(x), - is_leaf: Callable = lambda _: True, - is_const_leaf: Callable = lambda _: False, + leaf_type: Callable = _leaf_type, + is_leaf: Callable = _is_leaf, + is_const_leaf: Callable = _is_const_leaf, ): if type(values) not in SUPPORTED_TYPE: - assert is_leaf(values) + assert is_leaf(values), values node = LeafDef(leaf_type(values)) if is_const_leaf(values): if isinstance(values, np.ndarray): diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 58d9d8e0..03089400 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -26,6 +26,12 @@ from ...core._imperative_rt.core2 import ( from ...core._trace_option import set_symbolic_shape from ...core.tensor.array_method import ArrayMethodMixin from ...module import Module +from ...quantization.fake_quant import LSQ, TQT, FakeQuantize +from ...quantization.observer import ( + ExponentialMovingAverageObserver, + MinMaxObserver, + SyncMinMaxObserver, +) from ...tensor import Tensor from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input from .module_tracer import ( @@ -40,15 +46,6 @@ from .pytree import tree_flatten logger = get_logger(__name__) -def _leaf_type(node): - if isinstance(node, (RawTensor, TensorNode)): - return (Tensor, TensorNode) - elif isinstance(node, (NodeMixin, Module, ModuleNode)): - return (Module, ModuleNode, NodeMixin) - else: - return type(node) - - def _is_leaf(node): assert isinstance(node, RawTensor), "doesn't support {} in return values".format( type(node) @@ -56,20 +53,10 @@ def _is_leaf(node): return isinstance(node, RawTensor) -def _is_const_leaf(node): - if isinstance(node, (RawTensor, NodeMixin, Module)): - return False - return True - - def wrap_tensors(tensors: Tensor, nodes: TensorNode): inp_tensors = copy.deepcopy(tensors) - inp_tensors, inp_def_v = tree_flatten( - inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf - ) - inp_nodes, inp_def_n = tree_flatten( - nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf - ) + inp_tensors, inp_def_v = tree_flatten(inp_tensors) + inp_nodes, inp_def_n = tree_flatten(nodes) for v, n in zip(inp_tensors, inp_nodes): if isinstance(n, TensorNode) and isinstance(v, Tensor): NodeMixin.wrap_safe(v, n) @@ -124,6 +111,9 @@ class InternalGraph: self._exprs = [] self._inputs = [] self._outputs = [] + self._watch_point = [] + self._end_point = [] + self._rst = collections.defaultdict(list) def insert(self, expr): self._exprs.append(expr) @@ -177,6 +167,7 @@ class InternalGraph: for idx, i in enumerate(self._inputs): if i in repl_dict: self._inputs[idx] = repl_dict[i] + for idx, o in enumerate(self._outputs): if o in repl_dict: self._outputs[idx] = repl_dict[o] @@ -224,11 +215,7 @@ class InternalGraph: moudle = forma_mnode.owner assert moudle._is_top, "reset_inputs only support the top-level graph" - inputs, tree_def = tree_flatten( - ((moudle, *args), kwargs), - leaf_type=_leaf_type, - is_const_leaf=_is_const_leaf, - ) + inputs, tree_def = tree_flatten(((moudle, *args), kwargs)) def create_node(val: Tensor): node = Input(type=TensorNode).outputs[0] @@ -302,7 +289,6 @@ class InternalGraph: formal_inp_node = create_node(True) inputs, tree_def = tree_flatten( ((*args, formal_inp_node), kwargs), - leaf_type=_leaf_type, is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), ) self._inputs[:] = inputs[:] @@ -313,7 +299,6 @@ class InternalGraph: args = args + (create_node(False),) inputs, tree_def = tree_flatten( (args, kwargs), - leaf_type=_leaf_type, is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), ) e.inputs[:] = inputs[:] @@ -328,7 +313,7 @@ class InternalGraph: def reset_outputs(self, outputs): outputs, out_def = tree_flatten( - outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode), + outputs, is_leaf=lambda x: isinstance(x, TensorNode), ) forma_mnode = self.inputs[0] @@ -393,9 +378,7 @@ class InternalGraph: org_out_def = moudle.argdef_outdef_map[tree_def] org_outs = org_out_def.unflatten(self._outputs) outputs, out_def = tree_flatten( - (org_outs, node), - leaf_type=_leaf_type, - is_leaf=lambda x: isinstance(x, TensorNode), + (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode), ) self._outputs[:] = outputs @@ -404,9 +387,7 @@ class InternalGraph: actual_node = create_node(node, e) org_outs = org_out_def.unflatten(e.outputs) outputs, out_def = tree_flatten( - (org_outs, actual_node), - leaf_type=_leaf_type, - is_leaf=lambda x: isinstance(x, TensorNode), + (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode), ) e.outputs[:] = outputs e.out_def = out_def @@ -419,9 +400,7 @@ class InternalGraph: def insert_function(self, func: Callable, *args, **kwargs): assert isinstance(func, Callable) - inp_nodes, inp_def = tree_flatten( - (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf - ) + inp_nodes, inp_def = tree_flatten((args, kwargs)) insert_idx = -1 for i in inp_nodes: @@ -449,7 +428,7 @@ class InternalGraph: if rst is None: return None - outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) + outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) node_outputs = [] for out in outputs: assert isinstance(out, RawTensor) @@ -510,15 +489,40 @@ class InternalGraph: def interpret(self, *inputs): node2value = {} + end_nodes_set = set(self._end_point) + endnode2value = {} + + def get_all_endnode_val(n, v): + if n in end_nodes_set: + endnode2value[n] = v + end_nodes_set.remove(n) + return not end_nodes_set + return False + for n, v in zip(self._inputs, inputs): node2value[n] = v + if n in self._watch_point: + self._rst[n].append(v) + if n in self._end_point and get_all_endnode_val(n, v): + return list(endnode2value[i] for i in self._end_point) + for expr in self._exprs: values = expr.interpret(*list(node2value[i] for i in expr.inputs)) if values is not None: for n, v in zip(expr.outputs, values): node2value[n] = v + if n in self._watch_point: + self._rst[n] = v + if self._end_point and get_all_endnode_val(n, v): + return list(endnode2value[i] for i in self._end_point) + return list(node2value[i] for i in self._outputs) + def eval(self, *inputs): + assert len(inputs) == len(self._inputs) - 1 + inp = [self._inputs[0].owner] + list(inputs) + return self.interpret(*inp) + def __repr__(self): return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( ", ".join(str(i) for i in self._inputs), @@ -541,9 +545,7 @@ def _wrapped_function(orig_func): def wrapped_fn(*args, **kwargs): if is_tracing_module(): unset_module_tracing() - inputs, tree_def = tree_flatten( - (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf - ) + inputs, tree_def = tree_flatten((args, kwargs)) for i in inputs: if not NodeMixin.get(i, None): if isinstance(i, (RawTensor, NodeMixin)): @@ -575,9 +577,7 @@ def _wrapped_function(orig_func): if meth_name == "__setitem__": rst = self if rst is not None: - outputs, out_def = tree_flatten( - rst, leaf_type=_leaf_type, is_leaf=_is_leaf - ) + outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) call_node.out_def = out_def else: outputs = None @@ -604,13 +604,17 @@ class TracedModuleBuilder(NodeMixin): "_NodeMixin__node", "_is_builtin", "build", + "_record_wrapped_nodes", "_argdef_graph_map", "_argdef_outdef_map", "nodes", + "__class__", + "__dict__", ] def __init__(self, mod, is_top_module=False): super(TracedModuleBuilder, self).__init__() + assert isinstance(mod, Module) self._mod = mod self._body = None self._is_top = is_top_module @@ -618,6 +622,13 @@ class TracedModuleBuilder(NodeMixin): self._argdef_graph_map = {} self._argdef_outdef_map = {} self.nodes = set() + # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__. + # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__. + self.__class__ = type( + "TracedModuleBuilder", + (TracedModuleBuilder, mod.__class__), + dict(TracedModuleBuilder.__dict__), + ) def build(self): if self._is_builtin: @@ -631,8 +642,6 @@ class TracedModuleBuilder(NodeMixin): ) for _, g in self._argdef_graph_map.items(): g.compile() - # for node in self.nodes: - # node._owner = weakref.ref(traced_module) for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: @@ -653,9 +662,7 @@ class TracedModuleBuilder(NodeMixin): if node is None: # capture as constant NodeMixin.wrap(x, lambda: Constant.make(x)) - inputs, tree_def = tree_flatten( - ((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf - ) + inputs, tree_def = tree_flatten(((self, *args), kwargs)) for i in inputs: mark_constant(i) callnode = CallMethod.make(NodeMixin.get(self)) @@ -667,7 +674,7 @@ class TracedModuleBuilder(NodeMixin): if self._is_builtin: unset_module_tracing() rst = self._mod(*args, **kwargs) - outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) + outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) set_module_tracing() if self._is_builtin: self._body = None @@ -706,7 +713,7 @@ class TracedModuleBuilder(NodeMixin): getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) ) rst = type(self._mod).forward(*args, **kwargs) - outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf) + outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) for i in ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ): @@ -725,6 +732,12 @@ class TracedModuleBuilder(NodeMixin): self._argdef_outdef_map[callnode.arg_def] = out_def return rst + def __setattr__(self, name, value): + object.__setattr__(self, name, value) + + def __repr__(self): + return repr(self._mod) + def __getattr__(self, name): if name not in self._mod.__dict__: attr = getattr(type(self._mod), name).__get__(self, type(self)) @@ -743,11 +756,22 @@ class TracedModuleBuilder(NodeMixin): def __getattribute__(self, name): if name in TracedModuleBuilder.__builder_attributes__: - return super().__getattribute__(name) + return object.__getattribute__(self, name) else: - wrapped = super().__getattribute__(name) + wrapped = object.__getattribute__(self, name) if name in self._mod.__dict__: - assert not self._is_builtin + mod_attr = getattr(self._mod, name) + + if not isinstance(mod_attr, Module) and wrapped is not mod_attr: + wrapped = mod_attr + setattr(self, name, wrapped) + + if isinstance(mod_attr, Module): + assert mod_attr is wrapped._mod + else: + assert mod_attr is wrapped + + # assert not self._is_builtin if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( wrapped, @@ -757,14 +781,6 @@ class TracedModuleBuilder(NodeMixin): type=NodeMixin.get_wrapped_type(wrapped), ), ) - """ - else: - node = NodeMixin.get(wrapped) - expr = node.expr - assert isinstance(expr, GetAttr) - if expr not in active_module_tracer().current_scope()._exprs: - active_module_tracer().current_scope().insert(expr) - """ return wrapped @@ -924,20 +940,57 @@ class TracedModule(Module): self.argdef_graph_map = argdef_graph_map self.argdef_outdef_map = argdef_outdef_map self._is_top = is_top + self.watch_points = [] + self.watch_node_value = {} + self.end_points = [] def forward(self, *args, **kwargs): - inputs, treedef = tree_flatten( - ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf - ) + inputs, treedef = tree_flatten(((self, *args), kwargs)) assert treedef in self.argdef_graph_map inputs = filter( lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs ) # allow TracedModuleBuilder for retrace. outputs = self.argdef_graph_map[treedef].interpret(*inputs) + if self.watch_points: + self.watch_node_value = {} + for n in self.watch_points: + self.watch_node_value[n] = n.top_graph._rst.pop(n) + + if self.end_points: + return outputs + out_def = self.argdef_outdef_map[treedef] outputs = out_def.unflatten(outputs) + return outputs + def set_watch_points(self, nodes): + if not isinstance(nodes, Sequence): + nodes = [nodes] + self.watch_points = nodes + for n in nodes: + n.top_graph._watch_point.append(n) + + def clear_watch_points(self): + for n in self.watch_points: + n.top_graph._watch_point = [] + self.watch_points = [] + self.watch_node_value = {} + + def set_end_points(self, nodes): + if not isinstance(nodes, Sequence): + nodes = [nodes] + self.end_points = nodes + graphs = list(self.argdef_graph_map.values()) + for n in nodes: + assert n.top_graph in graphs + n.top_graph._end_point.append(n) + + def clear_end_points(self): + for n in self.end_points: + n.top_graph._end_point = [] + self.end_points = [] + @property def graph(self) -> InternalGraph: if self._is_top: @@ -1014,6 +1067,9 @@ class TracedModule(Module): node2obj[graph._inputs[0]] = module if call: node2obj[call.inputs[0]] = module + + # replace inputs for submodule's exprx + if call: repl_dict = dict(zip(graph._inputs, call.inputs)) for ind, out in enumerate(graph.outputs): if isinstance(out.expr, Input): @@ -1028,8 +1084,8 @@ class TracedModule(Module): repl_dict[out] = call.outputs[ind] graph._replace_inputs_outputs(repl_dict) - for expr in graph._exprs: + for expr in graph._exprs: if isinstance(expr, GetAttr): # replace GetAttr with Constant if isinstance(expr.outputs[0], TensorNode): @@ -1129,6 +1185,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: param kwargs: the keyword arguments passed to forward method of ``mod`` """ assert active_module_tracer() is None + assert isinstance(mod, Module) try: use_sym_shape = set_symbolic_shape(True) set_module_tracing() @@ -1140,9 +1197,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: builder = TracedModuleBuilder(mod, True) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) - inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf) + inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): - assert isinstance(i, Tensor), "not support " + # assert isinstance(i, Tensor), "not support " if isinstance(i, RawTensor): NodeMixin.wrap_safe( i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) -- GitLab