Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
92ded575
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92ded575
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!5106 fix grad api docs
Merge pull request !5106 from riemann_penn/fix_api_doc_of_gradoperation
上级
622f5f69
98f79945
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
186 addition
and
8 deletion
+186
-8
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+186
-8
未找到文件。
mindspore/ops/composite/base.py
浏览文件 @
92ded575
...
...
@@ -92,18 +92,196 @@ def core(fn=None, **flags):
class
GradOperation
(
GradOperation_
):
"""
An
metafuncgraph object which is used to get the gradient of output of a network(function)
.
An
higher-order function which is used to generate the gradient function for the input function
.
The GradOperation will convert the network(function) into a back propagation graph.
The gradient function generated by `GradOperation` higher-order function can be customized by construction args.
Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`,
see `Net` in Examples.
To generate a gradient function that returns gradients with respect to the first input
(see `GradNetWrtX` in Examples).
1. Construct a `GradOperation` higher-order function with default arguments:
`grad_op = GradOperation()`.
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
3. Call the gradient function with input function's inputs to get the gradients with respect to the first input:
`grad_op(net)(x, y)`.
To generate a gradient function that returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).
1. Construct a `GradOperation` higher-order function with `get_all=True` which
indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
`grad_op = GradOperation(get_all=True)`.
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
`gradient_function(x, y)`.
To generate a gradient function that returns gradients with respect to given parameters
(see `GradNetWithWrtParams` in Examples).
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
`grad_op = GradOperation(get_by_list=True)`.
2. Construct a `ParameterTuple` that will be passed along input function when constructing
`GradOperation` higher-order function, it will be used as a parameter filter that determine
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.
4. Call the gradient function with input function's inputs to get the gradients with
respect to given parameters: `gradient_function(x, y)`.
To generate a gradient function that returns gradients with respect to all inputs and given parameters
in the format of ((dx, dy), (dz))(see `GradNetWrtInputsAndParams` in Examples).
1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
`grad_op = GradOperation(get_all=True, get_by_list=True)`.
2. Construct a `ParameterTuple` that will be passed along input function when constructing
`GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
3. Call it with input function and `params` as arguments to get the gradient function:
`gradient_function = grad_op(net, params)`.
4. Call the gradient function with input function's inputs
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and
passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be
with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
`grad_op = GradOperation(get_all=True, sens_param=True)`.
2. Define grad_wrt_output as sens_param which works as the gradient with respect to output:
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
3. Call it with input function as argument to get the gradient function:
`gradient_function = grad_op(net)`.
4. Call the gradient function with input function's inputs and sens_param to
get the gradients with respect to all inputs:
`gradient_function(x, y, grad_wrt_output)`.
Args:
get_all (bool): If True, get all the gradients w.r.t inputs. Default: False.
get_by_list (bool): If True, get all the gradients w.r.t Parameter variables.
If get_all and get_by_list are both False, get the gradient w.r.t first input.
If get_all and get_by_list are both True, get the gradients w.r.t inputs and Parameter variables
at the same time in the form of ((grads w.r.t inputs), (grads w.r.t parameters)). Default: False.
sens_param (bool): Whether append sensitivity as input. If sens_param is False,
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
If get_all and get_by_list are both False, get the gradient with respect to first input.
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
at the same time in the form of ((gradients with respect to inputs),
(gradients with respect to parameters)). Default: False.
sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False,
a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
Returns:
The higher-order function which takes a function as argument and returns gradient function for it.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.matmul = P.MatMul()
>>> self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
>>> def construct(self, x, y):
>>> x = x * self.z
>>> out = self.matmul(x, y)
>>> return out
>>>
>>> class GradNetWrtX(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtX, self).__init__()
>>> self.net = net
>>> self.grad_op = GradOperation()
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> GradNetWrtX(Net())(x, y)
Tensor(shape=[2, 3], dtype=Float32,
[[1.4100001 1.5999999 6.6 ]
[1.4100001 1.5999999 6.6 ]])
>>>
>>> class GradNetWrtXY(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtXY, self).__init__()
>>> self.net = net
>>> self.grad_op = GradOperation(get_all=True)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtXY(Net())(x, y)
(Tensor(shape=[2, 3], dtype=Float32,
[[4.5099998 2.7 3.6000001]
[4.5099998 2.7 3.6000001]]), Tensor(shape=[3, 3], dtype=Float32,
[[2.6 2.6 2.6 ]
[1.9 1.9 1.9 ]
[1.3000001 1.3000001 1.3000001]]))
>>>
>>> class GradNetWrtXYWithSensParam(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtXYWithSensParam, self).__init__()
>>> self.net = net
>>> self.grad_op = GradOperation(get_all=True, sens_param=True)
>>> self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net)
>>> return gradient_function(x, y, self.grad_wrt_output)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtXYWithSensParam(Net())(x, y)
(Tensor(shape=[2, 3], dtype=Float32,
[[2.211 0.51 1.4900001]
[5.588 2.68 4.07 ]]), Tensor(shape=[3, 3], dtype=Float32,
[[1.52 2.82 2.14 ]
[1.1 2.05 1.55 ]
[0.90000004 1.55 1.25 ]]))
>>>
>>> class GradNetWithWrtParams(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWithWrtParams, self).__init__()
>>> self.net = net
>>> self.params = ParameterTuple(net.trainable_params())
>>> self.grad_op = GradOperation(get_by_list=True)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net, self.params)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWithWrtParams(Net())(x, y)
(Tensor(shape=[1], dtype=Float32, [21.536]),)
>>>
>>> class GradNetWrtInputsAndParams(nn.Cell):
>>> def __init__(self, net):
>>> super(GradNetWrtInputsAndParams, self).__init__()
>>> self.net = net
>>> self.params = ParameterTuple(net.trainable_params())
>>> self.grad_op = GradOperation(get_all=True, get_by_list=True)
>>> def construct(self, x, y):
>>> gradient_function = self.grad_op(self.net, self.params)
>>> return gradient_function(x, y)
>>>
>>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
>>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
>>> GradNetWrtInputsAndParams(Net())(x, y)
((Tensor(shape=[2, 3], dtype=Float32,
[[3.52 3.9 2.6 ]
[3.52 3.9 2.6 ]]), Tensor(shape=[3, 3], dtype=Float32,
[[0.6 0.6 0.6 ]
[1.9 1.9 1.9 ]
[1.3000001 1.3000001 1.3000001]])), (Tensor(shape=[1], dtype=Float32, [12.902]),))
"""
def
__init__
(
self
,
get_all
=
False
,
get_by_list
=
False
,
sens_param
=
False
):
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部