From bee305beb2d852a1d4b75aa09e18c665ce77fa6e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 30 Jun 2021 19:27:45 +0800 Subject: [PATCH] feat(traced_module): add functional trace and CallMethod/Function expr GitOrigin-RevId: ad2cdc1b61aa1dcd309b4ae725fad5efa978cdff --- .../experimental/traced_module/expr.py | 132 ++++++++++++---- .../traced_module/module_tracer.py | 110 ++++++++++++- .../experimental/traced_module/node.py | 4 + .../traced_module/traced_module.py | 145 +++++++++++++++--- 4 files changed, 338 insertions(+), 53 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index ad9ed301..ec07af73 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -9,12 +9,13 @@ import collections -from typing import List +from typing import Callable, List from ...core._imperative_rt import OpDef from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ...core.ops.special import Const +from ...module import Module from ...tensor import Tensor from .module_tracer import active_module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode @@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode class Expr: """ - ``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. + ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. """ inputs = None # type: List[Node] outputs = None # type: List[Node] + def add_input(self, node): + self.inputs.append(node) + + def add_outputs(self, outputs): + self.outputs = [] + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + for i in outputs: + self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) + + for i, node in zip(outputs, self.outputs,): + NodeMixin.wrap_safe(i, node) + + @classmethod + def get_args_node(cls, arg): + """ + Create nodes by ``arg``, which may be a container. + Return the same structure with arg. + + If ``arg`` was not Tensor or Module, it will be stored as const. + + :param arg: tensor, module or const. + """ + if isinstance(arg, (RawTensor, Module)): + if not NodeMixin.get(arg, None): + NodeMixin.wrap_safe(arg, Constant.make(arg)) + return NodeMixin.get(arg) + elif isinstance(arg, collections.abc.Sequence): + seq_cls = type(arg) + return seq_cls([Expr.get_args_node(a) for a in arg]) + else: + # TODO: assert arg type + return arg # as const + + @classmethod + def get_arg_value(cls, inp_node, node2value): + """ + Get values from node2value by inp_node, which may be a container. + Return the same structure with inp_node. + + If ``inp_node`` was not in node2value, it is a const. + + :param inp_node: nodes. + :param node2value: dict from node to tensor and module. + """ + if inp_node in node2value: + return node2value[inp_node] + elif isinstance(inp_node, collections.abc.Sequence): + seq_cls = type(inp_node) + return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) + else: + return inp_node + # expr: None (i.e. fake expression which is used to mark input) class Input(Expr): @@ -83,23 +138,22 @@ class GetAttr(Expr): # expr: outputs = inputs[0].__call__(*inputs[1:]) -class Call(Expr): - def __init__(self, module): - assert isinstance(module, ModuleNode) +class CallMethod(Expr): + def __init__(self, module, method="__call__"): + assert isinstance(module, (TensorNode, ModuleNode)) self.inputs = [ module, ] + self.method = method + self.arg_names = [] + self.kwargs = {} # const kwargs - def add_input(self, node): + def add_input(self, node, arg_name=None): + if arg_name == "self": # FIXME: + return self.inputs.append(node) - - def add_outputs(self, references): - self.outputs = [] - if not isinstance(references, collections.Sequence): - references = (references,) - - for i in references: - self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) + if arg_name is not None: + self.arg_names.append(arg_name) @classmethod def make(cls, *args, **kwargs): @@ -110,15 +164,16 @@ class Call(Expr): def interpret(self, *inputs): mod = inputs[0] args = inputs[1:] - outputs = mod(*args) + outputs = getattr(mod, self.method)(*args, **self.kwargs) if isinstance(outputs, RawTensor): outputs = (outputs,) return outputs def __repr__(self): - return "{} = Call({})({})".format( + return "{} = CallMethod({}, {})({})".format( ", ".join(str(i) for i in self.outputs), self.inputs[0], + self.method, ", ".join(str(i) for i in self.inputs[1:]), ) @@ -132,17 +187,6 @@ class Apply(Expr): self.opdef = opdef self.inputs = [] - def add_input(self, node): - self.inputs.append(node) - - def add_outputs(self, references): - self.outputs = [] - if not isinstance(references, collections.Sequence): - references = (references,) - - for i in references: - self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) - @classmethod def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) @@ -179,6 +223,40 @@ class Apply(Expr): return list(outputs) +class CallFunction(Expr): + def __init__(self, func): + assert isinstance(func, Callable) + self.func = func + self.inputs = [] + self.arg_names = [] + self.kwargs = {} # const kwargs + + def add_input(self, node, arg_name): + self.inputs.append(node) + self.arg_names.append(arg_name) + + @classmethod + def make(cls, *args, **kwargs): + expr = cls(*args, **kwargs) + active_module_tracer().current_scope().insert(expr) + return expr + + def interpret(self, *inputs): + inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) + outputs = self.func(**inp_dict, **self.kwargs) + outputs = ( + outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) + ) + return outputs + + def __repr__(self): + return "{} = {}({})".format( + ", ".join(str(i) for i in self.outputs), + self.func.__module__ + "." + self.func.__name__, + ", ".join(str(i) for i in self.inputs), + ) + + # expr outputs = self.value class Constant(Expr): value = None diff --git a/imperative/python/megengine/experimental/traced_module/module_tracer.py b/imperative/python/megengine/experimental/traced_module/module_tracer.py index 0a0b2807..f8eb67cf 100644 --- a/imperative/python/megengine/experimental/traced_module/module_tracer.py +++ b/imperative/python/megengine/experimental/traced_module/module_tracer.py @@ -6,7 +6,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections +from ... import Tensor +from ... import functional as F +from ...core.tensor.array_method import ArrayMethodMixin from ...module import Module _active_module_tracer = None @@ -23,12 +27,14 @@ def set_active_module_tracer(tracer): class module_tracer: + # builtin types _opaque_types = set() _active_scopes = None - def __init__(self): + def __init__(self, wrap_fn): self._active_scopes = [] + self.patcher = Patcher(wrap_fn) @classmethod def register_as_builtin(cls, mod): @@ -50,3 +56,105 @@ class module_tracer: if self._active_scopes: return self._active_scopes[-1] return None + + +class PatchedFn: + frame_dict = None + name = None + origin_fn = None + + def __init__(self, frame_dict, name): + self.frame_dict = frame_dict + self.name = name + self.origin_fn = ( + self.frame_dict[name] + if isinstance(frame_dict, collections.abc.Mapping) + else getattr(frame_dict, name) + ) + + def set_func(self, func): + if isinstance(self.frame_dict, collections.abc.Mapping): + self.frame_dict[self.name] = func + else: + setattr(self.frame_dict, self.name, func) + + +class Patcher: + + patched_fn_ids = set() + _builtin_functions = [] + _builtin_modules = [ + F, + F.distributed, + F.elemwise, + F.inplace, + F.loss, + F.math, + F.metric, + F.nn, + F.quantized, + F.tensor, + F.utils, + F.vision, + ] + _builtin_methods = [ + Tensor, + ArrayMethodMixin, + ] + + def __init__(self, wrap_fn): + self.patched_fn = [] + self.visited_frames_ids = set() + self.wrap_fn = wrap_fn + for module in self._builtin_modules: + self.patch_module(module) + + for cls in self._builtin_methods: + self.patch_cls(cls) + + for i, j in self._builtin_functions: + if id(i) not in self.visited_frames_ids: + self.patch_function(i, j, self.wrap_fn) + + def patch_function(self, frame_dict, fn, wrap_fn): + patched_fn = PatchedFn(frame_dict, fn) + self.patched_fn_ids.add(id(patched_fn.origin_fn)) + patched_fn.set_func(wrap_fn(patched_fn.origin_fn)) + self.patched_fn.append(patched_fn) + + def patch_method(self, cls, name, wrap_fn): + self.patch_function(cls, name, wrap_fn) + + def patch_cls(self, cls): + import inspect + + if id(cls) not in self.visited_frames_ids: + for k, v in cls.__dict__.items(): + if inspect.isfunction(v) and not k.startswith("_"): + self.patch_function(cls, k, self.wrap_fn) + self.visited_frames_ids.add(id(cls)) + + def patch_module(self, module): + import inspect + + if id(module.__dict__) not in self.visited_frames_ids: + for k, v in module.__dict__.items(): + if inspect.isfunction(v) and not k.startswith("_"): + self.patch_function(module.__dict__, k, self.wrap_fn) + self.visited_frames_ids.add(id(module.__dict__)) + + def auto_patch(self, frame_dict): + if id(frame_dict) not in self.visited_frames_ids: + for k, v in frame_dict.items(): + if id(v) in self.patched_fn_ids: + self.patch_function(frame_dict, k, self.wrap_fn) + self.visited_frames_ids.add(id(frame_dict)) + + def __enter__(self): + return self + + def __exit__(self, type, vlaue, trace): + while self.patched_fn: + pf = self.patched_fn.pop() + pf.set_func(pf.origin_fn) + self.visited_frames_ids.clear() diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index 6e6b5a9a..e45aa281 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -34,6 +34,10 @@ class Node: Node.__total_id += 1 self._name = name + def __setstate__(self, d): + self.__dict__ = d + Node.__total_id = max(Node.__total_id, self._id) + 1 + def __repr__(self): if self._name is None: return "%{}".format(self._id) diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index 880c6404..9dd76035 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -8,14 +8,25 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections import copy +import functools from typing import List, Type from ... import module as M -from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing +from ...core._imperative_rt.core2 import ( + is_tracing_module, + set_module_tracing, + unset_module_tracing, +) +from ...core.tensor.array_method import ArrayMethodMixin from ...module import Module from ...tensor import Tensor -from .expr import Apply, Call, Constant, Expr, GetAttr, Input -from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer +from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input +from .module_tracer import ( + Patcher, + active_module_tracer, + module_tracer, + set_active_module_tracer, +) from .node import ModuleNode, Node, NodeMixin, TensorNode @@ -54,7 +65,9 @@ class InternalGraph: for n, v in zip(self._inputs, inputs): node2value[n] = v for expr in self._exprs: - values = expr.interpret(*list(node2value[i] for i in expr.inputs)) + values = expr.interpret( + *list(Expr.get_arg_value(i, node2value) for i in expr.inputs) + ) for n, v in zip(expr.outputs, values): node2value[n] = v return list(node2value[i] for i in self._outputs) @@ -67,6 +80,41 @@ class InternalGraph: ) +def _wrapped_function(orig_func): + @functools.wraps(orig_func) + def wrapped_fn(*inputs, **kwargs): + if is_tracing_module(): + unset_module_tracing() + const_kwargs = {} + arg_names = orig_func.__code__.co_varnames + if orig_func.__qualname__.split(".").__len__() > 1: + # FIXME: a robust way to distinguish method and function. + self = inputs[0] + call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) + else: + call_node = CallFunction.make(orig_func) + + def add_input(inp, varname=None): + node = Expr.get_args_node(inp) + if node is not None: + call_node.add_input(node, varname) + else: + const_kwargs[varname] = inp + + for ind, inp in enumerate(inputs): + add_input(inp, arg_names[ind]) + for k, v in kwargs.items(): + add_input(v, k) + call_node.kwargs = const_kwargs + outputs = orig_func(*inputs, **kwargs) + call_node.add_outputs(outputs) + set_module_tracing() + return outputs + return orig_func(*inputs, **kwargs) + + return wrapped_fn + + class TracedModuleBuilder(NodeMixin): _mod = None # type: Module @@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin): mark_constant(i) for k, v in kwargs.items(): mark_constant(v) - callnode = Call.make(NodeMixin.get(self)) + callnode = CallMethod.make(NodeMixin.get(self)) def add_input(x): callnode.add_input(NodeMixin.get(x)) @@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin): ) # prepare args and kwargs for inner graph def wrap(x): - wrapped = copy.copy(x) # FIXME + # wrapped = copy.copy(x) # FIXME + wrapped = x # FIXME: NodeMixin.wrap( wrapped, lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), @@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin): args.append(wrap(i)) for k, v in kwargs.items(): kwargs[k] = wrap(v) - + active_module_tracer().patcher.auto_patch( + getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) + ) outputs = type(self._mod).forward(self, *args, **kwargs) for i in ( @@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin): # rebind output to outer graph callnode.add_outputs(outputs) - for i, node in zip( - outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), - callnode.outputs, - ): - NodeMixin.wrap_safe(i, node) return outputs def __getattr__(self, name): @@ -229,6 +275,55 @@ class TracedModule(Module): rst = rst[0] return rst + @property + def all_exprs(self): + """ + Visit all ``Expr``s in the graph recursively. + + :return: List[Expr] + """ + + in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] + + def _flatten_submodule(module, call=None): + if not isinstance(module, TracedModule): + call.inputs[0] = module + return (call,) + + exprs = [] + + graph = module.m_node.graph + for expr in graph._exprs: + + # replace inputs for submodule's expr + for idx, inp in enumerate(expr.inputs): + if call and inp in graph._inputs: + expr.inputs[idx] = call.inputs[idx] + # replace outputs for submodule's expr + for idx, outp in enumerate(expr.outputs): + if call and outp in graph._outputs: + expr.outputs[idx] = call.outputs[idx] + + if isinstance(expr, GetAttr): + # replace GetAttr with Constant + if isinstance(expr.outputs[0], TensorNode): + const = Constant(getattr(module, expr.name)) + const.outputs = expr.outputs + exprs.append(const) + elif isinstance(expr, CallMethod): + obj_node = expr.inputs[0] + if isinstance(obj_node, ModuleNode): + (obj,) = expr.inputs[0].expr.interpret(module) + exprs.extend(_flatten_submodule(obj, expr)) + else: + exprs.append(expr) + else: + exprs.append(expr) + + return exprs + + return in_nodes + _flatten_submodule(self) + def __getstate__(self): d = self.__dict__ for k in Module.__dict__: @@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule assert active_module_tracer() is None try: set_module_tracing() - set_active_module_tracer(module_tracer()) - global_scope = InternalGraph() - - active_module_tracer().push_scope(global_scope) + set_active_module_tracer(module_tracer(_wrapped_function)) + with active_module_tracer().patcher: + global_scope = InternalGraph() + active_module_tracer().push_scope(global_scope) - builder = TracedModuleBuilder(mod) - NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) + builder = TracedModuleBuilder(mod) + NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) - for _, i in enumerate(inputs): - NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) - for k, v in kwargs.items(): - NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) + for _, i in enumerate(inputs): + NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) + for k, v in kwargs.items(): + NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) - builder(*inputs, **kwargs) - active_module_tracer().pop_scope() + builder(*inputs, **kwargs) + active_module_tracer().pop_scope() - return builder.build() + return builder.build() finally: set_active_module_tracer(None) unset_module_tracing() -- GitLab