提交 15712807 编写于 作者: M Megvii Engine Team

feat(traced_module): add name to Node

GitOrigin-RevId: 39c28090678d0da23c313d594103405896e872ec
上级 e918f0aa
...@@ -11,6 +11,7 @@ import builtins ...@@ -11,6 +11,7 @@ import builtins
import collections import collections
import copy import copy
import inspect import inspect
import re
from typing import Callable, Dict, List from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
...@@ -21,7 +22,24 @@ from ...module import Module ...@@ -21,7 +22,24 @@ from ...module import Module
from ...tensor import Parameter, Tensor from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef, tree_flatten from .pytree import ArgsIndex, TreeDef, tree_flatten
def rstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?P<left>.*?)(?:%s)+$" % __chars, "\g<left>", s)
return s
def lstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?:%s)+(?P<right>.*)$" % __chars, "\g<right>", s)
return s
def strip(s: str, __chars: str):
s = lstrip(rstrip(s, __chars), __chars)
return s
class Expr: class Expr:
...@@ -67,9 +85,29 @@ class Expr: ...@@ -67,9 +85,29 @@ class Expr:
if not isinstance(outputs, collections.Sequence): if not isinstance(outputs, collections.Sequence):
outputs = (outputs,) outputs = (outputs,)
name = None
if isinstance(self, CallMethod):
name = self.inputs[0]._name
assert name is not None
name = rstrip(name, "_out")
if self.method == "__call__":
name += "_out"
else:
strip_method = strip(self.method, "_")
name = "%s_out" % strip_method
elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out"
elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out"
for i in outputs: for i in outputs:
assert isinstance(i, RawTensor) assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) o_name = (
active_module_tracer().current_scope()._create_unique_name(name)
)
self.outputs.append(
NodeMixin.get_wrapped_type(i)(expr=self, name=o_name)
)
for i, node in zip(outputs, self.outputs,): for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node) NodeMixin.wrap_safe(i, node)
...@@ -133,11 +171,16 @@ class Input(Expr): ...@@ -133,11 +171,16 @@ class Input(Expr):
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope().add_input(expr.outputs[0]) oup_node = expr.outputs[0]
name = (
active_module_tracer().current_scope()._create_unique_name(oup_node._name)
)
oup_node._name = name
active_module_tracer().current_scope().add_input(oup_node)
return expr.outputs[0] return expr.outputs[0]
def __repr__(self): def __repr__(self):
return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name) return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name)
# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
...@@ -154,22 +197,31 @@ class GetAttr(Expr): ...@@ -154,22 +197,31 @@ class GetAttr(Expr):
self.name = name self.name = name
node_cls = type if type else Node node_cls = type if type else Node
self.outputs = [ self.outputs = [
node_cls(self), node_cls(self, name=name),
] ]
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
module = expr.inputs[0]
oup_name = expr.name
while module._name != "self":
oup_name = module._name + "_" + oup_name
module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope().insert(expr)
expr.outputs[0]._name = expr.name
return expr.outputs[0] return expr.outputs[0]
def interpret(self, *inputs): def interpret(self, *inputs):
return (getattr(inputs[0], self.name),) return (getattr(inputs[0], self.name),)
def __repr__(self): def __repr__(self):
return '%{}: {} = GetAttr({}, "{}")'.format( out_type = "Tensor"
self._id, self.outputs[0], self.inputs[0], self.name if isinstance(self.outputs[0], ModuleNode):
out_type = self.outputs[0].module_type.__name__
return '%{}:\t{} = getattr({}, "{}") -> ({})'.format(
self._id, self.outputs[0], self.inputs[0], self.name, out_type
) )
...@@ -230,11 +282,14 @@ class CallMethod(Expr): ...@@ -230,11 +282,14 @@ class CallMethod(Expr):
outputs = self.outputs outputs = self.outputs
if self.out_def: if self.out_def:
outputs = self.out_def.unflatten(outputs) outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}.{}({})".format( method = ".%s" % self.method
if method == ".__call__":
method = ""
return "%{}:\t{}{}{}({})".format(
self._id, self._id,
str(outputs) + " = " if outputs else "", str(outputs) + " = " if outputs else "",
self.args[0], self.args[0],
self.method, method,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )
...@@ -259,7 +314,7 @@ class Apply(Expr): ...@@ -259,7 +314,7 @@ class Apply(Expr):
return apply(self.opdef, *inputs) return apply(self.opdef, *inputs)
def __repr__(self): def __repr__(self):
return "%{}: {} = {}({})".format( return "%{}:\t{} = {}({})".format(
self._id, self._id,
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.opdef, self.opdef,
...@@ -314,10 +369,10 @@ class CallFunction(Expr): ...@@ -314,10 +369,10 @@ class CallFunction(Expr):
outputs = self.outputs outputs = self.outputs
if self.out_def: if self.out_def:
outputs = self.out_def.unflatten(outputs) outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}({})".format( return "%{}:\t{}{}({})".format(
self._id, self._id,
str(outputs) + " = " if outputs else "", str(outputs) + " = " if outputs else "",
self.func.__module__ + "." + self.func.__name__, self.func.__module__.rsplit(".")[-1] + "." + self.func.__name__,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )
...@@ -328,21 +383,25 @@ class Constant(Expr): ...@@ -328,21 +383,25 @@ class Constant(Expr):
# TODO: constant cache to reduce the size of dumped model # TODO: constant cache to reduce the size of dumped model
_constant_cache = {} _constant_cache = {}
def __init__(self, c): def __init__(self, c, name=None):
super().__init__() super().__init__()
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c) assert module_tracer.is_builtin(c)
self.value = c self.value = c
self.name = name
self.inputs = [] self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c) node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [ self.outputs = [
node_cls(self), node_cls(self, name=name),
] ]
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
name = active_module_tracer().current_scope()._create_unique_name(name)
expr.outputs[0]._name = name
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope().insert(expr)
return expr.outputs[0] return expr.outputs[0]
...@@ -352,8 +411,14 @@ class Constant(Expr): ...@@ -352,8 +411,14 @@ class Constant(Expr):
return (self.value,) return (self.value,)
def __repr__(self): def __repr__(self):
return "%{}: {} = Constant({})".format( name = self.name
self._id, self.outputs[0], type(self.value) if name is None:
name = type(self.value)
node_type = "Module"
if isinstance(self.outputs[0], TensorNode):
node_type = "Tensor"
return "%{}:\t{} = Constant({}) -> ({})".format(
self._id, self.outputs[0], name, node_type
) )
def __getstate__(self): def __getstate__(self):
......
...@@ -28,8 +28,9 @@ class Node: ...@@ -28,8 +28,9 @@ class Node:
expr = None expr = None
__total_id = 0 __total_id = 0
_id = None _id = None
_name = None
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
_name = None
_format_spec = ""
def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
self.expr = expr self.expr = expr
...@@ -43,10 +44,35 @@ class Node: ...@@ -43,10 +44,35 @@ class Node:
Node.__total_id = max(Node.__total_id, self._id) + 1 Node.__total_id = max(Node.__total_id, self._id) + 1
def __repr__(self): def __repr__(self):
if self._name is None: format_spec = Node._format_spec
return "%{}".format(self._id) return self.__format__(format_spec)
def __format__(self, format_spec: str) -> str:
if format_spec == "" or format_spec is None:
format_spec = Node._format_spec
name = self._name
if name is None:
name = ""
if format_spec in ["i", "p", "ip", "pi"]:
if "p" in format_spec:
graph = self.top_graph
prefix_name = ""
if graph is not None:
prefix_name = graph._name
if graph._prefix_name:
prefix_name = "{}_{}".format(
graph._prefix_name, prefix_name.lstrip("_")
)
if name:
name = "_" + name.lstrip("_")
name = "{}{}".format(prefix_name, name)
if "i" in format_spec:
if name:
name = "_" + name.lstrip("_")
name = "%{}{}".format(self._id, name)
return name
else: else:
return "%{}".format(self._name) return name if name else ("%d" % self._id)
@property @property
def top_graph(self): def top_graph(self):
...@@ -54,6 +80,12 @@ class Node: ...@@ -54,6 +80,12 @@ class Node:
return self._top_graph() return self._top_graph()
return None return None
@classmethod
def set_format_spec(cls, str):
old_format_spec = cls._format_spec
cls._format_spec = str
return old_format_spec
class ModuleNode(Node): class ModuleNode(Node):
""" """
...@@ -72,12 +104,6 @@ class ModuleNode(Node): ...@@ -72,12 +104,6 @@ class ModuleNode(Node):
super().__init__(expr, name) super().__init__(expr, name)
self.actual_mnode = [] self.actual_mnode = []
def __repr__(self):
if self._name is None:
return "%{}_({})".format(self._id, self.module_type.__name__)
else:
return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)
def __getstate__(self): def __getstate__(self):
return { return {
"expr": self.expr, "expr": self.expr,
...@@ -104,12 +130,6 @@ class TensorNode(Node): ...@@ -104,12 +130,6 @@ class TensorNode(Node):
qparam = None qparam = None
device = None device = None
def __repr__(self):
if self._name is None:
return "%{}_(Tensor)".format(self._id)
else:
return "%{}_{}(Tensor)".format(self._id, self._name)
def __getstate__(self): def __getstate__(self):
return { return {
"expr": self.expr, "expr": self.expr,
...@@ -119,6 +139,7 @@ class TensorNode(Node): ...@@ -119,6 +139,7 @@ class TensorNode(Node):
"shape": self.shape, "shape": self.shape,
"dtype": self.dtype, "dtype": self.dtype,
"device": self.device, "device": self.device,
"_name": self._name,
} }
......
...@@ -22,6 +22,16 @@ from ...quantization.utils import LSQParams, QParams, QuantMode ...@@ -22,6 +22,16 @@ from ...quantization.utils import LSQParams, QParams, QuantMode
from ...tensor import Parameter, Tensor from ...tensor import Parameter, Tensor
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
class ArgsIndex:
def __init__(self, index=0, name="") -> None:
self.index = index
self.name = name
def __repr__(self) -> str:
return self.name
SUPPORTED_TYPE = {} SUPPORTED_TYPE = {}
# if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree # if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree
...@@ -39,6 +49,7 @@ SUPPORTED_LEAF_TYPE = { ...@@ -39,6 +49,7 @@ SUPPORTED_LEAF_TYPE = {
type(None), type(None),
type(Ellipsis), type(Ellipsis),
QuantMode, QuantMode,
ArgsIndex,
} }
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
...@@ -121,11 +132,11 @@ def _is_leaf(obj): ...@@ -121,11 +132,11 @@ def _is_leaf(obj):
def _leaf_type(node): def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)): if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode) return (Tensor, TensorNode, ArgsIndex)
elif isinstance(node, (NodeMixin, Module)): elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin) return (Module, ModuleNode, NodeMixin, ArgsIndex)
else: else:
return type(node) return (type(node), ArgsIndex)
def _is_const_leaf(node): def _is_const_leaf(node):
......
...@@ -6,12 +6,15 @@ ...@@ -6,12 +6,15 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import builtins
import collections import collections
import copy import copy
import fnmatch
import functools import functools
import inspect import keyword
import re
import weakref import weakref
from inspect import getmembers, isclass, ismethod from inspect import getcallargs, getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
from ... import functional as F from ... import functional as F
...@@ -41,11 +44,19 @@ from .module_tracer import ( ...@@ -41,11 +44,19 @@ from .module_tracer import (
set_active_module_tracer, set_active_module_tracer,
) )
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten from .pytree import ArgsIndex, tree_flatten
logger = get_logger(__name__) logger = get_logger(__name__)
def _is_builtin_name(name: str) -> bool:
return (
name in builtins.__dict__
or name in keyword.kwlist
or name in {"inf", "nan", "NoneType"}
)
def _is_leaf(node): def _is_leaf(node):
assert isinstance(node, RawTensor), "doesn't support {} in return values".format( assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node) type(node)
...@@ -67,6 +78,7 @@ class _InsertExprs: ...@@ -67,6 +78,7 @@ class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True): def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
self.graph = graph self.graph = graph
self.global_scope = InternalGraph() self.global_scope = InternalGraph()
self.global_scope._used_names.update(graph._used_names)
self.expr = expr self.expr = expr
self.after = after self.after = after
...@@ -91,6 +103,7 @@ class _InsertExprs: ...@@ -91,6 +103,7 @@ class _InsertExprs:
for expr in self.global_scope._exprs: for expr in self.global_scope._exprs:
self.graph._exprs.insert(index, expr) self.graph._exprs.insert(index, expr)
index += 1 index += 1
self.graph._used_names.update(self.global_scope._used_names)
class InternalGraph: class InternalGraph:
...@@ -107,17 +120,37 @@ class InternalGraph: ...@@ -107,17 +120,37 @@ class InternalGraph:
_inputs = None # type: List[Node] _inputs = None # type: List[Node]
_outputs = None # type: List[Node] _outputs = None # type: List[Node]
def __init__(self): def __init__(self, name: str = None, prefix_name: str = ""):
self._exprs = [] self._exprs = []
self._inputs = [] self._inputs = []
self._outputs = [] self._outputs = []
self._watch_point = [] self._watch_point = []
self._end_point = [] self._end_point = []
self._used_names = {}
self._rst = collections.defaultdict(list) self._rst = collections.defaultdict(list)
self._name = name
self._prefix_name = prefix_name
def insert(self, expr): def insert(self, expr):
self._exprs.append(expr) self._exprs.append(expr)
def _create_unique_name(self, name: str) -> str:
assert isinstance(name, str)
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
if name[0].isdigit():
name = "_{}".format(name)
while name in self._used_names or _is_builtin_name(name):
match = re.match(r"(.*)_(\d+)$", name)
if match is None:
name = name + "_1"
else:
base, num = match.group(1, 2)
name = "{}_{}".format(base, int(num) + 1)
self._used_names.setdefault(name)
return name
@property @property
def inputs(self): def inputs(self):
return self._inputs return self._inputs
...@@ -150,13 +183,16 @@ class InternalGraph: ...@@ -150,13 +183,16 @@ class InternalGraph:
def get_node_by_id(self, node_id: List[int] = None): def get_node_by_id(self, node_id: List[int] = None):
return self.node_filter.node_id(node_id) return self.node_filter.node_id(node_id)
def get_node_by_name(self, name: str = None, ignorecase: bool = True):
return self.node_filter.name(name, ignorecase)
def add_input(self, i): def add_input(self, i):
self._inputs.append(i) self._inputs.append(i)
def add_output(self, o): def add_output(self, o):
self._outputs.append(o) self._outputs.append(o)
def _replace_inputs_outputs(self, repl_dict): def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""):
for node, repl_node in repl_dict.items(): for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs assert node in self._inputs or node in self._outputs
...@@ -175,13 +211,29 @@ class InternalGraph: ...@@ -175,13 +211,29 @@ class InternalGraph:
for expr in self._exprs: for expr in self._exprs:
for idx, i in enumerate(expr.inputs): for idx, i in enumerate(expr.inputs):
assert i._name is not None
if i in repl_dict: if i in repl_dict:
expr.inputs[idx] = repl_dict[i] expr.inputs[idx] = repl_dict[i]
elif isinstance(i, TensorNode) and prefix_name not in i._name:
if i.top_graph != active_module_tracer().current_scope():
i._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name + i._name.lstrip("_"))
)
for idx, o in enumerate(expr.outputs): for idx, o in enumerate(expr.outputs):
assert o._name is not None
if o in repl_dict: if o in repl_dict:
expr.outputs[idx] = repl_dict[o] expr.outputs[idx] = repl_dict[o]
expr.outputs[idx].expr = expr expr.outputs[idx].expr = expr
elif isinstance(o, TensorNode) and prefix_name not in i._name:
if o.top_graph != active_module_tracer().current_scope():
o._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name + o._name.lstrip("_"))
)
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence): if not isinstance(nodes, Sequence):
...@@ -258,7 +310,7 @@ class InternalGraph: ...@@ -258,7 +310,7 @@ class InternalGraph:
# return formal_node_inputs[1:], actual_nodes # return formal_node_inputs[1:], actual_nodes
return formal_node_inputs[1:] return formal_node_inputs[1:]
def add_input_node(self, shape, dtype="float32"): def add_input_node(self, shape, dtype="float32", name="args"):
forma_mnode = self.inputs[0] forma_mnode = self.inputs[0]
actual_mnodes = forma_mnode.actual_mnode actual_mnodes = forma_mnode.actual_mnode
...@@ -271,11 +323,11 @@ class InternalGraph: ...@@ -271,11 +323,11 @@ class InternalGraph:
if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
call_nodes.append(c_expr) call_nodes.append(c_expr)
def create_node(is_input: bool = True): def create_node(name=None, is_input: bool = True):
if is_input: if is_input:
node = Input(type=TensorNode).outputs[0] node = Input(type=TensorNode, name=name).outputs[0]
else: else:
node = TensorNode(expr=None) node = TensorNode(expr=None, name=None)
node.shape = shape node.shape = shape
node.dtype = dtype node.dtype = dtype
return node return node
...@@ -286,7 +338,7 @@ class InternalGraph: ...@@ -286,7 +338,7 @@ class InternalGraph:
org_argdef = call_nodes[0].arg_def org_argdef = call_nodes[0].arg_def
args, kwargs = org_argdef.unflatten(self._inputs) args, kwargs = org_argdef.unflatten(self._inputs)
formal_inp_node = create_node(True) formal_inp_node = create_node(self._create_unique_name(name), True)
inputs, tree_def = tree_flatten( inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs), ((*args, formal_inp_node), kwargs),
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
...@@ -524,11 +576,21 @@ class InternalGraph: ...@@ -524,11 +576,21 @@ class InternalGraph:
return self.interpret(*inp) return self.interpret(*inp)
def __repr__(self): def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( return self.__format__()
def __format__(self, format_spec: str = "") -> str:
saved_format_spec = Node.set_format_spec(format_spec)
name = ""
if self._name:
name = "%s.Graph" % self._name
res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
name,
", ".join(str(i) for i in self._inputs), ", ".join(str(i) for i in self._inputs),
"\n\t".join("{}".format(str(i)) for i in self._exprs), "\n\t".join("{}".format(str(i)) for i in self._exprs),
", ".join(str(i) for i in self._outputs), ", ".join(str(i) for i in self._outputs),
) )
Node.set_format_spec(saved_format_spec)
return res
def _get_meth_name(obj, func): def _get_meth_name(obj, func):
...@@ -621,6 +683,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -621,6 +683,7 @@ class TracedModuleBuilder(NodeMixin):
self._is_builtin = module_tracer.is_builtin(mod) self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {} self._argdef_graph_map = {}
self._argdef_outdef_map = {} self._argdef_outdef_map = {}
self.nodes = set() self.nodes = set()
# The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
# modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
...@@ -631,7 +694,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -631,7 +694,7 @@ class TracedModuleBuilder(NodeMixin):
) )
def build(self): def build(self):
if self._is_builtin: if self._is_builtin or isinstance(self._mod, TracedModule):
for node in self.nodes: for node in self.nodes:
node.module_type = type(self._mod) node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod) # node._owner = weakref.ref(self._mod)
...@@ -671,21 +734,38 @@ class TracedModuleBuilder(NodeMixin): ...@@ -671,21 +734,38 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def callnode.arg_def = tree_def
if self._is_builtin: if (
self._is_builtin
or tree_def in self._argdef_graph_map
or isinstance(self._mod, TracedModule)
):
unset_module_tracing() unset_module_tracing()
rst = self._mod(*args, **kwargs) rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
set_module_tracing() set_module_tracing()
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
elif tree_def in self._argdef_graph_map:
self._body = self._argdef_graph_map[tree_def]
else:
self._mod._is_top = False
self._body = self._mod.graph
name = NodeMixin.get(self)._name
if name:
self._body._name = name
else: else:
self_node = None self_node = None
if tree_def in self._argdef_graph_map: orig_self = NodeMixin.get(self)
self_node = self._argdef_graph_map[tree_def].inputs[0] top_graph = active_module_tracer().current_scope()
self._body = InternalGraph() graph_prefix_name = top_graph._name
if top_graph._prefix_name:
graph_prefix_name = "{}_{}".format(
top_graph._prefix_name, graph_prefix_name.lstrip("_")
)
self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name)
active_module_tracer().push_scope(self._body) active_module_tracer().push_scope(self._body)
# rebind self to new input node # rebind self to new input node
orig_self = NodeMixin.get(self)
if self_node: if self_node:
NodeMixin.wrap_safe(self, self_node) NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node) active_module_tracer().current_scope().add_input(self_node)
...@@ -698,16 +778,37 @@ class TracedModuleBuilder(NodeMixin): ...@@ -698,16 +778,37 @@ class TracedModuleBuilder(NodeMixin):
) )
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x): index_args, index_kwargs = tree_def.unflatten(
[
ArgsIndex(0),
*list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
]
)
key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
idx2key = {}
for k, v in key2idx.items():
if isinstance(v, ArgsIndex):
idx2key[v.index] = k
else:
flatten_argidx, _ = tree_flatten(v)
for _i, v in enumerate(flatten_argidx):
if isinstance(v, ArgsIndex):
idx2key[v.index] = k + "_%d" % _i
def wrap(x, name):
if isinstance(x, (RawTensor, NodeMixin)): if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap( NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)), x,
lambda: Input.make(
type=NodeMixin.get_wrapped_type(x), name=name
),
) )
return x return x
args = [self] args = [self]
for i in inputs[1:]: for i, v in enumerate(inputs[1:]):
args.append(wrap(i)) args.append(wrap(v, idx2key[i + 1]))
args, kwargs = tree_def.unflatten(args) args, kwargs = tree_def.unflatten(args)
active_module_tracer().patcher.auto_patch( active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
...@@ -857,6 +958,9 @@ class NodeFilter(BaseFilter): ...@@ -857,6 +958,9 @@ class NodeFilter(BaseFilter):
def node_id(self, node_id: List[int]): def node_id(self, node_id: List[int]):
return NodeFilterNodeId(self, node_id) return NodeFilterNodeId(self, node_id)
def name(self, name: str, ignorecase: bool = True):
return NodeFilterName(self, name, ignorecase)
class NodeFilterType(NodeFilter): class NodeFilterType(NodeFilter):
def __init__(self, expr_iter, owner_type, node_type): def __init__(self, expr_iter, owner_type, node_type):
...@@ -887,6 +991,33 @@ class NodeFilterNodeId(NodeFilter): ...@@ -887,6 +991,33 @@ class NodeFilterNodeId(NodeFilter):
yield node yield node
class NodeFilterName(NodeFilter):
_re = None
def __init__(self, node_iter, pattern, ignorecase):
super().__init__(node_iter)
self.pattern = pattern
self._re = self.make_re(pattern, ignorecase)
@classmethod
def make_re(cls, pattern, ignorecase=True):
assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
assert isinstance(ignorecase, bool)
flags = 0
if ignorecase:
flags |= re.IGNORECASE
return re.compile(fnmatch.translate(pattern), flags=flags)
def __iter__(self):
for i in self._iter:
graph = i.top_graph
name = "{}_{}".format(graph._name, i._name.lstrip("_"))
if graph._prefix_name:
name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
if self.pattern == name or self._re.match(name):
yield i
class ExprFilterCallFunction(ExprFilter): class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None): def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter) super().__init__(expr_iter)
...@@ -1052,12 +1183,29 @@ class TracedModule(Module): ...@@ -1052,12 +1183,29 @@ class TracedModule(Module):
:return: :class:`TracedModule` :return: :class:`TracedModule`
""" """
new_module = copy.deepcopy(self) new_module = copy.deepcopy(self)
module2name = {}
def _flatten_subgraph(graph, module, call=None): assert active_module_tracer() is None
set_active_module_tracer(module_tracer(lambda x: x))
active_module_tracer().push_scope(new_module.graph)
for n, m in new_module.named_modules():
module2name[id(m)] = n
def _flatten_subgraph(
graph: InternalGraph, module: Module, call=None, prefix_name=""
):
if graph is not None and prefix_name and prefix_name[-1] != "_":
prefix_name += "_"
if graph is None: if graph is None:
assert not isinstance(module, TracedModule) assert not isinstance(module, TracedModule)
const = Constant(module) const = Constant(module, "self.%s" % module2name[id(module)])
const.outputs[0] = call.inputs[0] m_node = call.inputs[0]
if m_node.top_graph != active_module_tracer().current_scope():
m_node._name = (
active_module_tracer()
.current_scope()
._create_unique_name(prefix_name)
)
const.outputs[0] = m_node
const.outputs[0].expr = const const.outputs[0].expr = const
return [const, call] return [const, call]
if call is not None: if call is not None:
...@@ -1083,7 +1231,7 @@ class TracedModule(Module): ...@@ -1083,7 +1231,7 @@ class TracedModule(Module):
continue continue
repl_dict[out] = call.outputs[ind] repl_dict[out] = call.outputs[ind]
graph._replace_inputs_outputs(repl_dict) graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name)
for expr in graph._exprs: for expr in graph._exprs:
if isinstance(expr, GetAttr): if isinstance(expr, GetAttr):
...@@ -1109,7 +1257,14 @@ class TracedModule(Module): ...@@ -1109,7 +1257,14 @@ class TracedModule(Module):
if hasattr(obj, "argdef_graph_map") if hasattr(obj, "argdef_graph_map")
else None else None
) )
exprs.extend(_flatten_subgraph(expr_graph, obj, expr)) exprs.extend(
_flatten_subgraph(
expr_graph,
obj,
expr,
prefix_name + obj_node._name.lstrip("_"),
)
)
else: else:
# module has been replaced. # module has been replaced.
assert isinstance(pre_expr, Constant) assert isinstance(pre_expr, Constant)
...@@ -1126,7 +1281,18 @@ class TracedModule(Module): ...@@ -1126,7 +1281,18 @@ class TracedModule(Module):
return exprs return exprs
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
new_module.graph.compile()
set_active_module_tracer(None)
for _id, expr in enumerate(new_module.graph._exprs):
expr._id = _id
total_node_id = 0
for i in new_module.graph._inputs:
i._id = total_node_id
total_node_id += 1
for expr in new_module.graph._exprs:
for o in expr.outputs:
o._id = total_node_id
total_node_id += 1
return new_module return new_module
def __getstate__(self): def __getstate__(self):
...@@ -1149,19 +1315,7 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: ...@@ -1149,19 +1315,7 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
module_tracer.register_as_builtin(mod_cls) module_tracer.register_as_builtin(mod_cls)
def wrap(func: Union[Callable]): wrap = _wrapped_function
assert callable(func)
if hasattr(func, "__code__"):
assert not isinstance(func, str)
fn_name = func.__code__.co_name
currentframe = inspect.currentframe()
assert currentframe is not None
f = currentframe.f_back
assert f is not None
if f.f_code.co_name != "<module>":
raise NotImplementedError("wrap must be called at the top level of a module")
Patcher._builtin_functions.append((f.f_globals, fn_name))
return func
def _register_all_builtin_module(): def _register_all_builtin_module():
...@@ -1192,11 +1346,11 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: ...@@ -1192,11 +1346,11 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
set_active_module_tracer(module_tracer(_wrapped_function)) set_active_module_tracer(module_tracer(_wrapped_function))
with active_module_tracer().patcher: with active_module_tracer().patcher:
global_scope = InternalGraph() global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope) active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod, True) builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) name = mod._name if mod._name else mod.__class__.__name__
NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode))
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support " # assert isinstance(i, Tensor), "not support "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册