提交 92ded575 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5106 fix grad api docs

Merge pull request !5106 from riemann_penn/fix_api_doc_of_gradoperation
......@@ -92,18 +92,196 @@ def core(fn=None, **flags):
class GradOperation(GradOperation_):
"""
An metafuncgraph object which is used to get the gradient of output of a network(function).
An higher-order function which is used to generate the gradient function for the input function.
The GradOperation will convert the network(function) into a back propagation graph.
The gradient function generated by `GradOperation` higher-order function can be customized by construction args.
Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`,
see `Net` in Examples.
To generate a gradient function that returns gradients with respect to the first input
(see `GradNetWrtX` in Examples).
1. Construct a `GradOperation` higher-order function with default arguments:
`grad_op = GradOperation()`.
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
3. Call the gradient function with input function's inputs to get the gradients with respect to the first input:
`grad_op(net)(x, y)`.
To generate a gradient function that returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).
1. Construct a `GradOperation` higher-order function with `get_all=True` which
indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
`grad_op = GradOperation(get_all=True)`.
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
`gradient_function(x, y)`.
To generate a gradient function that returns gradients with respect to given parameters
(see `GradNetWithWrtParams` in Examples).
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
`grad_op = GradOperation(get_by_list=True)`.
2. Construct a `ParameterTuple` that will be passed along input function when constructing
`GradOperation` higher-order function, it will be used as a parameter filter that determine
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.
4. Call the gradient function with input function's inputs to get the gradients with
respect to given parameters: `gradient_function(x, y)`.
To generate a gradient function that returns gradients with respect to all inputs and given parameters
in the format of ((dx, dy), (dz))(see `GradNetWrtInputsAndParams` in Examples).
1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
`grad_op = GradOperation(get_all=True, get_by_list=True)`.
2. Construct a `ParameterTuple` that will be passed along input function when constructing
`GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.
4. Call the gradient function with input function's inputs
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and
passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be
with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
`grad_op = GradOperation(get_all=True, sens_param=True)`.
2. Define grad_wrt_output as sens_param which works as the gradient with respect to output:
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
3. Call it with input function as argument to get the gradient function:
`gradient_function = grad_op(net)`.
4. Call the gradient function with input function's inputs and sens_param to
get the gradients with respect to all inputs:
`gradient_function(x, y, grad_wrt_output)`.
Args:
get_all (bool): If True, get all the gradients w.r.t inputs. Default: False.
get_by_list (bool): If True, get all the gradients w.r.t Parameter variables.
If get_all and get_by_list are both False, get the gradient w.r.t first input.
If get_all and get_by_list are both True, get the gradients w.r.t inputs and Parameter variables
at the same time in the form of ((grads w.r.t inputs), (grads w.r.t parameters)). Default: False.
sens_param (bool): Whether append sensitivity as input. If sens_param is False,
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
If get_all and get_by_list are both False, get the gradient with respect to first input.
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
at the same time in the form of ((gradients with respect to inputs),
(gradients with respect to parameters)). Default: False.
sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False,
a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
Returns:
The higher-order function which takes a function as argument and returns gradient function for it.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.matmul = P.MatMul()
>>> self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
>>> def construct(self, x, y):
>>> x = x * self.z
>>> out = self.matmul(x, y)
>>> return out
>>>
>>> class GradNetWrtX(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtX, self).__init__()
>>> self.net = net
>>> self.grad_op = GradOperation()
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> GradNetWrtX(Net())(x, y)
Tensor(shape=[2, 3], dtype=Float32,
[[1.4100001 1.5999999 6.6 ]
[1.4100001 1.5999999 6.6 ]])
>>>
>>> class GradNetWrtXY(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtXY, self).__init__()
>>> self.net = net
>>> self.grad_op = GradOperation(get_all=True)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtXY(Net())(x, y)
(Tensor(shape=[2, 3], dtype=Float32,
[[4.5099998 2.7 3.6000001]
[4.5099998 2.7 3.6000001]]), Tensor(shape=[3, 3], dtype=Float32,
[[2.6 2.6 2.6 ]
[1.9 1.9 1.9 ]
[1.3000001 1.3000001 1.3000001]]))
>>>
>>> class GradNetWrtXYWithSensParam(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtXYWithSensParam, self).__init__()
>>> self.net = net
>>> self.grad_op = GradOperation(get_all=True, sens_param=True)
>>> self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net)
>>> return gradient_function(x, y, self.grad_wrt_output)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtXYWithSensParam(Net())(x, y)
(Tensor(shape=[2, 3], dtype=Float32,
[[2.211 0.51 1.4900001]
[5.588 2.68 4.07 ]]), Tensor(shape=[3, 3], dtype=Float32,
[[1.52 2.82 2.14 ]
[1.1 2.05 1.55 ]
[0.90000004 1.55 1.25 ]]))
>>>
>>> class GradNetWithWrtParams(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWithWrtParams, self).__init__()
>>> self.net = net
>>> self.params = ParameterTuple(net.trainable_params())
>>> self.grad_op = GradOperation(get_by_list=True)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net, self.params)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWithWrtParams(Net())(x, y)
(Tensor(shape=[1], dtype=Float32, [21.536]),)
>>>
>>> class GradNetWrtInputsAndParams(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtInputsAndParams, self).__init__()
>>> self.net = net
>>> self.params = ParameterTuple(net.trainable_params())
>>> self.grad_op = GradOperation(get_all=True, get_by_list=True)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net, self.params)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
>>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtInputsAndParams(Net())(x, y)
((Tensor(shape=[2, 3], dtype=Float32,
[[3.52 3.9 2.6 ]
[3.52 3.9 2.6 ]]), Tensor(shape=[3, 3], dtype=Float32,
[[0.6 0.6 0.6 ]
[1.9 1.9 1.9 ]
[1.3000001 1.3000001 1.3000001]])), (Tensor(shape=[1], dtype=Float32, [12.902]),))
"""
def __init__(self, get_all=False, get_by_list=False, sens_param=False):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册