Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
49da4e79
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
49da4e79
编写于
7月 13, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 13, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3026 Add front end PS optimizer expression
Merge pull request !3026 from ZPaC/add-front-end-ps-optim-expression
上级
8f15d1f6
f8c7ae76
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
241 addition
and
11 deletion
+241
-11
mindspore/ccsrc/common/utils.h
mindspore/ccsrc/common/utils.h
+8
-0
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
...e/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
+1
-0
mindspore/ccsrc/utils/context/ms_context.cc
mindspore/ccsrc/utils/context/ms_context.cc
+6
-0
mindspore/ccsrc/utils/utils.h
mindspore/ccsrc/utils/utils.h
+23
-5
mindspore/common/parameter.py
mindspore/common/parameter.py
+4
-0
mindspore/nn/cell.py
mindspore/nn/cell.py
+14
-0
mindspore/nn/optim/__init__.py
mindspore/nn/optim/__init__.py
+4
-4
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+71
-0
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+55
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+4
-2
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+51
-0
未找到文件。
mindspore/ccsrc/common/utils.h
浏览文件 @
49da4e79
...
...
@@ -38,6 +38,14 @@ static inline std::string GetEnv(const std::string &envvar) {
return
std
::
string
(
value
);
}
static
inline
int
SetEnv
(
const
char
*
envname
,
const
char
*
envvar
,
int
overwrite
=
1
)
{
#if defined(_WIN32)
return
0
;
#else
return
::
setenv
(
envname
,
envvar
,
overwrite
);
#endif
}
}
// namespace common
}
// namespace mindspore
...
...
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
浏览文件 @
49da4e79
...
...
@@ -72,6 +72,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register
(
kSpaceToBatchOpName
,
{
1
});
Register
(
kBatchToSpaceOpName
,
{
1
});
Register
(
kPadOpName
,
{
1
});
Register
(
kPushOpName
,
{
1
});
}
ConstInputToAttrInfoRegistry
&
ConstInputToAttrInfoRegistry
::
Instance
()
{
...
...
mindspore/ccsrc/utils/context/ms_context.cc
浏览文件 @
49da4e79
...
...
@@ -30,6 +30,7 @@
#include "transform/df_graph_manager.h"
#endif
#include "ir/tensor.h"
#include "common/utils.h"
namespace
mindspore
{
#ifdef ENABLE_GE
...
...
@@ -168,6 +169,11 @@ bool MsContext::OpenTsd() {
return
true
;
}
auto
role
=
common
::
GetEnv
(
"MS_ROLE"
);
if
(
strcmp
(
role
.
c_str
(),
"MS_SCHED"
)
==
0
||
strcmp
(
role
.
c_str
(),
"MS_PSERVER"
)
==
0
)
{
return
true
;
}
unsigned
int
device_id
;
unsigned
int
rank_size
=
1
;
...
...
mindspore/ccsrc/utils/utils.h
浏览文件 @
49da4e79
...
...
@@ -173,6 +173,10 @@ constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad";
constexpr
auto
kSparseApplyRMSPropOpName
=
"SparseApplyRMSProp"
;
constexpr
auto
kSparseApplyAdadeltaOpName
=
"SparseApplyAdadelta"
;
constexpr
auto
kApplyAdamWithAmsgradOpName
=
"ApplyAdamWithAmsgrad"
;
constexpr
auto
kPushOpName
=
"Push"
;
constexpr
auto
kPullOpName
=
"Pull"
;
constexpr
auto
kEmbeddingLookupOpName
=
"EmbeddingLookup"
;
constexpr
auto
kEmbeddingLookupProxyOpName
=
"EmbeddingLookupProxy"
;
// attr key name
constexpr
auto
kAttrInputNames
=
"input_names"
;
...
...
@@ -234,6 +238,8 @@ constexpr auto kAttrSizeSplits = "size_splits";
constexpr
auto
kAttrOutputDefault
=
"output_default"
;
constexpr
auto
kAttrReduceScatterFlag
=
"reduce_scatter_flag"
;
constexpr
auto
kAttrOffset
=
"offset"
;
constexpr
auto
kAttrPsKey
=
"ps_key"
;
constexpr
auto
kAttrOptimizerType
=
"optim_type"
;
// attr value
constexpr
auto
kValueTargetSwitch
=
"target_switch"
;
...
...
@@ -286,12 +292,24 @@ const std::set<std::string> kOpFormatList = {
kOpFormat_NC1HWC0_C04
,
kOpFormat_FRACTAL_Z_C04
,
kOpFormat_NDHWC
};
const
std
::
set
<
std
::
string
>
kDefaultCompatibleFormat
=
{
kOpFormat_ND
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_HWCN
};
const
std
::
set
<
std
::
string
>
kOptOperatorSet
=
{
kMomentumOpName
,
kApplyMomentumOpName
,
kApplyAdadeltaOpName
,
kApplyAdagradOpName
,
kApplyAdagradDAName
,
kApplyAdamOpName
,
kApplyAdaMaxOpName
,
kApplyAddSignOpName
,
kApplyCenteredRMSPOpName
,
kApplyFtrlOpName
,
kApplyFtrlV2OpName
,
kApplyGradientDescentOpName
,
kApplyPowerSignOpName
,
kApplyProximalAdagradOpName
,
kApplyProximalGradientDescentOpName
,
kMomentumOpName
,
kApplyMomentumOpName
,
kApplyAdadeltaOpName
,
kApplyAdagradOpName
,
kApplyAdagradDAName
,
kApplyAdamOpName
,
kApplyAdaMaxOpName
,
kApplyAddSignOpName
,
kApplyCenteredRMSPOpName
,
kApplyFtrlOpName
,
kApplyFtrlV2OpName
,
kApplyGradientDescentOpName
,
kApplyPowerSignOpName
,
kApplyProximalAdagradOpName
,
kApplyProximalGradientDescentOpName
,
kApplyRMSPropOpName
,
kPushOpName
,
kPullOpName
,
};
const
std
::
set
<
std
::
string
>
kHWSpecialFormatSet
=
{
kOpFormat_FRAC_Z
,
kOpFormat_NC1KHKWHWC0
,
kOpFormat_NC1HWC0
,
...
...
mindspore/common/parameter.py
浏览文件 @
49da4e79
...
...
@@ -65,6 +65,7 @@ class Parameter:
self
.
has_indexed_slices_grad
=
has_indexed_slices_grad
self
.
_is_init
=
False
self
.
_sliced
=
False
self
.
is_param_ps
=
False
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
self
.
init_data
()
...
...
@@ -75,6 +76,9 @@ class Parameter:
def
__parameter__
(
self
):
"""For parse check."""
def
set_param_ps
(
self
):
self
.
is_param_ps
=
True
@
property
def
name
(
self
):
"""Get the name of the parameter."""
...
...
mindspore/nn/cell.py
浏览文件 @
49da4e79
...
...
@@ -831,6 +831,20 @@ class Cell:
self
.
_backward_hook
=
HookBackward
(
fn
,
self
.
cls_name
+
"("
+
str
(
id
(
self
))
+
")"
)
self
.
enable_hook
=
True
def
set_param_ps
(
self
,
recurse
=
True
):
"""
Set whether the trainable parameter is updated by parameter server.
Note:
This only works when running task in parameter server mode.
Args:
recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
"""
params
=
self
.
trainable_params
(
recurse
)
for
param
in
params
:
param
.
set_param_ps
()
class
GraphKernel
(
Cell
):
"""
Base class for GraphKernel.
...
...
mindspore/nn/optim/__init__.py
浏览文件 @
49da4e79
...
...
@@ -20,14 +20,14 @@ The optimizer is used to calculate and update the gradients.
"""
from
.optimizer
import
Optimizer
from
.momentum
import
Momentum
from
.adam
import
Adam
,
AdamWeightDecay
,
AdamWeightDecayDynamicLR
from
.adam
import
Adam
,
PSAdam
,
AdamWeightDecay
,
AdamWeightDecayDynamicLR
from
.lamb
import
Lamb
from
.sgd
import
SGD
from
.lars
import
LARS
from
.ftrl
import
FTRL
from
.ftrl
import
FTRL
,
PSFTRL
from
.rmsprop
import
RMSProp
from
.proximal_ada_grad
import
ProximalAdagrad
from
.lazyadam
import
LazyAdam
__all__
=
[
'Optimizer'
,
'Momentum'
,
'LARS'
,
'Adam'
,
'AdamWeightDecay'
,
'LazyAdam'
,
'AdamWeightDecayDynamicLR'
,
'Lamb'
,
'SGD'
,
'FTRL'
,
'RMSProp'
,
'ProximalAdagrad'
]
__all__
=
[
'Optimizer'
,
'Momentum'
,
'LARS'
,
'Adam'
,
'
PSAdam'
,
'
AdamWeightDecay'
,
'LazyAdam'
,
'AdamWeightDecayDynamicLR'
,
'Lamb'
,
'SGD'
,
'FTRL'
,
'
PSFTRL'
,
'
RMSProp'
,
'ProximalAdagrad'
]
mindspore/nn/optim/adam.py
浏览文件 @
49da4e79
...
...
@@ -27,6 +27,7 @@ from mindspore._checkparam import Rel
from
.optimizer
import
Optimizer
_adam_opt
=
C
.
MultitypeFuncGraph
(
"adam_opt"
)
_adam_push_pull_opt
=
C
.
MultitypeFuncGraph
(
"_adam_push_pull_opt"
)
@
_adam_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
...
...
@@ -129,6 +130,31 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b
eps
,
gradient
))
return
success
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_push_pull_opt_with_sparse
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
"""Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse."""
success
=
True
op_shape
=
P
.
Shape
()
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]))
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
[
1
],
gradient
[
0
]),
shapes
),
params
))
return
success
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_push_pull_opt_with_one_number
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
"""Apply adam optimizer by push and pull to the weight parameter using Tensor."""
success
=
True
op_shape
=
P
.
Shape
()
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
),
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
))),
params
))
return
success
class
Adam
(
Optimizer
):
r
"""
...
...
@@ -274,6 +300,51 @@ class Adam(Optimizer):
gradients
,
params
,
moment1
,
moment2
)
return
success
class
PSAdam
(
Optimizer
):
'''The same usage as Adam optimizer except the parameters are set PS mode.'''
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-8
,
use_locking
=
False
,
use_nesterov
=
False
,
weight_decay
=
0.0
,
loss_scale
=
1.0
):
super
(
PSAdam
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
cls_name
)
validator
.
check_value_type
(
"use_nesterov"
,
use_nesterov
,
[
bool
],
self
.
cls_name
)
self
.
beta1
=
Tensor
(
beta1
,
mstype
.
float32
)
self
.
beta2
=
Tensor
(
beta2
,
mstype
.
float32
)
self
.
beta1_power
=
Parameter
(
initializer
(
1
,
[
1
],
mstype
.
float32
),
name
=
"beta1_power"
)
self
.
beta2_power
=
Parameter
(
initializer
(
1
,
[
1
],
mstype
.
float32
),
name
=
"beta2_power"
)
self
.
eps
=
Tensor
(
eps
,
mstype
.
float32
)
self
.
moment1
=
self
.
parameters
.
clone
(
prefix
=
"moment1"
,
init
=
'zeros'
)
self
.
moment2
=
self
.
parameters
.
clone
(
prefix
=
"moment2"
,
init
=
'zeros'
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
self
.
push
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
pull
=
P
.
Pull
()
self
.
pull
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
def
construct
(
self
,
gradients
):
params
=
self
.
parameters
moment1
=
self
.
moment1
moment2
=
self
.
moment2
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
beta1_power
=
self
.
beta1_power
*
self
.
beta1
self
.
beta1_power
=
beta1_power
beta2_power
=
self
.
beta2_power
*
self
.
beta2
self
.
beta2_power
=
beta2_power
if
self
.
is_group_lr
:
success
=
self
.
map_
(
F
.
partial
(
_adam_push_pull_opt
,
self
.
push
,
self
.
pull
,
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
),
lr
,
gradients
,
params
,
moment1
,
moment2
)
else
:
success
=
self
.
map_
(
F
.
partial
(
_adam_push_pull_opt
,
self
.
push
,
self
.
pull
,
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
),
gradients
,
params
,
moment1
,
moment2
)
return
success
class
AdamWeightDecay
(
Optimizer
):
"""
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
49da4e79
...
...
@@ -22,6 +22,7 @@ from mindspore._checkparam import Rel
from
.optimizer
import
Optimizer
,
_apply_decay
,
_grad_scale
_ftrl_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
_ftrl_push_pull_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
...
...
@@ -41,6 +42,26 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
return
success
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
)
def
_tensor_run_push_pull_opt_with_sparse
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
success
=
True
op_shape
=
P
.
Shape
()
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]))
success
=
F
.
depend
(
success
,
pull
(
push
((
gradient
[
1
],
gradient
[
0
]),
shapes
),
weight
))
return
success
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_tensor_run_push_pull_opt_with_one_number
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
success
=
True
op_shape
=
P
.
Shape
()
success
=
F
.
depend
(
success
,
pull
(
push
((
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
),
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
))),
weight
))
return
success
def
_check_param
(
initial_accum
,
lr_power
,
l1
,
l2
,
use_locking
,
weight_decay
=
0.0
,
prim_name
=
None
):
"""Check param."""
...
...
@@ -131,3 +152,37 @@ class FTRL(Optimizer):
success
=
self
.
map_
(
F
.
partial
(
_ftrl_opt
,
self
.
opt
,
self
.
sparse_opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
return
success
class
PSFTRL
(
Optimizer
):
def
__init__
(
self
,
params
,
initial_accum
=
0.1
,
learning_rate
=
0.001
,
lr_power
=-
0.5
,
l1
=
0.0
,
l2
=
0.0
,
use_locking
=
False
,
loss_scale
=
1.0
,
weight_decay
=
0.0
):
super
(
PSFTRL
,
self
).
__init__
(
learning_rate
,
params
,
loss_scale
=
loss_scale
)
if
self
.
is_group
:
raise
RuntimeError
(
f
"The
{
self
.
cls_name
}
optimizer cannot support group setting."
)
_check_param
(
initial_accum
,
lr_power
,
l1
,
l2
,
use_locking
,
weight_decay
,
self
.
cls_name
)
self
.
moments
=
self
.
parameters
.
clone
(
prefix
=
"moments"
,
init
=
initial_accum
)
self
.
linear
=
self
.
parameters
.
clone
(
prefix
=
"linear"
,
init
=
'zeros'
)
self
.
l1
=
l1
self
.
l2
=
l2
self
.
lr_power
=
lr_power
self
.
weight_decay
=
weight_decay
self
.
decay_tf
=
tuple
((
lambda
:
True
)()
for
x
in
self
.
parameters
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
self
.
push
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
pull
=
P
.
Pull
()
self
.
pull
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
def
construct
(
self
,
grads
):
params
=
self
.
parameters
moments
=
self
.
moments
linear
=
self
.
linear
lr
=
self
.
learning_rate
if
self
.
weight_decay
>
0.0
:
grads
=
self
.
hyper_map
(
F
.
partial
(
_apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
grads
)
grads
=
self
.
scale_grad
(
grads
)
success
=
self
.
map_
(
F
.
partial
(
_ftrl_push_pull_opt
,
self
.
push
,
self
.
pull
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
return
success
mindspore/ops/operations/__init__.py
浏览文件 @
49da4e79
...
...
@@ -78,7 +78,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
ApplyAddSign
,
ApplyPowerSign
,
ApplyGradientDescent
,
ApplyProximalGradientDescent
,
ApplyRMSProp
,
ApplyCenteredRMSProp
,
BasicLSTMCell
,
InTopK
)
from
.other_ops
import
(
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
PopulationCount
,
CheckValid
,
MakeRefKey
,
Partial
,
Depend
,
CheckBprop
)
CheckValid
,
MakeRefKey
,
Partial
,
Depend
,
CheckBprop
,
Push
,
Pull
)
from
.thor_ops
import
*
__all__
=
[
...
...
@@ -333,7 +333,9 @@ __all__ = [
"Mod"
,
"PopulationCount"
,
"ParallelConcat"
,
"EmbeddingLookup"
"EmbeddingLookup"
,
"Push"
,
"Pull"
]
__all__
.
sort
()
mindspore/ops/operations/other_ops.py
浏览文件 @
49da4e79
...
...
@@ -488,3 +488,54 @@ class PopulationCount(PrimitiveWithInfer):
args
=
{
"x"
:
x_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
int16
,
mstype
.
uint16
,),
self
.
name
)
return
mstype
.
tensor_type
(
mstype
.
uint8
)
class
Push
(
PrimitiveWithInfer
):
"""
Pushing the inputs of the corresponding optimizer to parameter server.
Args:
optim_type (string): The optimizer type. Default: 'ApplyMomentum'.
only_shape_indices (list): The indices of input of which only shape
will be pushed to parameter server. Default: None.
Inputs:
- **optim_inputs** (tuple) - The inputs for this kind of optimizer.
- **optim_input_shapes** (tuple) - The shapes of the inputs.
Outputs:
Tensor, the key of the weight which needs to be updated.
"""
@
prim_attr_register
def
__init__
(
self
,
optim_type
=
'ApplyMomentum'
,
only_shape_indices
=
None
):
"""init Push"""
self
.
init_prim_io_names
(
inputs
=
[
'optim_inputs'
,
'optim_input_shapes'
],
outputs
=
[
'key'
])
def
infer_shape
(
self
,
inputs
,
shapes
):
return
[
1
]
def
infer_dtype
(
self
,
inputs
,
shapes
):
return
mstype
.
uint64
class
Pull
(
PrimitiveWithInfer
):
"""
Pulling weight from parameter server.
Inputs:
- **key** (Tensor) - The key of the weight.
- **weight** (Tensor) - The weight to be updated.
Outputs:
None.
"""
@
prim_attr_register
def
__init__
(
self
):
"""init Pull"""
self
.
init_prim_io_names
(
inputs
=
[
'key'
,
'weight'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
key_shape
,
weight_shape
):
return
[
1
]
def
infer_dtype
(
self
,
key_dtype
,
weight_dtype
):
return
mstype
.
float32
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录