提交 c7e730bc 编写于 作者: M Megvii Engine Team

feat(traced_module): add some functions of graph modification

GitOrigin-RevId: 09691ebd334072f822226125acb11cebdc218618
上级 f88bd3ae
......@@ -13,6 +13,8 @@ from .traced_module import (
cpp_apply_module_trace,
register_as_builtin,
trace_module,
wrap,
wrap_tensors,
)
_register_all_builtin_module()
......
......@@ -11,7 +11,7 @@ import builtins
import collections
import copy
import inspect
from typing import Callable, List
from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
......@@ -29,10 +29,24 @@ class Expr:
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
"""
__total_id = 0
inputs = None # type: List[Node]
outputs = None # type: List[Node]
const_val = None # type: List[Any]
arg_def = None # type: TreeDef
out_def = None # type: TreeDef
_top_graph = None # type: weakref.ReferenceType
def __init__(self) -> None:
self._id = Expr.__total_id
Expr.__total_id += 1
self._disable_remove = False
def enable_remove(self):
self._disable_remove = False
def disable_remove(self):
self._disable_remove = True
def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence):
......@@ -70,6 +84,22 @@ class Expr:
else:
return inputs, {}
def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
assert node in nodes
index = nodes.index(node)
nodes[index] = repl_node
repl_node.users.append(self)
node.users.pop(self)
def replace_inputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.inputs)
def replace_outputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.outputs)
@property
def kwargs(self):
_, kwargs = self.unflatten_args(self.inputs)
......@@ -80,12 +110,19 @@ class Expr:
args, _ = self.unflatten_args(self.inputs)
return args
@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None
# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
name = None
def __init__(self, name=None, type=None):
super().__init__()
self.inputs = []
node_cls = type if type else Node
self.outputs = [
......@@ -100,7 +137,7 @@ class Input(Expr):
return expr.outputs[0]
def __repr__(self):
return "{} = Input({})".format(self.outputs[0], self.name)
return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name)
# expr: outputs = getattr(inputs[0], self.name)
......@@ -108,6 +145,7 @@ class GetAttr(Expr):
name = None
def __init__(self, module, name, type=None):
super().__init__()
assert isinstance(module, ModuleNode)
self.inputs = [
module,
......@@ -130,14 +168,15 @@ class GetAttr(Expr):
return (getattr(inputs[0], self.name),)
def __repr__(self):
return '{} = GetAttr({}, "{}")'.format(
self.outputs[0], self.inputs[0], self.name
return '%{}: {} = GetAttr({}, "{}")'.format(
self._id, self.outputs[0], self.inputs[0], self.name
)
# expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr):
def __init__(self, node, method="__call__"):
super().__init__()
if isinstance(node, type):
assert issubclass(node, Tensor)
cls = Parameter if issubclass(node, Parameter) else Tensor
......@@ -178,6 +217,8 @@ class CallMethod(Expr):
if inspect.ismethod(meth):
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if self.method == "__setitem__":
outputs = obj
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
......@@ -186,8 +227,12 @@ class CallMethod(Expr):
def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs),
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}.{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.args[0],
self.method,
", ".join([args, kwargs]),
......@@ -199,6 +244,7 @@ class Apply(Expr):
opdef = None
def __init__(self, opdef):
super().__init__()
assert isinstance(opdef, OpDef)
self.opdef = opdef
self.inputs = []
......@@ -213,7 +259,8 @@ class Apply(Expr):
return apply(self.opdef, *inputs)
def __repr__(self):
return "{} = {}({})".format(
return "%{}: {} = {}({})".format(
self._id,
", ".join(str(i) for i in self.outputs),
self.opdef,
", ".join(str(i) for i in self.inputs),
......@@ -241,6 +288,7 @@ class Apply(Expr):
class CallFunction(Expr):
def __init__(self, func):
super().__init__()
assert isinstance(func, Callable)
self.func = func
self.const_val = []
......@@ -255,16 +303,20 @@ class CallFunction(Expr):
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs)
outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
)
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs
def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
outputs = self.outputs
if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.func.__module__ + "." + self.func.__name__,
", ".join([args, kwargs]),
)
......@@ -277,6 +329,7 @@ class Constant(Expr):
_constant_cache = {}
def __init__(self, c):
super().__init__()
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
assert module_tracer.is_builtin(c)
......@@ -299,7 +352,9 @@ class Constant(Expr):
return (self.value,)
def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], type(self.value))
return "%{}: {} = Constant({})".format(
self._id, self.outputs[0], type(self.value)
)
def __getstate__(self):
state = self.__dict__.copy()
......
......@@ -30,6 +30,7 @@ class Node:
__total_id = 0
_id = None
_name = None
_top_graph = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None):
self.expr = expr
......@@ -48,6 +49,12 @@ class Node:
else:
return "%{}".format(self._name)
@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None
class ModuleNode(Node):
"""
......@@ -64,21 +71,28 @@ class ModuleNode(Node):
def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.actual_mnode = []
def __repr__(self):
if self._name is None:
return "%{}({})".format(self._id, self.module_type.__name__)
return "%{}_({})".format(self._id, self.module_type.__name__)
else:
return "%{}({})".format(self._name, self.module_type.__name__)
return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)
def __getstate__(self):
d = self.__dict__
d.pop("_owner", None)
return d
return {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"_name": self._name,
"module_type": self.module_type,
}
@property
def owner(self):
return self._owner()
if self._owner:
return self._owner()
return None
class TensorNode(Node):
......@@ -91,9 +105,9 @@ class TensorNode(Node):
def __repr__(self):
if self._name is None:
return "%{}(Tensor)".format(self._id)
return "%{}_(Tensor)".format(self._id)
else:
return "%{}(Tensor)".format(self._name)
return "%{}_{}(Tensor)".format(self._id, self._name)
class NodeMixin(abc.ABC):
......
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from collections import OrderedDict
from typing import Callable, NamedTuple
import numpy as np
......@@ -34,9 +35,25 @@ def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps))
def _ordereddict_flatten(inp):
aux_data = []
results = []
for key, value in inp.items():
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)
def _ordereddict_unflatten(inps, aux_data):
return OrderedDict(zip(aux_data, inps))
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten
)
register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
......@@ -99,6 +116,12 @@ class TreeDef:
)
)
def __lt__(self, other):
return self.__hash__() < other.__hash__()
def __gt__(self, other):
return self.__hash__() > other.__hash__()
def __eq__(self, other):
return (
self.type == other.type
......
......@@ -57,16 +57,16 @@ def _init_module():
def test_search():
traced_module, *_ = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
relu_expr = graph.get_function_by_type(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
def test_insert():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_node = graph.get_call_function(F.relu).as_unique().outputs
neg_node = graph.insert_call_function(F.neg, relu_node)
graph.replace_node({relu_node[0]: neg_node[0]})
relu_node = graph.get_function_by_type(F.relu).as_unique().outputs
neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node)
graph.replace_node({relu_node[0]: neg_node})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
......@@ -74,7 +74,7 @@ def test_insert():
def test_delete():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
relu_expr = graph.get_function_by_type(F.relu).as_unique()
node = relu_expr.outputs
repl_node = relu_expr.inputs
graph.replace_node({node[0]: repl_node[0]})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册