Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1d66467d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1d66467d
编写于
7月 16, 2020
作者:
J
jinyaohui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt add ps logic
上级
8300802b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
76 addition
and
24 deletion
+76
-24
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+32
-11
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+25
-7
mindspore/nn/optim/momentum.py
mindspore/nn/optim/momentum.py
+15
-6
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+2
-0
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+2
-0
未找到文件。
mindspore/nn/optim/adam.py
浏览文件 @
1d66467d
...
@@ -71,7 +71,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
...
@@ -71,7 +71,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta2
,
op_square
(
gradient_fp32
))
-
beta2
,
op_square
(
gradient_fp32
))
update
=
next_m
/
(
eps
+
op_sqrt
(
next_v
))
update
=
next_m
/
(
eps
+
op_sqrt
(
next_v
))
if
decay_flag
:
if
decay_flag
:
update
=
op_mul
(
weight_decay_tensor
,
param_fp32
)
+
update
update
=
op_mul
(
weight_decay_tensor
,
param_fp32
)
+
update
...
@@ -110,26 +109,45 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
...
@@ -110,26 +109,45 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tuple"
,
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
moment1
,
moment2
,
ps_parameter
):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
if
ps_parameter
:
eps
,
gradient
[
1
],
gradient
[
0
]))
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
[
1
],
gradient
[
0
]),
shapes
),
params
))
else
:
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
[
1
],
gradient
[
0
]))
return
success
return
success
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tensor"
,
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_run_opt_with_one_number
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_opt_with_one_number
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
moment1
,
moment2
,
ps_parameter
):
"""Apply adam optimizer to the weight parameter using Tensor."""
"""Apply adam optimizer to the weight parameter using Tensor."""
success
=
True
success
=
True
success
=
F
.
depend
(
success
,
opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
if
ps_parameter
:
eps
,
gradient
))
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
),
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
))),
params
))
else
:
success
=
F
.
depend
(
success
,
opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
))
return
success
return
success
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_push_pull_opt_with_sparse
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_push_pull_opt_with_sparse
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
...
@@ -156,6 +174,7 @@ def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, bet
...
@@ -156,6 +174,7 @@ def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, bet
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
))),
params
))
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
))),
params
))
return
success
return
success
class
Adam
(
Optimizer
):
class
Adam
(
Optimizer
):
r
"""
r
"""
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
...
@@ -293,13 +312,14 @@ class Adam(Optimizer):
...
@@ -293,13 +312,14 @@ class Adam(Optimizer):
if
self
.
is_group_lr
:
if
self
.
is_group_lr
:
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
beta1_power
,
beta2_power
,
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
),
self
.
beta1
,
self
.
beta2
,
self
.
eps
),
lr
,
gradients
,
params
,
moment1
,
moment2
)
lr
,
gradients
,
params
,
moment1
,
moment2
,
self
.
ps_parameters
)
else
:
else
:
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
beta1_power
,
beta2_power
,
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
),
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
),
gradients
,
params
,
moment1
,
moment2
)
gradients
,
params
,
moment1
,
moment2
,
self
.
ps_parameters
)
return
success
return
success
class
PSAdam
(
Optimizer
):
class
PSAdam
(
Optimizer
):
'''The same usage as Adam optimizer except the parameters are set PS mode.'''
'''The same usage as Adam optimizer except the parameters are set PS mode.'''
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-8
,
use_locking
=
False
,
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-8
,
use_locking
=
False
,
...
@@ -346,6 +366,7 @@ class PSAdam(Optimizer):
...
@@ -346,6 +366,7 @@ class PSAdam(Optimizer):
gradients
,
params
,
moment1
,
moment2
)
gradients
,
params
,
moment1
,
moment2
)
return
success
return
success
class
AdamWeightDecay
(
Optimizer
):
class
AdamWeightDecay
(
Optimizer
):
"""
"""
Implements Adam algorithm weight decay fix.
Implements Adam algorithm weight decay fix.
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
1d66467d
...
@@ -26,22 +26,38 @@ _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt")
...
@@ -26,22 +26,38 @@ _ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt")
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"Bool"
)
def
_tensor_run_opt_with_sparse
(
opt
,
spars_opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
def
_tensor_run_opt_with_sparse
(
opt
,
spars_opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
,
ps_parameter
):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
success
=
F
.
depend
(
success
,
spars_opt
(
weight
,
moment
,
linear
,
gradient
[
1
],
gradient
[
0
]))
if
ps_parameter
:
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
gradient
[
1
],
gradient
[
0
]),
shapes
),
weight
))
else
:
success
=
F
.
depend
(
success
,
spars_opt
(
weight
,
moment
,
linear
,
gradient
[
1
],
gradient
[
0
]))
return
success
return
success
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"Bool"
)
def
_tensor_run_opt
(
opt
,
spars_opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
def
_tensor_run_opt
(
opt
,
spars_opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
,
ps_parameter
):
"""Apply ftrl optimizer to the weight parameter."""
"""Apply ftrl optimizer to the weight parameter."""
success
=
True
success
=
True
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
if
ps_parameter
:
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
),
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
))),
weight
))
else
:
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
return
success
return
success
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tuple"
,
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"Tensor"
)
def
_tensor_run_push_pull_opt_with_sparse
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
def
_tensor_run_push_pull_opt_with_sparse
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
...
@@ -63,6 +79,7 @@ def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2,
...
@@ -63,6 +79,7 @@ def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2,
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
))),
weight
))
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
))),
weight
))
return
success
return
success
def
_check_param
(
initial_accum
,
lr_power
,
l1
,
l2
,
use_locking
,
weight_decay
=
0.0
,
prim_name
=
None
):
def
_check_param
(
initial_accum
,
lr_power
,
l1
,
l2
,
use_locking
,
weight_decay
=
0.0
,
prim_name
=
None
):
"""Check param."""
"""Check param."""
validator
.
check_value_type
(
"initial_accum"
,
initial_accum
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"initial_accum"
,
initial_accum
,
[
float
],
prim_name
)
...
@@ -150,9 +167,10 @@ class FTRL(Optimizer):
...
@@ -150,9 +167,10 @@ class FTRL(Optimizer):
grads
=
self
.
scale_grad
(
grads
)
grads
=
self
.
scale_grad
(
grads
)
success
=
self
.
map_
(
F
.
partial
(
_ftrl_opt
,
self
.
opt
,
self
.
sparse_opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
success
=
self
.
map_
(
F
.
partial
(
_ftrl_opt
,
self
.
opt
,
self
.
sparse_opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
linear
,
grads
,
params
,
moments
,
self
.
ps_parameters
)
return
success
return
success
class
PSFTRL
(
Optimizer
):
class
PSFTRL
(
Optimizer
):
def
__init__
(
self
,
params
,
initial_accum
=
0.1
,
learning_rate
=
0.001
,
lr_power
=-
0.5
,
l1
=
0.0
,
l2
=
0.0
,
def
__init__
(
self
,
params
,
initial_accum
=
0.1
,
learning_rate
=
0.001
,
lr_power
=-
0.5
,
l1
=
0.0
,
l2
=
0.0
,
use_locking
=
False
,
loss_scale
=
1.0
,
weight_decay
=
0.0
):
use_locking
=
False
,
loss_scale
=
1.0
,
weight_decay
=
0.0
):
...
...
mindspore/nn/optim/momentum.py
浏览文件 @
1d66467d
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""momentum"""
"""momentum"""
from
mindspore.ops
import
functional
as
F
,
composite
as
C
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops
import
_selected_ops
from
mindspore.ops
import
_selected_ops
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
...
@@ -25,11 +25,18 @@ from .optimizer import Optimizer
...
@@ -25,11 +25,18 @@ from .optimizer import Optimizer
_momentum_opt
=
C
.
MultitypeFuncGraph
(
"momentum_opt"
)
_momentum_opt
=
C
.
MultitypeFuncGraph
(
"momentum_opt"
)
@
_momentum_opt
.
register
(
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
@
_momentum_opt
.
register
(
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_tensor_run_opt_ext
(
opt
,
momentum
,
learning_rate
,
gradient
,
weight
,
moment
):
def
_tensor_run_opt_ext
(
opt
,
momentum
,
learning_rate
,
gradient
,
weight
,
moment
,
ps_parameter
):
"""Apply momentum optimizer to the weight parameter using Tensor."""
"""Apply momentum optimizer to the weight parameter using Tensor."""
success
=
True
success
=
True
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
learning_rate
,
gradient
,
momentum
))
if
ps_parameter
:
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Momentum"
,
[])
shapes
=
(
op_shape
(
learning_rate
),
op_shape
(
gradient
),
op_shape
(
momentum
))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
learning_rate
,
gradient
,
momentum
),
shapes
),
weight
))
else
:
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
learning_rate
,
gradient
,
momentum
))
return
success
return
success
...
@@ -127,7 +134,9 @@ class Momentum(Optimizer):
...
@@ -127,7 +134,9 @@ class Momentum(Optimizer):
gradients
=
self
.
scale_grad
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
lr
=
self
.
get_lr
()
if
self
.
is_group_lr
:
if
self
.
is_group_lr
:
success
=
self
.
hyper_map
(
F
.
partial
(
_momentum_opt
,
self
.
opt
,
self
.
momentum
),
lr
,
gradients
,
params
,
moments
)
success
=
self
.
hyper_map
(
F
.
partial
(
_momentum_opt
,
self
.
opt
,
self
.
momentum
),
lr
,
gradients
,
params
,
moments
,
self
.
ps_parameters
)
else
:
else
:
success
=
self
.
hyper_map
(
F
.
partial
(
_momentum_opt
,
self
.
opt
,
self
.
momentum
,
lr
),
gradients
,
params
,
moments
)
success
=
self
.
hyper_map
(
F
.
partial
(
_momentum_opt
,
self
.
opt
,
self
.
momentum
,
lr
),
gradients
,
params
,
moments
,
self
.
ps_parameters
)
return
success
return
success
mindspore/nn/optim/optimizer.py
浏览文件 @
1d66467d
...
@@ -152,6 +152,8 @@ class Optimizer(Cell):
...
@@ -152,6 +152,8 @@ class Optimizer(Cell):
self
.
weight_decay
=
weight_decay
*
loss_scale
self
.
weight_decay
=
weight_decay
*
loss_scale
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
self
.
decay_flags
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
decay_flags
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
parameters
)
ps_filter
=
lambda
x
:
x
.
is_param_ps
self
.
ps_parameters
=
tuple
(
ps_filter
(
x
)
for
x
in
self
.
parameters
)
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
exec_weight_decay
=
any
(
self
.
decay_flags
)
self
.
exec_weight_decay
=
any
(
self
.
decay_flags
)
self
.
param_length
=
len
(
self
.
parameters
)
self
.
param_length
=
len
(
self
.
parameters
)
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
1d66467d
...
@@ -511,6 +511,7 @@ class Push(PrimitiveWithInfer):
...
@@ -511,6 +511,7 @@ class Push(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
optim_type
=
'ApplyMomentum'
,
only_shape_indices
=
None
):
def
__init__
(
self
,
optim_type
=
'ApplyMomentum'
,
only_shape_indices
=
None
):
"""init Push"""
"""init Push"""
self
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
init_prim_io_names
(
inputs
=
[
'optim_inputs'
,
'optim_input_shapes'
],
outputs
=
[
'key'
])
self
.
init_prim_io_names
(
inputs
=
[
'optim_inputs'
,
'optim_input_shapes'
],
outputs
=
[
'key'
])
def
infer_shape
(
self
,
inputs
,
shapes
):
def
infer_shape
(
self
,
inputs
,
shapes
):
...
@@ -534,6 +535,7 @@ class Pull(PrimitiveWithInfer):
...
@@ -534,6 +535,7 @@ class Pull(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
):
def
__init__
(
self
):
"""init Pull"""
"""init Pull"""
self
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
init_prim_io_names
(
inputs
=
[
'key'
,
'weight'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'key'
,
'weight'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
key_shape
,
weight_shape
):
def
infer_shape
(
self
,
key_shape
,
weight_shape
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录