提交 bee305be 编写于 作者: M Megvii Engine Team

feat(traced_module): add functional trace and CallMethod/Function expr

GitOrigin-RevId: ad2cdc1b61aa1dcd309b4ae725fad5efa978cdff
上级 763c56f3
......@@ -9,12 +9,13 @@
import collections
from typing import List
from typing import Callable, List
from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
......@@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode
class Expr:
"""
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``.
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
"""
inputs = None # type: List[Node]
outputs = None # type: List[Node]
def add_input(self, node):
self.inputs.append(node)
def add_outputs(self, outputs):
self.outputs = []
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
for i in outputs:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
@classmethod
def get_args_node(cls, arg):
"""
Create nodes by ``arg``, which may be a container.
Return the same structure with arg.
If ``arg`` was not Tensor or Module, it will be stored as const.
:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
else:
# TODO: assert arg type
return arg # as const
@classmethod
def get_arg_value(cls, inp_node, node2value):
"""
Get values from node2value by inp_node, which may be a container.
Return the same structure with inp_node.
If ``inp_node`` was not in node2value, it is a const.
:param inp_node: nodes.
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node
# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
......@@ -83,23 +138,22 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:])
class Call(Expr):
def __init__(self, module):
assert isinstance(module, ModuleNode)
class CallMethod(Expr):
def __init__(self, module, method="__call__"):
assert isinstance(module, (TensorNode, ModuleNode))
self.inputs = [
module,
]
self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node):
def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node)
def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)
for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
if arg_name is not None:
self.arg_names.append(arg_name)
@classmethod
def make(cls, *args, **kwargs):
......@@ -110,15 +164,16 @@ class Call(Expr):
def interpret(self, *inputs):
mod = inputs[0]
args = inputs[1:]
outputs = mod(*args)
outputs = getattr(mod, self.method)(*args, **self.kwargs)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
return outputs
def __repr__(self):
return "{} = Call({})({})".format(
return "{} = CallMethod({}, {})({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
self.method,
", ".join(str(i) for i in self.inputs[1:]),
)
......@@ -132,17 +187,6 @@ class Apply(Expr):
self.opdef = opdef
self.inputs = []
def add_input(self, node):
self.inputs.append(node)
def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)
for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
......@@ -179,6 +223,40 @@ class Apply(Expr):
return list(outputs)
class CallFunction(Expr):
def __init__(self, func):
assert isinstance(func, Callable)
self.func = func
self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr
def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)])
outputs = self.func(**inp_dict, **self.kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
return outputs
def __repr__(self):
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs),
)
# expr outputs = self.value
class Constant(Expr):
value = None
......
......@@ -6,7 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from ... import Tensor
from ... import functional as F
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
_active_module_tracer = None
......@@ -23,12 +27,14 @@ def set_active_module_tracer(tracer):
class module_tracer:
# builtin types
_opaque_types = set()
_active_scopes = None
def __init__(self):
def __init__(self, wrap_fn):
self._active_scopes = []
self.patcher = Patcher(wrap_fn)
@classmethod
def register_as_builtin(cls, mod):
......@@ -50,3 +56,105 @@ class module_tracer:
if self._active_scopes:
return self._active_scopes[-1]
return None
class PatchedFn:
frame_dict = None
name = None
origin_fn = None
def __init__(self, frame_dict, name):
self.frame_dict = frame_dict
self.name = name
self.origin_fn = (
self.frame_dict[name]
if isinstance(frame_dict, collections.abc.Mapping)
else getattr(frame_dict, name)
)
def set_func(self, func):
if isinstance(self.frame_dict, collections.abc.Mapping):
self.frame_dict[self.name] = func
else:
setattr(self.frame_dict, self.name, func)
class Patcher:
patched_fn_ids = set()
_builtin_functions = []
_builtin_modules = [
F,
F.distributed,
F.elemwise,
F.inplace,
F.loss,
F.math,
F.metric,
F.nn,
F.quantized,
F.tensor,
F.utils,
F.vision,
]
_builtin_methods = [
Tensor,
ArrayMethodMixin,
]
def __init__(self, wrap_fn):
self.patched_fn = []
self.visited_frames_ids = set()
self.wrap_fn = wrap_fn
for module in self._builtin_modules:
self.patch_module(module)
for cls in self._builtin_methods:
self.patch_cls(cls)
for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)
def patch_function(self, frame_dict, fn, wrap_fn):
patched_fn = PatchedFn(frame_dict, fn)
self.patched_fn_ids.add(id(patched_fn.origin_fn))
patched_fn.set_func(wrap_fn(patched_fn.origin_fn))
self.patched_fn.append(patched_fn)
def patch_method(self, cls, name, wrap_fn):
self.patch_function(cls, name, wrap_fn)
def patch_cls(self, cls):
import inspect
if id(cls) not in self.visited_frames_ids:
for k, v in cls.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(cls, k, self.wrap_fn)
self.visited_frames_ids.add(id(cls))
def patch_module(self, module):
import inspect
if id(module.__dict__) not in self.visited_frames_ids:
for k, v in module.__dict__.items():
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__))
def auto_patch(self, frame_dict):
if id(frame_dict) not in self.visited_frames_ids:
for k, v in frame_dict.items():
if id(v) in self.patched_fn_ids:
self.patch_function(frame_dict, k, self.wrap_fn)
self.visited_frames_ids.add(id(frame_dict))
def __enter__(self):
return self
def __exit__(self, type, vlaue, trace):
while self.patched_fn:
pf = self.patched_fn.pop()
pf.set_func(pf.origin_fn)
self.visited_frames_ids.clear()
......@@ -34,6 +34,10 @@ class Node:
Node.__total_id += 1
self._name = name
def __setstate__(self, d):
self.__dict__ = d
Node.__total_id = max(Node.__total_id, self._id) + 1
def __repr__(self):
if self._name is None:
return "%{}".format(self._id)
......
......@@ -8,14 +8,25 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import functools
from typing import List, Type
from ... import module as M
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing
from ...core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...tensor import Tensor
from .expr import Apply, Call, Constant, Expr, GetAttr, Input
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .module_tracer import (
Patcher,
active_module_tracer,
module_tracer,
set_active_module_tracer,
)
from .node import ModuleNode, Node, NodeMixin, TensorNode
......@@ -54,7 +65,9 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
values = expr.interpret(
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)
......@@ -67,6 +80,41 @@ class InternalGraph:
)
def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs):
if is_tracing_module():
unset_module_tracing()
const_kwargs = {}
arg_names = orig_func.__code__.co_varnames
if orig_func.__qualname__.split(".").__len__() > 1:
# FIXME: a robust way to distinguish method and function. <XP>
self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__)
else:
call_node = CallFunction.make(orig_func)
def add_input(inp, varname=None):
node = Expr.get_args_node(inp)
if node is not None:
call_node.add_input(node, varname)
else:
const_kwargs[varname] = inp
for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*inputs, **kwargs)
return wrapped_fn
class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
......@@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin):
mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = Call.make(NodeMixin.get(self))
callnode = CallMethod.make(NodeMixin.get(self))
def add_input(x):
callnode.add_input(NodeMixin.get(x))
......@@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin):
)
# prepare args and kwargs for inner graph
def wrap(x):
wrapped = copy.copy(x) # FIXME
# wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
......@@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin):
args.append(wrap(i))
for k, v in kwargs.items():
kwargs[k] = wrap(v)
active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(self, *args, **kwargs)
for i in (
......@@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin):
# rebind output to outer graph
callnode.add_outputs(outputs)
for i, node in zip(
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,),
callnode.outputs,
):
NodeMixin.wrap_safe(i, node)
return outputs
def __getattr__(self, name):
......@@ -229,6 +275,55 @@ class TracedModule(Module):
rst = rst[0]
return rst
@property
def all_exprs(self):
"""
Visit all ``Expr``s in the graph recursively.
:return: List[Expr]
"""
in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]
def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
call.inputs[0] = module
return (call,)
exprs = []
graph = module.m_node.graph
for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
expr.inputs[idx] = call.inputs[idx]
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
expr.outputs[idx] = call.outputs[idx]
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
exprs.append(const)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
else:
exprs.append(expr)
else:
exprs.append(expr)
return exprs
return in_nodes + _flatten_submodule(self)
def __getstate__(self):
d = self.__dict__
for k in Module.__dict__:
......@@ -273,9 +368,9 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule
assert active_module_tracer() is None
try:
set_module_tracing()
set_active_module_tracer(module_tracer())
set_active_module_tracer(module_tracer(_wrapped_function))
with active_module_tracer().patcher:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册