Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
172728a6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
172728a6
编写于
6月 28, 2020
作者:
W
wangnan39@huawei.com
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support weight decay for sparse optimizer
上级
e83c5630
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
26 addition
and
17 deletion
+26
-17
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+1
-1
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+2
-2
mindspore/nn/optim/lazyadam.py
mindspore/nn/optim/lazyadam.py
+1
-2
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+14
-4
mindspore/nn/optim/proximal_ada_grad.py
mindspore/nn/optim/proximal_ada_grad.py
+1
-1
tests/ut/python/nn/optim/test_adam.py
tests/ut/python/nn/optim/test_adam.py
+1
-1
tests/ut/python/nn/optim/test_ftrl.py
tests/ut/python/nn/optim/test_ftrl.py
+1
-1
tests/ut/python/nn/optim/test_lazyadam.py
tests/ut/python/nn/optim/test_lazyadam.py
+1
-1
tests/ut/python/nn/optim/test_proximal_ada_grad.py
tests/ut/python/nn/optim/test_proximal_ada_grad.py
+2
-2
tests/ut/python/nn/optim/test_rmsprop.py
tests/ut/python/nn/optim/test_rmsprop.py
+2
-2
未找到文件。
mindspore/nn/optim/adam.py
浏览文件 @
172728a6
...
...
@@ -164,7 +164,7 @@ class Adam(Optimizer):
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU
, weight decay is not supported
.
behavior is currently performed on the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
172728a6
...
...
@@ -73,7 +73,7 @@ class FTRL(Optimizer):
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU
, weight decay is not supported
.
behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
...
...
@@ -124,7 +124,7 @@ class FTRL(Optimizer):
linear
=
self
.
linear
lr
=
self
.
learning_rate
if
self
.
weight_decay
>
0.0
:
grads
=
self
.
hyper_map
(
F
.
partial
(
_apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
grads
)
grads
=
self
.
map_
(
F
.
partial
(
_apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
grads
)
grads
=
self
.
scale_grad
(
grads
)
success
=
self
.
map_
(
F
.
partial
(
_ftrl_opt
,
self
.
opt
,
self
.
sparse_opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
...
...
mindspore/nn/optim/lazyadam.py
浏览文件 @
172728a6
...
...
@@ -94,8 +94,7 @@ class LazyAdam(Optimizer):
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU, weight decay is
not supported.
continuous development. The sparse behavior is currently performed on the CPU.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
...
...
mindspore/nn/optim/optimizer.py
浏览文件 @
172728a6
...
...
@@ -195,12 +195,12 @@ class Optimizer(Cell):
params
=
self
.
parameters
if
self
.
is_group
:
if
self
.
exec_weight_decay
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
_apply_decay
),
self
.
weight_decay
,
self
.
decay_flags
,
params
,
gradients
)
gradients
=
self
.
map_
(
F
.
partial
(
_apply_decay
),
self
.
weight_decay
,
self
.
decay_flags
,
params
,
gradients
)
else
:
if
self
.
weight_decay
>
0
:
gradients
=
self
.
hyper_map
(
F
.
partial
(
_apply_decay
,
self
.
weight_decay
),
self
.
decay_flags
,
params
,
gradients
)
gradients
=
self
.
map_
(
F
.
partial
(
_apply_decay
,
self
.
weight_decay
),
self
.
decay_flags
,
params
,
gradients
)
return
gradients
...
...
@@ -479,10 +479,20 @@ class Optimizer(Cell):
op_add
=
P
.
AddN
()
op_gather
=
P
.
GatherV2
()
_apply_decay
=
C
.
MultitypeFuncGraph
(
"apply_decay"
)
@
_apply_decay
.
register
(
"Number"
,
"Bool"
,
"Tensor"
,
"Tuple"
)
def
_tensor_apply_decay_with_sparse
(
weight_decay
,
if_apply
,
weight
,
gradient
):
"""Get grad with weight_decay."""
if
if_apply
:
weight
=
op_gather
(
weight
,
gradient
[
0
],
0
)
return
gradient
[
0
],
op_add
((
weight
*
weight_decay
,
gradient
[
1
])),
gradient
[
2
]
return
gradient
@
_apply_decay
.
register
(
"Number"
,
"Bool"
,
"Tensor"
,
"Tensor"
)
def
_tensor_apply_decay
(
weight_decay
,
if_apply
,
weight
,
gradient
):
"""Get grad with weight_decay."""
...
...
mindspore/nn/optim/proximal_ada_grad.py
浏览文件 @
172728a6
...
...
@@ -60,7 +60,7 @@ class ProximalAdagrad(Optimizer):
Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU
, weight decay is not supported
.
behavior is currently performed on the CPU.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
...
...
tests/ut/python/nn/optim/test_adam.py
浏览文件 @
172728a6
...
...
@@ -107,7 +107,7 @@ def test_sparse_adam_compile():
net
=
NetWithSparseGatherV2
()
net
.
set_train
()
optimizer
=
Adam
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
loss_scale
=
1024.0
)
optimizer
=
Adam
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
loss_scale
=
1024.0
,
weight_decay
=
0.9
)
train_network
=
TrainOneStepCell
(
net
,
optimizer
)
_executor
.
compile
(
train_network
,
indices
,
label
)
...
...
tests/ut/python/nn/optim/test_ftrl.py
浏览文件 @
172728a6
...
...
@@ -71,6 +71,6 @@ def test_spares_ftrl_compile():
net
=
NetWithSparseGatherV2
()
net
.
set_train
()
optimizer
=
FTRL
(
net
.
trainable_params
(),
loss_scale
=
2.0
)
optimizer
=
FTRL
(
net
.
trainable_params
(),
weight_decay
=
0.9
,
loss_scale
=
2.0
)
train_network
=
TrainOneStepCell
(
net
,
optimizer
)
_executor
.
compile
(
train_network
,
indices
,
label
)
tests/ut/python/nn/optim/test_lazyadam.py
浏览文件 @
172728a6
...
...
@@ -75,7 +75,7 @@ def test_spares_lazy_adam_compile():
net
=
NetWithSparseGatherV2
()
net
.
set_train
()
optimizer
=
LazyAdam
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
loss_scale
=
2.0
)
optimizer
=
LazyAdam
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
weight_decay
=
0.9
,
loss_scale
=
2.0
)
train_network
=
TrainOneStepCell
(
net
,
optimizer
)
_executor
.
compile
(
train_network
,
indices
,
label
)
...
...
tests/ut/python/nn/optim/test_proximal_ada_grad.py
浏览文件 @
172728a6
...
...
@@ -57,7 +57,7 @@ def test_proximal_ada_grad():
net
=
Net
()
net
.
set_train
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optimizer
=
ProximalAdagrad
(
net
.
trainable_params
())
optimizer
=
ProximalAdagrad
(
net
.
trainable_params
()
,
weight_decay
=
0.9
,
loss_scale
=
1024.0
)
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
_executor
.
compile
(
train_network
,
inputs
,
label
)
...
...
@@ -70,6 +70,6 @@ def test_spares_proximal_ada_grad_compile():
net
=
NetWithSparseGatherV2
()
net
.
set_train
()
optimizer
=
ProximalAdagrad
(
net
.
trainable_params
(),
loss_scale
=
2
.0
)
optimizer
=
ProximalAdagrad
(
net
.
trainable_params
(),
weight_decay
=
0.9
,
loss_scale
=
1024
.0
)
train_network
=
TrainOneStepCell
(
net
,
optimizer
)
_executor
.
compile
(
train_network
,
indices
,
label
)
tests/ut/python/nn/optim/test_rmsprop.py
浏览文件 @
172728a6
...
...
@@ -57,7 +57,7 @@ def test_rmsprop_compile():
def
test_rmsprop_e
():
net
=
Net
()
with
pytest
.
raises
(
ValueError
):
RMSProp
(
net
.
get_parameters
(),
momentum
=-
0.1
,
learning_rate
=
0.1
)
RMSProp
(
net
.
get_parameters
(),
momentum
=-
0.1
,
learning_rate
=
0.1
,
weight_decay
=
0.9
)
with
pytest
.
raises
(
TypeError
):
RMSProp
(
net
.
get_parameters
(),
momentum
=
1
,
learning_rate
=
0.1
)
RMSProp
(
net
.
get_parameters
(),
momentum
=
1
,
learning_rate
=
0.1
,
weight_decay
=
0.9
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录