提交 e6715910 编写于 作者: M Megvii Engine Team

docs(mge/imperative): add docstring for GradManager

GitOrigin-RevId: 4c326206b83f7fb40f43ba3487bc76316440e1f8
上级 ce55fbf6
......@@ -14,6 +14,51 @@ def get_backwarding_grad_manager():
class GradManager:
r"""GradManager manages auto differentiation and all resources required to perform it.
Our auto differentiation framework requires that the user explicitly indicates when
the forward operations start and when all resources should be released. A typical usage of
GradManager is as follows:
.. codeblock::
gm = GradManager()
gm.attach(model.parameters())
with gm:
# forward operations
...
# backward gradients
gm.backward(loss)
You can also use `record()` and `release()` method instead of `with` context:
.. codeblock::
gm = GradManager()
gm.attach(model.parameters())
gm.record()
# forward operations
...
# backward gradients
gm.backward(loss)
gm.release()
Typically, in data parallel, we would like to average the gradients across
processes. Users will finally get the averaged gradients if an "AllReduce"
callback is registered as follows:
.. codeblock::
import megengine.distributed as dist
gm = GradManager()
gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
"""
def __init__(self):
self._call_back_dict = defaultdict(list)
self._param_dict = dict()
......@@ -23,6 +68,18 @@ class GradManager:
self._gradients = dict()
def attach(self, params, callbacks=None):
r"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this:
.. codeblock::
def cb(param: Tensor, grad: Tensor) -> Tensor:
# do something
return grad
:param params: registered parameters
:param callbacks: list of callback functions
"""
if callbacks is None:
callbacks = []
if isinstance(callbacks, Callable):
......@@ -38,6 +95,11 @@ class GradManager:
return self
def backward(self, ys, dys=None):
r"""Performs back-propagation and computes gradients.
:param ys: outputs of forward operators, e.g., the loss tensor
:param dys: derivatives of ys
"""
global backwarding_grad_manager
cache = backwarding_grad_manager
backwarding_grad_manager = self
......@@ -71,6 +133,8 @@ class GradManager:
backwarding_grad_manager = cache
def record(self):
r"""Starts recording forward operations.
"""
if self._recording:
raise RuntimeError("already recording")
grad = Grad()
......@@ -90,6 +154,8 @@ class GradManager:
grad.__enter__()
def release(self):
r"""Stops recording and releases resources for gradients calculation.
"""
if not self._recording:
raise RuntimeError("not recording")
self._stop_record()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册