未验证 提交 d41f5de5 编写于 作者: X Xiaoyu Xu 提交者: GitHub

add doc of graph (#6093)

* add doc

* add doc

* rm useless code

* add doc

* refine docstring of nn.Graph

* refine doc of graph

* add log doc

* auto format by CI

* add doctest for graph.py
Co-authored-by: NYao Chi <later@usopp.net>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
上级 4e03c3ee
oneflow.nn.Graph
================================================
Graph class for building neural networks
---------------------------------------------------
======================================================
Base class for running neural networks in Graph Mode.
------------------------------------------------------
.. currentmodule:: oneflow.nn
.. autoclass:: oneflow.nn.Graph
:members:
:members: __init__,
build,
__call__,
add_optimizer,
set_grad_scaler,
name,
training,
debug,
__repr__,
:member-order: bysource
......@@ -389,22 +389,33 @@ class Block(object):
class BlockConfig(object):
r"""Configurations on Block in nn.Graph.
"""
def __init__(self):
self._stage_id = None
self._activation_checkpointing = None
@property
def stage_id(self):
r"""Get stage id of Block in pipeline parallelism.
"""
return self._stage_id
@stage_id.setter
def stage_id(self, value: int = None):
r"""Set stage id of Block in pipeline parallelism.
"""
self._stage_id = value
@property
def activation_checkpointing(self):
r"""Get whether do activation checkpointing in this Block.
"""
return self._activation_checkpointing
@activation_checkpointing.setter
def activation_checkpointing(self, value: bool = False):
r"""Set whether do activation checkpointing in this Block.
"""
self._activation_checkpointing = value
......@@ -20,6 +20,9 @@ import oneflow._oneflow_internal.oneflow.core.job.job_conf as job_conf_cfg
class GraphConfig(object):
r"""For configuration of nn.Graph.
"""
def __init__(self):
super().__init__()
self.proto = job_conf_cfg.JobConfigProto()
......@@ -73,6 +76,11 @@ class GraphConfig(object):
self.proto.set_enable_fuse_cast_scale(mode)
def set_gradient_accumulation_steps(self, value):
"""Set num of steps to accumulate gradient.
Args:
value (int): num of steps
"""
self.proto.set_num_gradient_accumulation_steps(value)
def _generate_optimizer_and_variable_configs(
......
......@@ -36,9 +36,62 @@ from oneflow.nn.optimizer.lr_scheduler import LrScheduler
class Graph(object):
r"""Base class for training or evaluating a neural network in graph mode.
To use graph mode for model training or evaluation in OneFlow, you should:
1. Define your customized graph as a subclass of ``nn.Graph``.
2. Add ``super().__init__()`` in your subclass's ``__init__()``.
3. Add modules to your graph as regular attributes.
4. Define computation logical in ``build()`` method.
5. Instantiate your graph then call it.
.. code-block:: python
>>> import oneflow as flow
>>> class LinearGraph(flow.nn.Graph):
... def __init__(self):
... super().__init__()
... # Add a module to the graph.
... self.linear = flow.nn.Linear(3, 8, False)
... def build(self, x):
... # Use the module to build the computation logic of the graph.
... return self.linear(x)
# Instantiate the graph
>>> linear_graph = LinearGraph()
>>> x = flow.randn(4, 3)
# First call on graph will run graph's build() method to
# trace a computatioin graph. Then the computation graph will be
# optimized and executed for the first time.
>>> linear_graph(x).shape
flow.Size([4, 8])
# Later call on graph will execute the computation graph directly.
>>> linear_graph(x).shape
flow.Size([4, 8])
Note that Graph cannot be nested at the moment.
"""
_child_init_cnt = dict()
def __init__(self):
"""
Initializes internal Graph states. It MUST be called in ``__init__`` method of subclass.
.. code-block:: python
>>> import oneflow as flow
>>> class SubclassGraph(flow.nn.Graph):
... def __init__(self):
... super().__init__() # MUST be called
... # Then define the graph attributes
... def build(self):
... pass
"""
self._generate_name()
self.config = GraphConfig()
self._blocks = OrderedDict()
......@@ -56,44 +109,85 @@ class Graph(object):
session.TryInit()
session.AddCGraph(self._c_nn_graph)
@property
def name(self):
return self._name
@property
def training(self):
return self.config.training
@property
def _config_proto(self):
return self.config.proto
@property
def _optimization_conf_proto(self):
session = session_ctx.GetDefaultSession()
assert type(session) is MultiClientSession
return session.resource
@property
def _graph_proto(self):
return self._job_proto
def debug(self, mode: bool = True) -> None:
if get_rank() != 0:
return
else:
print("Note that nn.Graph.debug() only print debug info on rank 0.")
self._debug = mode
for name, block in self._blocks.items():
assert block.type == BlockType.MODULE
block.debug(mode)
def build(self, *args):
r"""The ``build()`` method must be overridden to define neural network
computaion logic.
The ``build()`` method of nn.Graph is very similar to the ``forward()``
method of nn.Module. It is used to describe the computatioin logical of
a neural network.
When a graph object being called for the first time, the ``build()``
method will be called implicitly to build the computatioin graph.
.. code-block:: python
>>> import oneflow as flow
>>> class MyGraph(flow.nn.Graph):
... def __init__(self):
... super().__init__()
... self.linear = flow.nn.Linear(3, 8, False)
... def build(self, x):
... return self.linear(x)
>>> linear_graph = MyGraph()
>>> x = flow.randn(4, 3)
>>> y = linear_graph(x) # The build() method is called implicitly
"""
raise NotImplementedError()
def add_optimizer(
self, optim: Optimizer, *, lr_sch: LrScheduler = None,
):
r"""Add an optimizer, an learning rate scheduler to the graph.
To do training with nn.Graph, you should do 2 more things:
1. Add at least one optimier(learning rate schedulers are optional) with ``add_optimizer()`` method.
2. Call loss tensor's ``backward()`` method in ``build()`` method.
Note that the computaion graph will automatically execute these methods:
* optimizer's ``clip_grad()`` if a optimizer is set to do grad cliping.
* optimizer's ``step()``.
* optimizer's ``zero_grad()``.
* learn rate scheduler's ``step()``.
Also note that only scalar tensor are allowed to call ``backward()``
in ``nn.Graph.build()`` for the moment. So you may call ``Tensor.sum()``
or ``Tensor.mean()`` to make the loss tensor a scalar tensor.
.. code-block:: python
>>> import oneflow as flow
>>> loss_fn = flow.nn.MSELoss(reduction="sum")
>>> model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1))
>>> optimizer = flow.optim.SGD(model.parameters(), lr=1e-6)
>>> class LinearTrainGraph(flow.nn.Graph):
... def __init__(self):
... super().__init__()
... self.model = model
... self.loss_fn = loss_fn
... # Add an optimizer
... self.add_optimizer(optimizer)
... def build(self, x, y):
... y_pred = self.model(x)
... loss = self.loss_fn(y_pred, y)
... # Call loss tensor's backward(), loss tensor must be a scalar tensor
... loss.backward()
... return loss
>>> linear_graph = LinearTrainGraph()
>>> x = flow.randn(10, 3)
>>> y = flow.randn(10)
>>> for t in range(3):
... loss = linear_graph(x, y)
Args:
optim (oneflow.optim.Optimizer): The optimizer.
lr_sch : The learning rate scheduler, see oneflow.optim.lr_scheduler.
"""
opt_dict = dict()
assert optim is not None, "optimizer cannot be None"
assert isinstance(
......@@ -109,9 +203,127 @@ class Graph(object):
self._opts.append(opt_dict)
def set_grad_scaler(self, grad_scaler: GradScaler = None):
r"""Set the GradScaler for gradient and loss scaling.
"""
assert isinstance(grad_scaler, GradScaler)
self._grad_scaler = grad_scaler
def __call__(self, *args):
r"""Call nn.Graph subclass instance to run your customized graph.
Call your customized graph after the instantiation:
.. code-block:: python
g = CustomGraph()
out_tensors = g(input_tensors)
Note that the first call takes longer than later calls, because nn.Graph
will do the computaion graph generation and optimization at the first call.
``nn.Graph.__call__(*args)`` only accept positional arguements of
Tensor/List[Tensor]/TensorTuple/None at the moment.
Donot override this function.
"""
if not self._is_compiled:
self._compile(*args)
return self._run(*args)
@property
def name(self):
r"""Name auto-generated for this graph.
"""
return self._name
@property
def training(self):
r"""In traninig mode if the graph has an optimizer.
"""
return self.config.training
def debug(self, mode: bool = True) -> None:
r"""Open or close debug mode of the graph.
If in debug mode, logs of computation graph building will be
printed on rank 0.
.. code-block:: python
g = CustomGraph()
# Open debug mode
g.debug()
out_tensors = g(input_tensors) # Will print log for debug at the first call
"""
if get_rank() != 0:
return
self._debug = mode
for name, block in self._blocks.items():
assert block.type == BlockType.MODULE
block.debug(mode)
def __repr__(self):
r"""For printing the graph structure.
The graph structure can be printed after graph instantiation.
After the first call of graph, inputs and outputs will be added to
the graph structure.
.. code-block:: python
g = CustomGraph()
print(g)
out_tensors = g(input_tensors)
print(g) # Inputs and Outputs infos are added
"""
child_lines = []
if len(self._args_repr) > 0:
for in_str in self._args_repr:
input_str = add_indent(in_str, 2)
child_lines.append(input_str)
if len(self._blocks) > 0:
for n, m in self._blocks.items():
mod_str = repr(m)
mod_str = add_indent(mod_str, 2)
child_lines.append(mod_str)
if len(self._outs_repr) > 0:
for out_str in self._outs_repr:
output_str = add_indent(out_str, 2)
child_lines.append(output_str)
main_str = self._shallow_repr() + ": ("
if len(child_lines) > 0:
main_str += "\n " + "\n ".join(child_lines) + "\n"
main_str += ")"
return main_str
def _shallow_repr(self):
shallow_repr = "(GRAPH:" + self._name + ":" + self.__class__.__name__ + ")"
return shallow_repr
@property
def _config_proto(self):
return self.config.proto
@property
def _optimization_conf_proto(self):
session = session_ctx.GetDefaultSession()
assert type(session) is MultiClientSession
return session.resource
@property
def _graph_proto(self):
return self._job_proto
def _generate_name(self):
child_name = self.__class__.__name__
if Graph._child_init_cnt.get(child_name) is None:
......@@ -254,12 +466,6 @@ class Graph(object):
raise
return self._eager_outputs
def __call__(self, *args):
if not self._is_compiled:
self._compile(*args)
return self._run(*args)
def _build_io(self, io_type, build_func, *args):
assert io_type in ("input", "output")
io_type_upper = io_type.upper()
......@@ -416,14 +622,35 @@ class Graph(object):
return state_op_names, state_tensor_tuple
def _add_block(self, name: str, module: Module = None) -> None:
r"""Adds a module to the current graph as a block.
The block can be accessed as an attribute using the given name.
r"""Adds module to the graph as a block so that the module will
be called in nn.Graph.build.
Args:
name (string): name of the child block. The child block can be
accessed from this graph using the given name
name (str): name of the child block. The child block can be accessed from this graph using the given name.
module (Module): child module to be added to the graph.
Just assign nn.Module in nn.Graph, _add_block will be called to add the
module as a Block:
.. code-block:: python
>>> import oneflow as flow
>>> class LinearGraph(flow.nn.Graph):
... def __init__(self):
... super().__init__()
... # add a nn.Module as a block to graph.
... self.linear = flow.nn.Linear(3, 8, False)
... def build(self, x):
... # call the nn.Module block.
... return self.linear(x)
The block can be accessed as an attribute using the given name.
>>> g = LinearGraph()
>>> g.linear
(MODULE:linear:Linear(in_features=3, out_features=8, bias=False)): (
(PARAMETER:linear.weight:tensor(..., size=(8, 3), dtype=oneflow.float32, requires_grad=True)): ()
)
"""
if "_name" not in self.__dict__:
raise AttributeError(
......@@ -466,30 +693,8 @@ class Graph(object):
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
def __repr__(self):
child_lines = []
if len(self._args_repr) > 0:
for in_str in self._args_repr:
input_str = add_indent(in_str, 2)
child_lines.append(input_str)
if len(self._blocks) > 0:
for n, m in self._blocks.items():
mod_str = repr(m)
mod_str = add_indent(mod_str, 2)
child_lines.append(mod_str)
if len(self._outs_repr) > 0:
for out_str in self._outs_repr:
output_str = add_indent(out_str, 2)
child_lines.append(output_str)
main_str = self._shallow_repr() + ": ("
if len(child_lines) > 0:
main_str += "\n " + "\n ".join(child_lines) + "\n"
main_str += ")"
return main_str
if __name__ == "__main__":
import doctest
def _shallow_repr(self):
shallow_repr = "(GRAPH:" + self._name + ":" + self.__class__.__name__ + ")"
return shallow_repr
doctest.testmod(raise_on_error=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册