提交 a3f9073c 编写于 作者: M Megvii Engine Team

feat(traced_module): update graph transform and add _module_name

GitOrigin-RevId: ef63ae0fd0dcdd69c3566e19f8a34d85422a1e1e
上级 b3d0affa
...@@ -14,7 +14,6 @@ from .traced_module import ( ...@@ -14,7 +14,6 @@ from .traced_module import (
register_as_builtin, register_as_builtin,
trace_module, trace_module,
wrap, wrap,
wrap_tensors,
) )
_register_all_builtin_module() _register_all_builtin_module()
......
...@@ -33,17 +33,6 @@ def rstrip(s: str, __chars: str): ...@@ -33,17 +33,6 @@ def rstrip(s: str, __chars: str):
return s return s
def lstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?:%s)+(?P<right>.*)$" % __chars, "\g<right>", s)
return s
def strip(s: str, __chars: str):
s = lstrip(rstrip(s, __chars), __chars)
return s
class Expr: class Expr:
""" """
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
...@@ -89,27 +78,40 @@ class Expr: ...@@ -89,27 +78,40 @@ class Expr:
outputs = (outputs,) outputs = (outputs,)
name = None name = None
orig_name = None
if isinstance(self, CallMethod): if isinstance(self, CallMethod):
name = self.inputs[0]._name name = self.inputs[0]._name
assert name is not None orig_name = self.inputs[0]._orig_name
assert isinstance(name, str), "The name of ({}) must be a str".format(
self.inputs[0]
)
assert isinstance(
orig_name, str
), "The orig_name of ({}) must be a str".format(self.inputs[0])
name = rstrip(name, "_out") name = rstrip(name, "_out")
if self.method == "__call__": if self.method == "__call__":
name += "_out" name += "_out"
orig_name += "_out"
else: else:
strip_method = strip(self.method, "_") strip_method = self.method.strip("_")
name = "%s_out" % strip_method name = "%s_out" % strip_method
orig_name = name
elif isinstance(self, CallFunction): elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out" name = self.func.__name__ + "_out"
elif isinstance(self, Apply): elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out" name = str(self.opdef).lower() + "_out"
for i in outputs: for i in outputs:
assert isinstance(i, RawTensor) assert isinstance(i, RawTensor), "The output must be a Tensor"
o_name = ( o_name = (
active_module_tracer().current_scope()._create_unique_name(name) active_module_tracer().current_scope()._create_unique_name(name)
) )
self.outputs.append( self.outputs.append(
NodeMixin.get_wrapped_type(i)(expr=self, name=o_name) NodeMixin.get_wrapped_type(i)(
expr=self,
name=o_name,
orig_name=orig_name if orig_name else o_name,
)
) )
for i, node in zip(outputs, self.outputs,): for i, node in zip(outputs, self.outputs,):
...@@ -125,21 +127,26 @@ class Expr: ...@@ -125,21 +127,26 @@ class Expr:
else: else:
return inputs, {} return inputs, {}
def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]): def replace_inputs(self, repl_dict: Dict[Node, Node]):
while repl_dict: while repl_dict:
node, repl_node = repl_dict.popitem() node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node) assert type(node) == type(repl_node)
assert node in nodes assert node in self.inputs, "({}) is not in the ({})".format(node, self)
index = nodes.index(node) assert (
nodes[index] = repl_node repl_node.top_graph == node.top_graph
), "({}) and ({}) are not in the same graph".format(node, repl_node)
graph = self.top_graph
repl_expr_idx = graph._exprs.index(repl_node.expr)
self_idx = graph._exprs.index(self)
assert (
repl_expr_idx < self_idx
), "({}) must be generated before ({})".format(repl_node, self)
idx = self.inputs.index(node)
self.inputs[idx] = repl_node
user_idx = node.users.index(self)
assert user_idx >= 0
node.users.pop(user_idx)
repl_node.users.append(self) repl_node.users.append(self)
node.users.pop(self)
def replace_inputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.inputs)
def replace_outputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.outputs)
@property @property
def kwargs(self): def kwargs(self):
...@@ -159,7 +166,8 @@ class Expr: ...@@ -159,7 +166,8 @@ class Expr:
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state.pop("_top_graph", None) if "_top_graph" in state:
state.pop("_top_graph")
return state return state
...@@ -167,12 +175,14 @@ class Expr: ...@@ -167,12 +175,14 @@ class Expr:
class Input(Expr): class Input(Expr):
name = None name = None
def __init__(self, name=None, type=None): def __init__(self, name=None, type=None, orig_name=None):
super().__init__() super().__init__()
self.inputs = [] self.inputs = []
node_cls = type if type else Node node_cls = type if type else Node
if orig_name is None:
orig_name = name
self.outputs = [ self.outputs = [
node_cls(self, name=name), node_cls(self, name=name, orig_name=orig_name),
] ]
self.name = name self.name = name
...@@ -184,7 +194,7 @@ class Input(Expr): ...@@ -184,7 +194,7 @@ class Input(Expr):
active_module_tracer().current_scope()._create_unique_name(oup_node._name) active_module_tracer().current_scope()._create_unique_name(oup_node._name)
) )
oup_node._name = name oup_node._name = name
active_module_tracer().current_scope().add_input(oup_node) active_module_tracer().current_scope()._add_input(oup_node)
return expr.outputs[0] return expr.outputs[0]
def __repr__(self): def __repr__(self):
...@@ -195,7 +205,7 @@ class Input(Expr): ...@@ -195,7 +205,7 @@ class Input(Expr):
class GetAttr(Expr): class GetAttr(Expr):
name = None name = None
def __init__(self, module, name, type=None): def __init__(self, module, name, type=None, orig_name=None):
super().__init__() super().__init__()
assert isinstance(module, ModuleNode) assert isinstance(module, ModuleNode)
self.inputs = [ self.inputs = [
...@@ -205,7 +215,7 @@ class GetAttr(Expr): ...@@ -205,7 +215,7 @@ class GetAttr(Expr):
self.name = name self.name = name
node_cls = type if type else Node node_cls = type if type else Node
self.outputs = [ self.outputs = [
node_cls(self, name=name), node_cls(self, name=name, orig_name=orig_name),
] ]
@classmethod @classmethod
...@@ -218,7 +228,7 @@ class GetAttr(Expr): ...@@ -218,7 +228,7 @@ class GetAttr(Expr):
module = module.expr.inputs[0] module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name) oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name expr.outputs[0]._name = oup_name
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope()._insert(expr)
return expr.outputs[0] return expr.outputs[0]
def interpret(self, *inputs): def interpret(self, *inputs):
...@@ -255,7 +265,7 @@ class CallMethod(Expr): ...@@ -255,7 +265,7 @@ class CallMethod(Expr):
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope()._insert(expr)
return expr return expr
@property @property
...@@ -315,7 +325,7 @@ class Apply(Expr): ...@@ -315,7 +325,7 @@ class Apply(Expr):
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope()._insert(expr)
return expr return expr
def interpret(self, *inputs): def interpret(self, *inputs):
...@@ -382,7 +392,7 @@ class CallFunction(Expr): ...@@ -382,7 +392,7 @@ class CallFunction(Expr):
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr) active_module_tracer().current_scope()._insert(expr)
return expr return expr
def interpret(self, *inputs): def interpret(self, *inputs):
...@@ -423,7 +433,7 @@ class Constant(Expr): ...@@ -423,7 +433,7 @@ class Constant(Expr):
self.inputs = [] self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c) node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [ self.outputs = [
node_cls(self, name=name), node_cls(self, name=name, orig_name=name),
] ]
self.outputs[0]._name = name if name else "const_" + str(self._id) self.outputs[0]._name = name if name else "const_" + str(self._id)
...@@ -431,9 +441,23 @@ class Constant(Expr): ...@@ -431,9 +441,23 @@ class Constant(Expr):
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor" name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
name = active_module_tracer().current_scope()._create_unique_name(name) full_name = name
if (
isinstance(expr.value, RawTensor)
and id(expr.value) in active_module_tracer().id2name
):
full_name = active_module_tracer().id2name[id(expr.value)]
scope_name = active_module_tracer().current_scope()._module_name
if full_name and scope_name:
full_name = ("self." + full_name)[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
name = active_module_tracer().current_scope()._create_unique_name(full_name)
expr.outputs[0]._name = name expr.outputs[0]._name = name
active_module_tracer().current_scope().insert(expr) expr.outputs[0]._orig_name = full_name
active_module_tracer().current_scope()._insert(expr)
return expr.outputs[0] return expr.outputs[0]
def interpret(self, *inputs): def interpret(self, *inputs):
...@@ -453,7 +477,9 @@ class Constant(Expr): ...@@ -453,7 +477,9 @@ class Constant(Expr):
) )
def __getstate__(self): def __getstate__(self):
state = super().__getstate__() state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
if isinstance(self.value, RawTensor): if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value) state["value"] = Tensor(self.value)
return state return state
...@@ -84,6 +84,34 @@ BUILTIN_ARRAY_METHOD = [ ...@@ -84,6 +84,34 @@ BUILTIN_ARRAY_METHOD = [
"__setitem__", "__setitem__",
] ]
BUILTIN_TENSOR_WRAP_METHOD = [
"T",
"to",
"size",
"shape",
"detach",
"device",
"dtype",
"grad",
"item",
"name",
"ndim",
"numpy",
"qparams",
"set_value",
"reset_zero",
"requires_grad",
"_reset",
"_isscalar",
"_setscalar",
"_tuple_shape",
"_unsetscalar",
]
def get_tensor_wrapable_method():
return BUILTIN_TENSOR_WRAP_METHOD + BUILTIN_ARRAY_METHOD
def active_module_tracer(): def active_module_tracer():
return _active_module_tracer return _active_module_tracer
...@@ -101,9 +129,10 @@ class module_tracer: ...@@ -101,9 +129,10 @@ class module_tracer:
_active_scopes = None _active_scopes = None
def __init__(self, wrap_fn): def __init__(self, wrap_fn, id2name):
self._active_scopes = [] self._active_scopes = []
self.patcher = Patcher(wrap_fn) self.patcher = Patcher(wrap_fn)
self.id2name = id2name
@classmethod @classmethod
def register_as_builtin(cls, mod): def register_as_builtin(cls, mod):
...@@ -127,6 +156,10 @@ class module_tracer: ...@@ -127,6 +156,10 @@ class module_tracer:
return None return None
class NotExist:
pass
class PatchedFn: class PatchedFn:
frame_dict = None frame_dict = None
name = None name = None
...@@ -138,14 +171,17 @@ class PatchedFn: ...@@ -138,14 +171,17 @@ class PatchedFn:
self.origin_fn = ( self.origin_fn = (
self.frame_dict[name] self.frame_dict[name]
if isinstance(frame_dict, collections.abc.Mapping) if isinstance(frame_dict, collections.abc.Mapping)
else getattr(frame_dict, name) else getattr(frame_dict, name, NotExist)
) )
def set_func(self, func): def set_func(self, func):
if isinstance(self.frame_dict, collections.abc.Mapping): if isinstance(self.frame_dict, collections.abc.Mapping):
self.frame_dict[self.name] = func self.frame_dict[self.name] = func
else: else:
setattr(self.frame_dict, self.name, func) if func is not NotExist:
setattr(self.frame_dict, self.name, func)
else:
delattr(self.frame_dict, self.name)
class Patcher: class Patcher:
......
...@@ -30,14 +30,17 @@ class Node: ...@@ -30,14 +30,17 @@ class Node:
_id = None _id = None
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
_name = None _name = None
_orig_name = None
_format_spec = "" _format_spec = ""
def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
self.expr = expr self.expr = expr
self.users = [] # List[Expr] self.users = [] # List[Expr]
self._id = Node.__total_id self._id = Node.__total_id
Node.__total_id += 1 Node.__total_id += 1
self._name = name self._name = name
self._orig_name = orig_name
self.actual_node = [] # type: List[Node]
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__ = d
...@@ -48,7 +51,7 @@ class Node: ...@@ -48,7 +51,7 @@ class Node:
return self.__format__(format_spec) return self.__format__(format_spec)
def __format__(self, format_spec: str) -> str: def __format__(self, format_spec: str) -> str:
if format_spec == "" or format_spec is None: if not format_spec:
format_spec = Node._format_spec format_spec = Node._format_spec
name = self._name name = self._name
if name is None: if name is None:
...@@ -100,9 +103,8 @@ class ModuleNode(Node): ...@@ -100,9 +103,8 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
_owner = None # type: weakref.ReferenceType _owner = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None): def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
super().__init__(expr, name) super().__init__(expr, name, orig_name)
self.actual_mnode = []
def __getstate__(self): def __getstate__(self):
return { return {
...@@ -110,6 +112,7 @@ class ModuleNode(Node): ...@@ -110,6 +112,7 @@ class ModuleNode(Node):
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
"_name": self._name, "_name": self._name,
"_orig_name": self._orig_name,
"module_type": self.module_type, "module_type": self.module_type,
} }
...@@ -125,23 +128,67 @@ class TensorNode(Node): ...@@ -125,23 +128,67 @@ class TensorNode(Node):
``TensorNode`` represents the Tensor objects. ``TensorNode`` represents the Tensor objects.
""" """
shape = None # type: Tuple[int] _shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype _dtype = None # type: numpy.dtype
qparams = None _qparams = None
device = None _device = None
_value = None # type: Tensor
def __getstate__(self): def __getstate__(self):
return { return {
"expr": self.expr, "expr": self.expr,
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
"qparams": self.qparams, "_qparams": self._qparams,
"shape": self.shape, "_shape": self._shape,
"dtype": self.dtype, "_dtype": self._dtype,
"device": self.device, "_device": self._device,
"_name": self._name, "_name": self._name,
"_orig_name": self._orig_name,
} }
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
self._shape = shape
@property
def dtype(self):
return self._dtype
@dtype.setter
def dtype(self, dtype):
self._dtype = dtype
@property
def device(self):
return self._device
@device.setter
def device(self, device):
self._device = device
@property
def qparams(self):
return self._qparams
@qparams.setter
def qparams(self, qparams):
self._qparams = qparams
@property
def value(self):
return self._value
@value.setter
def value(self, value):
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None)
self._value = value
class NodeMixin(abc.ABC): class NodeMixin(abc.ABC):
__node = None __node = None
...@@ -156,13 +203,13 @@ class NodeMixin(abc.ABC): ...@@ -156,13 +203,13 @@ class NodeMixin(abc.ABC):
assert isinstance(node, TensorNode) assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor) assert isinstance(value, RawTensor)
if isinstance(value, RawTensor): if isinstance(value, RawTensor):
node.dtype = value.dtype node._dtype = value.dtype
node.shape = ( node._shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape value._tuple_shape if isinstance(value, Tensor) else value.shape
) )
node.device = value.device node._device = value.device
if hasattr(value, "_qparams") and value._qparams is not None: if hasattr(value, "_qparams") and value._qparams is not None:
node.qparams = value.qparams node._qparams = value.qparams
@classmethod @classmethod
def wrap(cls, value, node): def wrap(cls, value, node):
......
...@@ -133,7 +133,7 @@ def _is_leaf(obj): ...@@ -133,7 +133,7 @@ def _is_leaf(obj):
def _leaf_type(node): def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)): if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode, ArgsIndex) return (Tensor, TensorNode, ArgsIndex)
elif isinstance(node, (NodeMixin, Module)): elif isinstance(node, (NodeMixin, Module, ModuleNode)):
return (Module, ModuleNode, NodeMixin, ArgsIndex) return (Module, ModuleNode, NodeMixin, ArgsIndex)
else: else:
return (type(node), ArgsIndex) return (type(node), ArgsIndex)
......
...@@ -64,9 +64,10 @@ def test_search(): ...@@ -64,9 +64,10 @@ def test_search():
def test_insert(): def test_insert():
traced_module, x, expect = _init_block() traced_module, x, expect = _init_block()
graph = traced_module.graph graph = traced_module.graph
relu_node = graph.get_function_by_type(F.relu).as_unique().outputs relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node) with graph.insert_exprs():
graph.replace_node({relu_node[0]: neg_node}) neg_out = F.neg(relu_out)
graph.replace_node({relu_out: neg_out})
graph.compile() graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册