feat(traced_module): update graph transform and add _module_name

GitOrigin-RevId: ef63ae0fd0dcdd69c3566e19f8a34d85422a1e1e
......@@ -14,7 +14,6 @@ from .traced_module import (
register_as_builtin,
trace_module,
wrap,
wrap_tensors,
)
_register_all_builtin_module()
......
......@@ -33,17 +33,6 @@ def rstrip(s: str, __chars: str):
return s
def lstrip(s: str, __chars: str):
__chars = re.escape(__chars)
s = re.sub(r"^(?:%s)+(?P<right>.*)$" % __chars, "\g<right>", s)
return s
def strip(s: str, __chars: str):
s = lstrip(rstrip(s, __chars), __chars)
return s
class Expr:
"""
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``.
......@@ -89,27 +78,40 @@ class Expr:
outputs = (outputs,)
name = None
orig_name = None
if isinstance(self, CallMethod):
name = self.inputs[0]._name
assert name is not None
orig_name = self.inputs[0]._orig_name
assert isinstance(name, str), "The name of ({}) must be a str".format(
self.inputs[0]
)
assert isinstance(
orig_name, str
), "The orig_name of ({}) must be a str".format(self.inputs[0])
name = rstrip(name, "_out")
if self.method == "__call__":
name += "_out"
orig_name += "_out"
else:
strip_method = strip(self.method, "_")
strip_method = self.method.strip("_")
name = "%s_out" % strip_method
orig_name = name
elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out"
elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out"
for i in outputs:
assert isinstance(i, RawTensor)
assert isinstance(i, RawTensor), "The output must be a Tensor"
o_name = (
active_module_tracer().current_scope()._create_unique_name(name)
)
self.outputs.append(
NodeMixin.get_wrapped_type(i)(expr=self, name=o_name)
NodeMixin.get_wrapped_type(i)(
expr=self,
name=o_name,
orig_name=orig_name if orig_name else o_name,
)
)
for i, node in zip(outputs, self.outputs,):
......@@ -125,21 +127,26 @@ class Expr:
else:
return inputs, {}
def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]):
def replace_inputs(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
assert node in nodes
index = nodes.index(node)
nodes[index] = repl_node
assert node in self.inputs, "({}) is not in the ({})".format(node, self)
assert (
repl_node.top_graph == node.top_graph
), "({}) and ({}) are not in the same graph".format(node, repl_node)
graph = self.top_graph
repl_expr_idx = graph._exprs.index(repl_node.expr)
self_idx = graph._exprs.index(self)
assert (
repl_expr_idx < self_idx
), "({}) must be generated before ({})".format(repl_node, self)
idx = self.inputs.index(node)
self.inputs[idx] = repl_node
user_idx = node.users.index(self)
assert user_idx >= 0
node.users.pop(user_idx)
repl_node.users.append(self)
node.users.pop(self)
def replace_inputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.inputs)
def replace_outputs(self, repl_dict: Dict[Node, Node]):
self._replace_nodes(repl_dict, self.outputs)
@property
def kwargs(self):
......@@ -159,7 +166,8 @@ class Expr:
def __getstate__(self):
state = self.__dict__.copy()
state.pop("_top_graph", None)
if "_top_graph" in state:
state.pop("_top_graph")
return state
......@@ -167,12 +175,14 @@ class Expr:
class Input(Expr):
name = None
def __init__(self, name=None, type=None):
def __init__(self, name=None, type=None, orig_name=None):
super().__init__()
self.inputs = []
node_cls = type if type else Node
if orig_name is None:
orig_name = name
self.outputs = [
node_cls(self, name=name),
node_cls(self, name=name, orig_name=orig_name),
]
self.name = name
......@@ -184,7 +194,7 @@ class Input(Expr):
active_module_tracer().current_scope()._create_unique_name(oup_node._name)
)
oup_node._name = name
active_module_tracer().current_scope().add_input(oup_node)
active_module_tracer().current_scope()._add_input(oup_node)
return expr.outputs[0]
def __repr__(self):
......@@ -195,7 +205,7 @@ class Input(Expr):
class GetAttr(Expr):
name = None
def __init__(self, module, name, type=None):
def __init__(self, module, name, type=None, orig_name=None):
super().__init__()
assert isinstance(module, ModuleNode)
self.inputs = [
......@@ -205,7 +215,7 @@ class GetAttr(Expr):
self.name = name
node_cls = type if type else Node
self.outputs = [
node_cls(self, name=name),
node_cls(self, name=name, orig_name=orig_name),
]
@classmethod
......@@ -218,7 +228,7 @@ class GetAttr(Expr):
module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr.outputs[0]
def interpret(self, *inputs):
......@@ -255,7 +265,7 @@ class CallMethod(Expr):
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr
@property
......@@ -315,7 +325,7 @@ class Apply(Expr):
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr
def interpret(self, *inputs):
......@@ -382,7 +392,7 @@ class CallFunction(Expr):
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
active_module_tracer().current_scope()._insert(expr)
return expr
def interpret(self, *inputs):
......@@ -423,7 +433,7 @@ class Constant(Expr):
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [
node_cls(self, name=name),
node_cls(self, name=name, orig_name=name),
]
self.outputs[0]._name = name if name else "const_" + str(self._id)
......@@ -431,9 +441,23 @@ class Constant(Expr):
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
name = active_module_tracer().current_scope()._create_unique_name(name)
full_name = name
if (
isinstance(expr.value, RawTensor)
and id(expr.value) in active_module_tracer().id2name
):
full_name = active_module_tracer().id2name[id(expr.value)]
scope_name = active_module_tracer().current_scope()._module_name
if full_name and scope_name:
full_name = ("self." + full_name)[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
name = active_module_tracer().current_scope()._create_unique_name(full_name)
expr.outputs[0]._name = name
active_module_tracer().current_scope().insert(expr)
expr.outputs[0]._orig_name = full_name
active_module_tracer().current_scope()._insert(expr)
return expr.outputs[0]
def interpret(self, *inputs):
......@@ -453,7 +477,9 @@ class Constant(Expr):
)
def __getstate__(self):
state = super().__getstate__()
state = self.__dict__.copy()
if "_top_graph" in state:
state.pop("_top_graph")
if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value)
return state
......@@ -84,6 +84,34 @@ BUILTIN_ARRAY_METHOD = [
"__setitem__",
]
BUILTIN_TENSOR_WRAP_METHOD = [
"T",
"to",
"size",
"shape",
"detach",
"device",
"dtype",
"grad",
"item",
"name",
"ndim",
"numpy",
"qparams",
"set_value",
"reset_zero",
"requires_grad",
"_reset",
"_isscalar",
"_setscalar",
"_tuple_shape",
"_unsetscalar",
]
def get_tensor_wrapable_method():
return BUILTIN_TENSOR_WRAP_METHOD + BUILTIN_ARRAY_METHOD
def active_module_tracer():
return _active_module_tracer
......@@ -101,9 +129,10 @@ class module_tracer:
_active_scopes = None
def __init__(self, wrap_fn):
def __init__(self, wrap_fn, id2name):
self._active_scopes = []
self.patcher = Patcher(wrap_fn)
self.id2name = id2name
@classmethod
def register_as_builtin(cls, mod):
......@@ -127,6 +156,10 @@ class module_tracer:
return None
class NotExist:
pass
class PatchedFn:
frame_dict = None
name = None
......@@ -138,14 +171,17 @@ class PatchedFn:
self.origin_fn = (
self.frame_dict[name]
if isinstance(frame_dict, collections.abc.Mapping)
else getattr(frame_dict, name)
else getattr(frame_dict, name, NotExist)
)
def set_func(self, func):
if isinstance(self.frame_dict, collections.abc.Mapping):
self.frame_dict[self.name] = func
else:
setattr(self.frame_dict, self.name, func)
if func is not NotExist:
setattr(self.frame_dict, self.name, func)
else:
delattr(self.frame_dict, self.name)
class Patcher:
......
......@@ -30,14 +30,17 @@ class Node:
_id = None
_top_graph = None # type: weakref.ReferenceType
_name = None
_orig_name = None
_format_spec = ""
def __init__(self, expr: "Expr", name: str = None):
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id
Node.__total_id += 1
self._name = name
self._orig_name = orig_name
self.actual_node = [] # type: List[Node]
def __setstate__(self, d):
self.__dict__ = d
......@@ -48,7 +51,7 @@ class Node:
return self.__format__(format_spec)
def __format__(self, format_spec: str) -> str:
if format_spec == "" or format_spec is None:
if not format_spec:
format_spec = Node._format_spec
name = self._name
if name is None:
......@@ -100,9 +103,8 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module]
_owner = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.actual_mnode = []
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
super().__init__(expr, name, orig_name)
def __getstate__(self):
return {
......@@ -110,6 +112,7 @@ class ModuleNode(Node):
"users": self.users,
"_id": self._id,
"_name": self._name,
"_orig_name": self._orig_name,
"module_type": self.module_type,
}
......@@ -125,23 +128,67 @@ class TensorNode(Node):
``TensorNode`` represents the Tensor objects.
"""
shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype
qparams = None
device = None
_shape = None # type: Tuple[int]
_dtype = None # type: numpy.dtype
_qparams = None
_device = None
_value = None # type: Tensor
def __getstate__(self):
return {
"expr": self.expr,
"users": self.users,
"_id": self._id,
"qparams": self.qparams,
"shape": self.shape,
"dtype": self.dtype,
"device": self.device,
"_qparams": self._qparams,
"_shape": self._shape,
"_dtype": self._dtype,
"_device": self._device,
"_name": self._name,
"_orig_name": self._orig_name,
}
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
self._shape = shape
@property
def dtype(self):
return self._dtype
@dtype.setter
def dtype(self, dtype):
self._dtype = dtype
@property
def device(self):
return self._device
@device.setter
def device(self, device):
self._device = device
@property
def qparams(self):
return self._qparams
@qparams.setter
def qparams(self, qparams):
self._qparams = qparams
@property
def value(self):
return self._value
@value.setter
def value(self, value):
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None)
self._value = value
class NodeMixin(abc.ABC):
__node = None
......@@ -156,13 +203,13 @@ class NodeMixin(abc.ABC):
assert isinstance(node, TensorNode)
assert isinstance(value, RawTensor)
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
node._dtype = value.dtype
node._shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
node.device = value.device
node._device = value.device
if hasattr(value, "_qparams") and value._qparams is not None:
node.qparams = value.qparams
node._qparams = value.qparams
@classmethod
def wrap(cls, value, node):
......
......@@ -133,7 +133,7 @@ def _is_leaf(obj):
def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode, ArgsIndex)
elif isinstance(node, (NodeMixin, Module)):
elif isinstance(node, (NodeMixin, Module, ModuleNode)):
return (Module, ModuleNode, NodeMixin, ArgsIndex)
else:
return (type(node), ArgsIndex)
......
......@@ -64,9 +64,10 @@ def test_search():
def test_insert():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_node = graph.get_function_by_type(F.relu).as_unique().outputs
neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node)
graph.replace_node({relu_node[0]: neg_node})
relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
with graph.insert_exprs():
neg_out = F.neg(relu_out)
graph.replace_node({relu_out: neg_out})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部