提交 bee305be 编写于 作者: M Megvii Engine Team

feat(traced_module): add functional trace and CallMethod/Function expr

GitOrigin-RevId: ad2cdc1b61aa1dcd309b4ae725fad5efa978cdff
上级 763c56f3
...@@ -9,12 +9,13 @@ ...@@ -9,12 +9,13 @@
import collections import collections
from typing import List from typing import Callable, List
from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .module_tracer import active_module_tracer from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
...@@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode ...@@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode
class Expr: class Expr:
""" """
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
""" """
inputs = None # type: List[Node] inputs = None # type: List[Node]
outputs = None # type: List[Node] outputs = None # type: List[Node]
def add_input(self, node):
self.inputs.append(node)
def add_outputs(self, outputs):
self.outputs = []
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
for i in outputs:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
@classmethod
def get_args_node(cls, arg):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.
If ``arg`` was not Tensor or Module, it will be stored as const.
:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
else:
# TODO: assert arg type
return arg # as const
@classmethod
def get_arg_value(cls, inp_node, node2value):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.
If ``inp_node`` was not in node2value, it is a const.
:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node
# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
...@@ -83,23 +138,22 @@ class GetAttr(Expr): ...@@ -83,23 +138,22 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class Call(Expr): class CallMethod(Expr):
def __init__(self, module): def __init__(self, module, method="__call__"):
assert isinstance(module, ModuleNode) assert isinstance(module, (TensorNode, ModuleNode))
self.inputs = [ self.inputs = [
module, module,
] ]
self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node): def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node) self.inputs.append(node)
if arg_name is not None:
def add_outputs(self, references): self.arg_names.append(arg_name)
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)
for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
...@@ -110,15 +164,16 @@ class Call(Expr): ...@@ -110,15 +164,16 @@ class Call(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
mod = inputs[0] mod = inputs[0]
args = inputs[1:] args = inputs[1:]
outputs = mod(*args) outputs = getattr(mod, self.method)(*args, **self.kwargs)
if isinstance(outputs, RawTensor): if isinstance(outputs, RawTensor):
outputs = (outputs,) outputs = (outputs,)
return outputs return outputs
def __repr__(self): def __repr__(self):
return "{} = Call({})({})".format( return "{} = CallMethod({}, {})({})".format(
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.inputs[0], self.inputs[0],
self.method,
", ".join(str(i) for i in self.inputs[1:]), ", ".join(str(i) for i in self.inputs[1:]),
) )
...@@ -132,17 +187,6 @@ class Apply(Expr): ...@@ -132,17 +187,6 @@ class Apply(Expr):
self.opdef = opdef self.opdef = opdef
self.inputs = [] self.inputs = []
def add_input(self, node):
self.inputs.append(node)
def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)
for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
...@@ -179,6 +223,40 @@ class Apply(Expr): ...@@ -179,6 +223,40 @@ class Apply(Expr):
return list(outputs) return list(outputs)
class CallFunction(Expr):
def __init__(self, func):
assert isinstance(func, Callable)
self.func = func
self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr
def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
outputs = self.func(**inp_dict, **self.kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
return outputs
def __repr__(self):
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs),
)
# expr outputs = self.value # expr outputs = self.value
class Constant(Expr): class Constant(Expr):
value = None value = None
......
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from ... import Tensor
from ... import functional as F
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ...module import Module
_active_module_tracer = None _active_module_tracer = None
...@@ -23,12 +27,14 @@ def set_active_module_tracer(tracer): ...@@ -23,12 +27,14 @@ def set_active_module_tracer(tracer):
class module_tracer: class module_tracer:
# builtin types
_opaque_types = set() _opaque_types = set()
_active_scopes = None _active_scopes = None
def __init__(self): def __init__(self, wrap_fn):
self._active_scopes = [] self._active_scopes = []
self.patcher = Patcher(wrap_fn)
@classmethod @classmethod
def register_as_builtin(cls, mod): def register_as_builtin(cls, mod):
...@@ -50,3 +56,105 @@ class module_tracer: ...@@ -50,3 +56,105 @@ class module_tracer:
if self._active_scopes: if self._active_scopes:
return self._active_scopes[-1] return self._active_scopes[-1]
return None return None
class PatchedFn:
frame_dict = None
name = None
origin_fn = None
def __init__(self, frame_dict, name):
self.frame_dict = frame_dict
self.name = name
self.origin_fn = (
self.frame_dict[name]
if isinstance(frame_dict, collections.abc.Mapping)
else getattr(frame_dict, name)
)
def set_func(self, func):
if isinstance(self.frame_dict, collections.abc.Mapping):
self.frame_dict[self.name] = func
else:
setattr(self.frame_dict, self.name, func)
class Patcher:
patched_fn_ids = set()
_builtin_functions = []
_builtin_modules = [
F,
F.distributed,
F.elemwise,
F.inplace,
F.loss,
F.math,
F.metric,
F.nn,
F.quantized,
F.tensor,
F.utils,
F.vision,
]
_builtin_methods = [
Tensor,
ArrayMethodMixin,
]
def __init__(self, wrap_fn):
self.patched_fn = []
self.visited_frames_ids = set()
self.wrap_fn = wrap_fn
for module in self._builtin_modules:
self.patch_module(module)
for cls in self._builtin_methods:
self.patch_cls(cls)
for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)
def patch_function(self, frame_dict, fn, wrap_fn):
patched_fn = PatchedFn(frame_dict, fn)
self.patched_fn_ids.add(id(patched_fn.origin_fn))
patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
self.patched_fn.append(patched_fn)
def patch_method(self, cls, name, wrap_fn):
self.patch_function(cls, name, wrap_fn)
def patch_cls(self, cls):
import inspect
if id(cls) not in self.visited_frames_ids:
for k, v in cls.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(cls, k, self.wrap_fn)
self.visited_frames_ids.add(id(cls))
def patch_module(self, module):
import inspect
if id(module.__dict__) not in self.visited_frames_ids:
for k, v in module.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__))
def auto_patch(self, frame_dict):
if id(frame_dict) not in self.visited_frames_ids:
for k, v in frame_dict.items():
if id(v) in self.patched_fn_ids:
self.patch_function(frame_dict, k, self.wrap_fn)
self.visited_frames_ids.add(id(frame_dict))
def __enter__(self):
return self
def __exit__(self, type, vlaue, trace):
while self.patched_fn:
pf = self.patched_fn.pop()
pf.set_func(pf.origin_fn)
self.visited_frames_ids.clear()
...@@ -34,6 +34,10 @@ class Node: ...@@ -34,6 +34,10 @@ class Node:
Node.__total_id += 1 Node.__total_id += 1
self._name = name self._name = name
def __setstate__(self, d):
self.__dict__ = d
Node.__total_id = max(Node.__total_id, self._id) + 1
def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
return "%{}".format(self._id) return "%{}".format(self._id)
......
...@@ -8,14 +8,25 @@ ...@@ -8,14 +8,25 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import copy import copy
import functools
from typing import List, Type from typing import List, Type
from ... import module as M from ... import module as M
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing from ...core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .expr import Apply, Call, Constant, Expr, GetAttr, Input from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer from .module_tracer import (
Patcher,
active_module_tracer,
module_tracer,
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
...@@ -54,7 +65,9 @@ class InternalGraph: ...@@ -54,7 +65,9 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs): for n, v in zip(self._inputs, inputs):
node2value[n] = v node2value[n] = v
for expr in self._exprs: for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) values = expr.interpret(
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
for n, v in zip(expr.outputs, values): for n, v in zip(expr.outputs, values):
node2value[n] = v node2value[n] = v
return list(node2value[i] for i in self._outputs) return list(node2value[i] for i in self._outputs)
...@@ -67,6 +80,41 @@ class InternalGraph: ...@@ -67,6 +80,41 @@ class InternalGraph:
) )
def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs):
if is_tracing_module():
unset_module_tracing()
const_kwargs = {}
arg_names = orig_func.__code__.co_varnames
if orig_func.__qualname__.split(".").__len__() > 1:
# FIXME: a robust way to distinguish method and function. <XP>
self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
else:
call_node = CallFunction.make(orig_func)
def add_input(inp, varname=None):
node = Expr.get_args_node(inp)
if node is not None:
call_node.add_input(node, varname)
else:
const_kwargs[varname] = inp
for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*inputs, **kwargs)
return wrapped_fn
class TracedModuleBuilder(NodeMixin): class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module _mod = None # type: Module
...@@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin):
mark_constant(i) mark_constant(i)
for k, v in kwargs.items(): for k, v in kwargs.items():
mark_constant(v) mark_constant(v)
callnode = Call.make(NodeMixin.get(self)) callnode = CallMethod.make(NodeMixin.get(self))
def add_input(x): def add_input(x):
callnode.add_input(NodeMixin.get(x)) callnode.add_input(NodeMixin.get(x))
...@@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin): ...@@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin):
) )
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x): def wrap(x):
wrapped = copy.copy(x) # FIXME # wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
NodeMixin.wrap( NodeMixin.wrap(
wrapped, wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
...@@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin): ...@@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin):
args.append(wrap(i)) args.append(wrap(i))
for k, v in kwargs.items(): for k, v in kwargs.items():
kwargs[k] = wrap(v) kwargs[k] = wrap(v)
active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(self, *args, **kwargs) outputs = type(self._mod).forward(self, *args, **kwargs)
for i in ( for i in (
...@@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin): ...@@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin):
# rebind output to outer graph # rebind output to outer graph
callnode.add_outputs(outputs) callnode.add_outputs(outputs)
for i, node in zip(
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,),
callnode.outputs,
):
NodeMixin.wrap_safe(i, node)
return outputs return outputs
def __getattr__(self, name): def __getattr__(self, name):
...@@ -229,6 +275,55 @@ class TracedModule(Module): ...@@ -229,6 +275,55 @@ class TracedModule(Module):
rst = rst[0] rst = rst[0]
return rst return rst
@property
def all_exprs(self):
"""
Visit all ``Expr``s in the graph recursively.
:return: List[Expr]
"""
in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]
def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
call.inputs[0] = module
return (call,)
exprs = []
graph = module.m_node.graph
for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
expr.inputs[idx] = call.inputs[idx]
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
expr.outputs[idx] = call.outputs[idx]
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
exprs.append(const)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
else:
exprs.append(expr)
else:
exprs.append(expr)
return exprs
return in_nodes + _flatten_submodule(self)
def __getstate__(self): def __getstate__(self):
d = self.__dict__ d = self.__dict__
for k in Module.__dict__: for k in Module.__dict__:
...@@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule ...@@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule
assert active_module_tracer() is None assert active_module_tracer() is None
try: try:
set_module_tracing() set_module_tracing()
set_active_module_tracer(module_tracer()) set_active_module_tracer(module_tracer(_wrapped_function))
global_scope = InternalGraph() with active_module_tracer().patcher:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod) builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items(): for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))
builder(*inputs, **kwargs) builder(*inputs, **kwargs)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
return builder.build() return builder.build()
finally: finally:
set_active_module_tracer(None) set_active_module_tracer(None)
unset_module_tracing() unset_module_tracing()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册