提交 c7e730bc 编写于 作者: M Megvii Engine Team

feat(traced_module): add some functions of graph modification

GitOrigin-RevId: 09691ebd334072f822226125acb11cebdc218618
上级 f88bd3ae
...@@ -13,6 +13,8 @@ from .traced_module import ( ...@@ -13,6 +13,8 @@ from .traced_module import (
cpp_apply_module_trace, cpp_apply_module_trace,
register_as_builtin, register_as_builtin,
trace_module, trace_module,
wrap,
wrap_tensors,
) )
_register_all_builtin_module() _register_all_builtin_module()
......
...@@ -11,7 +11,7 @@ import builtins ...@@ -11,7 +11,7 @@ import builtins
import collections import collections
import copy import copy
import inspect import inspect
from typing import Callable, List from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
...@@ -29,10 +29,24 @@ class Expr: ...@@ -29,10 +29,24 @@ class Expr:
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
""" """
__total_id = 0
inputs = None # type: List[Node] inputs = None # type: List[Node]
outputs = None # type: List[Node] outputs = None # type: List[Node]
const_val = None # type: List[Any] const_val = None # type: List[Any]
arg_def = None # type: TreeDef arg_def = None # type: TreeDef
out_def = None # type: TreeDef
_top_graph = None # type: weakref.ReferenceType
def __init__(self) -> None:
self._id = Expr.__total_id
Expr.__total_id += 1
self._disable_remove = False
def enable_remove(self):
self._disable_remove = False
def disable_remove(self):
self._disable_remove = True
def add_inputs(self, vals): def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence): if not isinstance(vals, collections.abc.Sequence):
...@@ -70,6 +84,22 @@ class Expr: ...@@ -70,6 +84,22 @@ class Expr:
else: else:
return inputs, {} return inputs, {}
def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
assert node in nodes
index = nodes.index(node)
nodes[index] = repl_node
repl_node.users.append(self)
node.users.pop(self)
def replace_inputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.inputs)
def replace_outputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.outputs)
@property @property
def kwargs(self): def kwargs(self):
_, kwargs = self.unflatten_args(self.inputs) _, kwargs = self.unflatten_args(self.inputs)
...@@ -80,12 +110,19 @@ class Expr: ...@@ -80,12 +110,19 @@ class Expr:
args, _ = self.unflatten_args(self.inputs) args, _ = self.unflatten_args(self.inputs)
return args return args
@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None
# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
name = None name = None
def __init__(self, name=None, type=None): def __init__(self, name=None, type=None):
super().__init__()
self.inputs = [] self.inputs = []
node_cls = type if type else Node node_cls = type if type else Node
self.outputs = [ self.outputs = [
...@@ -100,7 +137,7 @@ class Input(Expr): ...@@ -100,7 +137,7 @@ class Input(Expr):
return expr.outputs[0] return expr.outputs[0]
def __repr__(self): def __repr__(self):
return "{} = Input({})".format(self.outputs[0], self.name) return "%{}: {} = Input({})".format(self._id, self.outputs[0], self.name)
# expr: outputs = getattr(inputs[0], self.name) # expr: outputs = getattr(inputs[0], self.name)
...@@ -108,6 +145,7 @@ class GetAttr(Expr): ...@@ -108,6 +145,7 @@ class GetAttr(Expr):
name = None name = None
def __init__(self, module, name, type=None): def __init__(self, module, name, type=None):
super().__init__()
assert isinstance(module, ModuleNode) assert isinstance(module, ModuleNode)
self.inputs = [ self.inputs = [
module, module,
...@@ -130,14 +168,15 @@ class GetAttr(Expr): ...@@ -130,14 +168,15 @@ class GetAttr(Expr):
return (getattr(inputs[0], self.name),) return (getattr(inputs[0], self.name),)
def __repr__(self): def __repr__(self):
return '{} = GetAttr({}, "{}")'.format( return '%{}: {} = GetAttr({}, "{}")'.format(
self.outputs[0], self.inputs[0], self.name self._id, self.outputs[0], self.inputs[0], self.name
) )
# expr: outputs = inputs[0].__call__(*inputs[1:]) # expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr): class CallMethod(Expr):
def __init__(self, node, method="__call__"): def __init__(self, node, method="__call__"):
super().__init__()
if isinstance(node, type): if isinstance(node, type):
assert issubclass(node, Tensor) assert issubclass(node, Tensor)
cls = Parameter if issubclass(node, Parameter) else Tensor cls = Parameter if issubclass(node, Parameter) else Tensor
...@@ -178,6 +217,8 @@ class CallMethod(Expr): ...@@ -178,6 +217,8 @@ class CallMethod(Expr):
if inspect.ismethod(meth): if inspect.ismethod(meth):
args = args[1:] args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs) outputs = getattr(obj, self.method)(*args, **kwargs)
if self.method == "__setitem__":
outputs = obj
if outputs is None: if outputs is None:
return outputs return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor)) outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
...@@ -186,8 +227,12 @@ class CallMethod(Expr): ...@@ -186,8 +227,12 @@ class CallMethod(Expr):
def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args[1:]) args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format( outputs = self.outputs
", ".join(str(i) for i in self.outputs), if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}.{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.args[0], self.args[0],
self.method, self.method,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
...@@ -199,6 +244,7 @@ class Apply(Expr): ...@@ -199,6 +244,7 @@ class Apply(Expr):
opdef = None opdef = None
def __init__(self, opdef): def __init__(self, opdef):
super().__init__()
assert isinstance(opdef, OpDef) assert isinstance(opdef, OpDef)
self.opdef = opdef self.opdef = opdef
self.inputs = [] self.inputs = []
...@@ -213,7 +259,8 @@ class Apply(Expr): ...@@ -213,7 +259,8 @@ class Apply(Expr):
return apply(self.opdef, *inputs) return apply(self.opdef, *inputs)
def __repr__(self): def __repr__(self):
return "{} = {}({})".format( return "%{}: {} = {}({})".format(
self._id,
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.opdef, self.opdef,
", ".join(str(i) for i in self.inputs), ", ".join(str(i) for i in self.inputs),
...@@ -241,6 +288,7 @@ class Apply(Expr): ...@@ -241,6 +288,7 @@ class Apply(Expr):
class CallFunction(Expr): class CallFunction(Expr):
def __init__(self, func): def __init__(self, func):
super().__init__()
assert isinstance(func, Callable) assert isinstance(func, Callable)
self.func = func self.func = func
self.const_val = [] self.const_val = []
...@@ -255,16 +303,20 @@ class CallFunction(Expr): ...@@ -255,16 +303,20 @@ class CallFunction(Expr):
def interpret(self, *inputs): def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs) args, kwargs = self.unflatten_args(inputs)
outputs = self.func(*args, **kwargs) outputs = self.func(*args, **kwargs)
outputs = ( if outputs is None:
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) return outputs
) outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs return outputs
def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args) args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items()) kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format( outputs = self.outputs
", ".join(str(i) for i in self.outputs), if self.out_def:
outputs = self.out_def.unflatten(outputs)
return "%{}: {}{}({})".format(
self._id,
str(outputs) + " = " if outputs else "",
self.func.__module__ + "." + self.func.__name__, self.func.__module__ + "." + self.func.__name__,
", ".join([args, kwargs]), ", ".join([args, kwargs]),
) )
...@@ -277,6 +329,7 @@ class Constant(Expr): ...@@ -277,6 +329,7 @@ class Constant(Expr):
_constant_cache = {} _constant_cache = {}
def __init__(self, c): def __init__(self, c):
super().__init__()
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c) assert module_tracer.is_builtin(c)
...@@ -299,7 +352,9 @@ class Constant(Expr): ...@@ -299,7 +352,9 @@ class Constant(Expr):
return (self.value,) return (self.value,)
def __repr__(self): def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], type(self.value)) return "%{}: {} = Constant({})".format(
self._id, self.outputs[0], type(self.value)
)
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
......
...@@ -30,6 +30,7 @@ class Node: ...@@ -30,6 +30,7 @@ class Node:
__total_id = 0 __total_id = 0
_id = None _id = None
_name = None _name = None
_top_graph = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
self.expr = expr self.expr = expr
...@@ -48,6 +49,12 @@ class Node: ...@@ -48,6 +49,12 @@ class Node:
else: else:
return "%{}".format(self._name) return "%{}".format(self._name)
@property
def top_graph(self):
if self._top_graph:
return self._top_graph()
return None
class ModuleNode(Node): class ModuleNode(Node):
""" """
...@@ -64,21 +71,28 @@ class ModuleNode(Node): ...@@ -64,21 +71,28 @@ class ModuleNode(Node):
def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name) super().__init__(expr, name)
self.actual_mnode = []
def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
return "%{}({})".format(self._id, self.module_type.__name__) return "%{}_({})".format(self._id, self.module_type.__name__)
else: else:
return "%{}({})".format(self._name, self.module_type.__name__) return "%{}_{}({})".format(self._id, self._name, self.module_type.__name__)
def __getstate__(self): def __getstate__(self):
d = self.__dict__ return {
d.pop("_owner", None) "expr": self.expr,
return d "users": self.users,
"_id": self._id,
"_name": self._name,
"module_type": self.module_type,
}
@property @property
def owner(self): def owner(self):
return self._owner() if self._owner:
return self._owner()
return None
class TensorNode(Node): class TensorNode(Node):
...@@ -91,9 +105,9 @@ class TensorNode(Node): ...@@ -91,9 +105,9 @@ class TensorNode(Node):
def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
return "%{}(Tensor)".format(self._id) return "%{}_(Tensor)".format(self._id)
else: else:
return "%{}(Tensor)".format(self._name) return "%{}_{}(Tensor)".format(self._id, self._name)
class NodeMixin(abc.ABC): class NodeMixin(abc.ABC):
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
from collections import OrderedDict
from typing import Callable, NamedTuple from typing import Callable, NamedTuple
import numpy as np import numpy as np
...@@ -34,9 +35,25 @@ def _dict_unflatten(inps, aux_data): ...@@ -34,9 +35,25 @@ def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps)) return dict(zip(aux_data, inps))
def _ordereddict_flatten(inp):
aux_data = []
results = []
for key, value in inp.items():
results.append(value)
aux_data.append(key)
return results, tuple(aux_data)
def _ordereddict_unflatten(inps, aux_data):
return OrderedDict(zip(aux_data, inps))
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten) register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten
)
register_supported_type( register_supported_type(
slice, slice,
lambda x: ([x.start, x.stop, x.step], None), lambda x: ([x.start, x.stop, x.step], None),
...@@ -99,6 +116,12 @@ class TreeDef: ...@@ -99,6 +116,12 @@ class TreeDef:
) )
) )
def __lt__(self, other):
return self.__hash__() < other.__hash__()
def __gt__(self, other):
return self.__hash__() > other.__hash__()
def __eq__(self, other): def __eq__(self, other):
return ( return (
self.type == other.type self.type == other.type
......
...@@ -57,16 +57,16 @@ def _init_module(): ...@@ -57,16 +57,16 @@ def _init_module():
def test_search(): def test_search():
traced_module, *_ = _init_block() traced_module, *_ = _init_block()
graph = traced_module.graph graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique() relu_expr = graph.get_function_by_type(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
def test_insert(): def test_insert():
traced_module, x, expect = _init_block() traced_module, x, expect = _init_block()
graph = traced_module.graph graph = traced_module.graph
relu_node = graph.get_call_function(F.relu).as_unique().outputs relu_node = graph.get_function_by_type(F.relu).as_unique().outputs
neg_node = graph.insert_call_function(F.neg, relu_node) neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node)
graph.replace_node({relu_node[0]: neg_node[0]}) graph.replace_node({relu_node[0]: neg_node})
graph.compile() graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
...@@ -74,7 +74,7 @@ def test_insert(): ...@@ -74,7 +74,7 @@ def test_insert():
def test_delete(): def test_delete():
traced_module, x, expect = _init_block() traced_module, x, expect = _init_block()
graph = traced_module.graph graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique() relu_expr = graph.get_function_by_type(F.relu).as_unique()
node = relu_expr.outputs node = relu_expr.outputs
repl_node = relu_expr.inputs repl_node = relu_expr.inputs
graph.replace_node({node[0]: repl_node[0]}) graph.replace_node({node[0]: repl_node[0]})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册