@@ -1425,6 +1769,8 @@ class ExprFilterCallMethod(ExprFilter):
classExprFilterExprId(ExprFilter):
"""See :meth:`~.InternalGraph.get_expr_by_id`"""
def__init__(self,expr_iter,expr_id:List[int]):
super().__init__(expr_iter)
ifnotisinstance(expr_id,Sequence):
...
...
@@ -1438,8 +1784,16 @@ class ExprFilterExprId(ExprFilter):
classTracedModule(Module):
r"""`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it."""
r"""``TracedModule`` is the Module created by tracing normal module.
It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule``
will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs``
and interpret it.
.. note::
``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module`
for more details.
"""
# m_node = None # type: ModuleNode
argdef_graph_map=None
argdef_outdef_map=None
...
...
@@ -1475,19 +1829,97 @@ class TracedModule(Module):
returnoutputs
defset_watch_points(self,nodes):
r"""Initialize the :attr:`~.TracedModule.watch_points`.
You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime.