diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 5f887780a11316e0ae9aca7eb8e06616fe885748..5fc67809841199af00df78d9524978d5982d5697 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -20,23 +20,6 @@ from .._imperative_rt import core2, ops from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const -""" Some notes: - 1. Initialize the optimizer: - for each trainable parameter: - call wrt(param, callback) - Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data - 2. Tracer has one member: node, which is a VariableNode - 3. VariableNode has a OpNode member: opnode - 4. OpNode has four members: - a. id - b. inputs, which is made of VariableNode - c. outputs, which are weakref's to VariableNode - d. backward: call back function - e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist - f. backward_allow_noinput: whether backward allow noinput - -""" - _grad_count = 0 _grad_manager_dict = weakref.WeakValueDictionary() @@ -97,6 +80,64 @@ class Grad: class Function(ops.PyOpBase): + """ + Defines a block of operations with customizable differentiation. + + The computation should be defined in ``forward`` method, with gradient + computation defined in ``backward`` method. + + Each instance of ``Function`` should be used only once during forwardding. + + Examples: + + .. code-block:: + class Sigmoid(Function): + def forward(self, x): + y = 1 / (1 + F.exp(-x)) + self.y = y + return y + + def backward(self, dy): + y = self.y + return dy * y * (1-y) + + """ + + def forward(self, *args, **kwargs): + """ + Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. + + :param input: input tensors. + :return: a tuple of Tensor or a single Tensor. + + .. note:: + + This method should return a tuple of Tensor or a single Tensor representing the output + of the function. + """ + raise NotImplementedError + + def backward(self, *output_grads): + """ + Compute the gradient of the forward function. It must be overriden by all subclasses. + + :param output_grads: gradients of outputs that are returned by :meth:`forward`. + + .. note:: + + In case when some tensors of outputs are not related to loss function, the corresponding + values in ``output_grads`` would be ``None``. + + .. note:: + + This method should return a tuple which containing the gradients of all inputs, in the same order + as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned + instead if there is only one input. If users want to stop the propagation of some gradients, + the corresponding returned values should be set ``None`` . + + """ + raise NotImplementedError + def _default_rule(self, *args): ret = self.forward(*args) self.__single_output = isinstance(ret, core2.Tensor)