Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
兔爷不爱我
mindspore
提交
af62d402
M
mindspore
项目概览
兔爷不爱我
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
af62d402
编写于
7月 22, 2020
作者:
Z
ZPaC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Pass optimizer attributes to push nodes.
上级
03193542
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
117 addition
and
201 deletion
+117
-201
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h
...ore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h
+2
-1
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc
...end/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc
+3
-5
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h
...kend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h
+2
-1
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
...end/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
+17
-5
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h
...kend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h
+2
-1
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc
...ernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc
+3
-5
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h
...kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h
+2
-1
mindspore/ccsrc/frontend/parallel/ps/common.h
mindspore/ccsrc/frontend/parallel/ps/common.h
+5
-0
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
+22
-4
mindspore/ccsrc/frontend/parallel/ps/util.cc
mindspore/ccsrc/frontend/parallel/ps/util.cc
+14
-0
mindspore/ccsrc/frontend/parallel/ps/util.h
mindspore/ccsrc/frontend/parallel/ps/util.h
+2
-0
mindspore/nn/optim/__init__.py
mindspore/nn/optim/__init__.py
+4
-4
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+20
-99
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+19
-75
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h
浏览文件 @
af62d402
...
@@ -31,8 +31,9 @@ class PServerKernel {
...
@@ -31,8 +31,9 @@ class PServerKernel {
~
PServerKernel
()
=
default
;
~
PServerKernel
()
=
default
;
PServerKernel
(
const
PServerKernel
&
)
=
delete
;
PServerKernel
(
const
PServerKernel
&
)
=
delete
;
PServerKernel
&
operator
=
(
const
PServerKernel
&
)
=
delete
;
PServerKernel
&
operator
=
(
const
PServerKernel
&
)
=
delete
;
virtual
void
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
{}
virtual
void
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
{}
virtual
void
InitKernel
(
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
{}
virtual
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
{}
virtual
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
{}
virtual
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
virtual
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
)
=
0
;
const
std
::
vector
<
AddressPtr
>
&
outputs
)
=
0
;
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc
浏览文件 @
af62d402
...
@@ -23,7 +23,7 @@ namespace mindspore {
...
@@ -23,7 +23,7 @@ namespace mindspore {
namespace
kernel
{
namespace
kernel
{
namespace
ps
{
namespace
ps
{
void
SparseApplyAdamPSKernel
::
InitKernel
(
void
SparseApplyAdamPSKernel
::
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
const
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>
&
shape_vec
=
*
shapes
;
const
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>
&
shape_vec
=
*
shapes
;
std
::
vector
<
size_t
>
&
var_shape
=
*
(
shape_vec
[
0
]);
std
::
vector
<
size_t
>
&
var_shape
=
*
(
shape_vec
[
0
]);
std
::
vector
<
size_t
>
&
m_shape
=
*
(
shape_vec
[
1
]);
std
::
vector
<
size_t
>
&
m_shape
=
*
(
shape_vec
[
1
]);
...
@@ -55,11 +55,9 @@ void SparseApplyAdamPSKernel::InitKernel(
...
@@ -55,11 +55,9 @@ void SparseApplyAdamPSKernel::InitKernel(
if
(
grad_shape
[
0
]
!=
indices_size_
)
{
if
(
grad_shape
[
0
]
!=
indices_size_
)
{
MS_LOG
(
ERROR
)
<<
"The first dimension of grad shape must be equal to indices"
;
MS_LOG
(
ERROR
)
<<
"The first dimension of grad shape must be equal to indices"
;
}
}
/*
if
(
AnfAlgo
::
HasNodeAttr
(
USE_NESTEROV
,
cnode
))
{
if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
use_nesterov_
=
AnfAlgo
::
GetNodeAttr
<
bool
>
(
cnode
,
"use_nesterov"
);
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
}
}
*/
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
sizeof
(
int
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
sizeof
(
int
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h
浏览文件 @
af62d402
...
@@ -30,7 +30,8 @@ class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerK
...
@@ -30,7 +30,8 @@ class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerK
SparseApplyAdamPSKernel
(
size_t
rank_id
,
size_t
pserver_num
)
:
PServerKernel
(
rank_id
,
pserver_num
)
{}
SparseApplyAdamPSKernel
(
size_t
rank_id
,
size_t
pserver_num
)
:
PServerKernel
(
rank_id
,
pserver_num
)
{}
~
SparseApplyAdamPSKernel
()
override
=
default
;
~
SparseApplyAdamPSKernel
()
override
=
default
;
void
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
InitKernel
(
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
)
override
;
const
std
::
vector
<
AddressPtr
>
&
outputs
)
override
;
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
浏览文件 @
af62d402
...
@@ -20,7 +20,7 @@ namespace mindspore {
...
@@ -20,7 +20,7 @@ namespace mindspore {
namespace
kernel
{
namespace
kernel
{
namespace
ps
{
namespace
ps
{
void
SparseApplyFtrlPSKernel
::
InitKernel
(
void
SparseApplyFtrlPSKernel
::
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
const
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>
&
shape_vec
=
*
shapes
;
const
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>
&
shape_vec
=
*
shapes
;
std
::
vector
<
size_t
>
var_shape
=
*
(
shape_vec
[
0
]);
std
::
vector
<
size_t
>
var_shape
=
*
(
shape_vec
[
0
]);
std
::
vector
<
size_t
>
accum_shape
=
*
(
shape_vec
[
1
]);
std
::
vector
<
size_t
>
accum_shape
=
*
(
shape_vec
[
1
]);
...
@@ -46,10 +46,22 @@ void SparseApplyFtrlPSKernel::InitKernel(
...
@@ -46,10 +46,22 @@ void SparseApplyFtrlPSKernel::InitKernel(
if
(
grad_shape
[
0
]
!=
indices_size_
)
{
if
(
grad_shape
[
0
]
!=
indices_size_
)
{
MS_LOG
(
EXCEPTION
)
<<
"The first dimension of grad shape must be equal to indices"
;
MS_LOG
(
EXCEPTION
)
<<
"The first dimension of grad shape must be equal to indices"
;
}
}
lr_
=
0.01
;
lr_
=
AnfAlgo
::
GetNodeAttr
<
float
>
(
cnode
,
"lr"
);
l1_
=
1e-8
;
if
(
lr_
<=
0
)
{
l2_
=
1e-8
;
MS_LOG
(
EXCEPTION
)
<<
"lr should be a positive scalar"
;
lr_power_
=
-
0.5
;
}
l1_
=
AnfAlgo
::
GetNodeAttr
<
float
>
(
cnode
,
"l1"
);
if
(
l1_
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"l1 should be a non-negative scalar"
;
}
l2_
=
AnfAlgo
::
GetNodeAttr
<
float
>
(
cnode
,
"l2"
);
if
(
l2_
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"l2 should be a non-negative scalar"
;
}
lr_power_
=
AnfAlgo
::
GetNodeAttr
<
float
>
(
cnode
,
"lr_power"
);
if
(
lr_power_
>
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"lr_power should be a non-positive scalar"
;
}
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
sizeof
(
int
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
sizeof
(
int
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h
浏览文件 @
af62d402
...
@@ -30,7 +30,8 @@ class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerK
...
@@ -30,7 +30,8 @@ class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerK
SparseApplyFtrlPSKernel
(
size_t
rank_id
,
size_t
pserver_num
)
:
PServerKernel
(
rank_id
,
pserver_num
)
{}
SparseApplyFtrlPSKernel
(
size_t
rank_id
,
size_t
pserver_num
)
:
PServerKernel
(
rank_id
,
pserver_num
)
{}
~
SparseApplyFtrlPSKernel
()
override
=
default
;
~
SparseApplyFtrlPSKernel
()
override
=
default
;
void
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
InitKernel
(
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc
浏览文件 @
af62d402
...
@@ -23,7 +23,7 @@ namespace mindspore {
...
@@ -23,7 +23,7 @@ namespace mindspore {
namespace
kernel
{
namespace
kernel
{
namespace
ps
{
namespace
ps
{
void
SparseApplyLazyAdamPSKernel
::
InitKernel
(
void
SparseApplyLazyAdamPSKernel
::
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
shapes
)
{
const
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>
&
shape_vec
=
*
shapes
;
const
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>
&
shape_vec
=
*
shapes
;
std
::
vector
<
size_t
>
&
var_shape
=
*
(
shape_vec
[
0
]);
std
::
vector
<
size_t
>
&
var_shape
=
*
(
shape_vec
[
0
]);
std
::
vector
<
size_t
>
&
m_shape
=
*
(
shape_vec
[
1
]);
std
::
vector
<
size_t
>
&
m_shape
=
*
(
shape_vec
[
1
]);
...
@@ -55,11 +55,9 @@ void SparseApplyLazyAdamPSKernel::InitKernel(
...
@@ -55,11 +55,9 @@ void SparseApplyLazyAdamPSKernel::InitKernel(
if
(
grad_shape
[
0
]
!=
indices_size_
)
{
if
(
grad_shape
[
0
]
!=
indices_size_
)
{
MS_LOG
(
ERROR
)
<<
"The first dimension of grad shape must be equal to indices"
;
MS_LOG
(
ERROR
)
<<
"The first dimension of grad shape must be equal to indices"
;
}
}
/*
if
(
AnfAlgo
::
HasNodeAttr
(
USE_NESTEROV
,
cnode
))
{
if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
use_nesterov_
=
AnfAlgo
::
GetNodeAttr
<
bool
>
(
cnode
,
"use_nesterov"
);
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
}
}
*/
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
sizeof
(
int
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
sizeof
(
int
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
workspace_size_list_
.
emplace_back
(
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
));
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h
浏览文件 @
af62d402
...
@@ -30,7 +30,8 @@ class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public
...
@@ -30,7 +30,8 @@ class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public
SparseApplyLazyAdamPSKernel
(
size_t
rank_id
,
size_t
pserver_num
)
:
PServerKernel
(
rank_id
,
pserver_num
)
{}
SparseApplyLazyAdamPSKernel
(
size_t
rank_id
,
size_t
pserver_num
)
:
PServerKernel
(
rank_id
,
pserver_num
)
{}
~
SparseApplyLazyAdamPSKernel
()
override
=
default
;
~
SparseApplyLazyAdamPSKernel
()
override
=
default
;
void
InitKernel
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
InitKernel
(
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
void
ReInit
(
const
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
size_t
>>>>
&
)
override
;
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
bool
Execute
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
)
override
;
const
std
::
vector
<
AddressPtr
>
&
outputs
)
override
;
...
...
mindspore/ccsrc/frontend/parallel/ps/common.h
浏览文件 @
af62d402
...
@@ -57,15 +57,20 @@ constexpr char kMomentum[] = "momentum";
...
@@ -57,15 +57,20 @@ constexpr char kMomentum[] = "momentum";
constexpr
char
kApplyMomentum
[]
=
"ApplyMomentum"
;
constexpr
char
kApplyMomentum
[]
=
"ApplyMomentum"
;
constexpr
char
kSparseAdam
[]
=
"Adam"
;
constexpr
char
kSparseAdam
[]
=
"Adam"
;
constexpr
char
kSparseFtrl
[]
=
"Ftrl"
;
constexpr
char
kSparseFtrl
[]
=
"Ftrl"
;
constexpr
char
kApplyMomentumOp
[]
=
"Momentum"
;
constexpr
char
kSparseAdamOp
[]
=
"Adam"
;
constexpr
char
kSparseFtrlOp
[]
=
"FTRL"
;
constexpr
int
kInitWeightsCmd
=
10
;
constexpr
int
kInitWeightsCmd
=
10
;
constexpr
int
kInitWeightToOptimIdCmd
=
11
;
constexpr
int
kInitWeightToOptimIdCmd
=
11
;
constexpr
int
kInitOptimInputsShapeCmd
=
12
;
constexpr
int
kInitOptimInputsShapeCmd
=
12
;
constexpr
int
kInitKeyToPushNodeIdCmd
=
13
;
constexpr
int
kInitEmbeddingsCmd
=
20
;
constexpr
int
kInitEmbeddingsCmd
=
20
;
constexpr
int
kEmbeddingLookupCmd
=
30
;
constexpr
int
kEmbeddingLookupCmd
=
30
;
constexpr
int
kFinalizeCmd
=
40
;
constexpr
int
kFinalizeCmd
=
40
;
constexpr
size_t
kInvalidKey
=
UINT64_MAX
;
constexpr
size_t
kInvalidKey
=
UINT64_MAX
;
constexpr
int
kInvalidID
=
-
1
;
using
Key
=
::
ps
::
Key
;
using
Key
=
::
ps
::
Key
;
using
Keys
=
::
ps
::
SArray
<
Key
>
;
using
Keys
=
::
ps
::
SArray
<
Key
>
;
...
...
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
100755 → 100644
浏览文件 @
af62d402
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <thread>
#include <thread>
#include <cmath>
#include <cmath>
#include <random>
#include <random>
#include <list>
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "backend/session/session_basic.h"
#include "backend/session/session_basic.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/anf_runtime_algorithm.h"
...
@@ -116,6 +117,7 @@ class ParameterServer {
...
@@ -116,6 +117,7 @@ class ParameterServer {
bool
ReadyForUpdateWeights
();
bool
ReadyForUpdateWeights
();
bool
ReadyForAccumGrads
();
bool
ReadyForAccumGrads
();
void
ResetGradAccumCount
();
void
ResetGradAccumCount
();
const
CNodePtr
GetCNode
(
const
std
::
string
&
name
)
const
;
size_t
pserver_num_
;
size_t
pserver_num_
;
size_t
worker_num_
;
size_t
worker_num_
;
...
@@ -132,6 +134,7 @@ class ParameterServer {
...
@@ -132,6 +134,7 @@ class ParameterServer {
std
::
unordered_map
<
Key
,
std
::
shared_ptr
<
OptimizerInfo
>>
optim_infos_
;
std
::
unordered_map
<
Key
,
std
::
shared_ptr
<
OptimizerInfo
>>
optim_infos_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
OptimizerInfoBuilder
>>
optim_info_builders_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
OptimizerInfoBuilder
>>
optim_info_builders_
;
std
::
unordered_map
<
Key
,
std
::
string
>
weight_key_to_optims_
;
std
::
unordered_map
<
Key
,
std
::
string
>
weight_key_to_optims_
;
std
::
unordered_map
<
Key
,
std
::
string
>
weight_key_to_optim_op_
;
std
::
unordered_map
<
Key
,
WeightPtr
>
weights_
;
std
::
unordered_map
<
Key
,
WeightPtr
>
weights_
;
std
::
unordered_map
<
Key
,
WeightPtr
>
grads_
;
std
::
unordered_map
<
Key
,
WeightPtr
>
grads_
;
std
::
unordered_map
<
Key
,
size_t
>
grads_accum_counter_
;
std
::
unordered_map
<
Key
,
size_t
>
grads_accum_counter_
;
...
@@ -277,7 +280,6 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
...
@@ -277,7 +280,6 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
handler_
->
Init
();
handler_
->
Init
();
InitOptimInfoBuilders
();
InitOptimInfoBuilders
();
ps_
->
set_request_handle
(
*
handler_
);
ps_
->
set_request_handle
(
*
handler_
);
thread_
.
reset
(
new
std
::
thread
(
&
ParameterServer
::
UpdateWeights
,
this
));
thread_
.
reset
(
new
std
::
thread
(
&
ParameterServer
::
UpdateWeights
,
this
));
return
true
;
return
true
;
...
@@ -299,6 +301,7 @@ void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_
...
@@ -299,6 +301,7 @@ void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_
return
;
return
;
}
}
weight_key_to_optims_
[
key
]
=
Util
::
optimizer_name
(
optim_id
);
weight_key_to_optims_
[
key
]
=
Util
::
optimizer_name
(
optim_id
);
weight_key_to_optim_op_
[
key
]
=
Util
::
optimizer_node_name
(
optim_id
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -321,27 +324,42 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
...
@@ -321,27 +324,42 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
}
}
if
(
weight_key_to_optims_
.
count
(
key
)
>
0
)
{
if
(
weight_key_to_optims_
.
count
(
key
)
>
0
)
{
const
std
::
string
&
optim_name
=
weight_key_to_optims_
[
key
];
const
std
::
string
&
optim_name
=
weight_key_to_optims_
[
key
];
const
std
::
string
&
optim_op_name
=
weight_key_to_optim_op_
[
key
];
if
(
optimizers_
.
count
(
key
)
==
0
&&
optim_inputs_shape_
.
count
(
key
)
>
0
)
{
if
(
optimizers_
.
count
(
key
)
==
0
&&
optim_inputs_shape_
.
count
(
key
)
>
0
)
{
const
CNodePtr
cnode
=
GetCNode
(
optim_op_name
);
MS_EXCEPTION_IF_NULL
(
cnode
);
if
(
optim_name
==
kSparseAdam
)
{
if
(
optim_name
==
kSparseAdam
)
{
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
make_shared
<
kernel
::
ps
::
SparseApplyLazyAdamPSKernel
>
(
rank_id_
,
pserver_num_
);
std
::
make_shared
<
kernel
::
ps
::
SparseApplyLazyAdamPSKernel
>
(
rank_id_
,
pserver_num_
);
optimizer
->
InitKernel
(
optim_inputs_shape_
[
key
]);
optimizer
->
InitKernel
(
cnode
,
optim_inputs_shape_
[
key
]);
optimizers_
[
key
]
=
optimizer
;
optimizers_
[
key
]
=
optimizer
;
}
else
if
(
optim_name
==
kApplyMomentum
)
{
}
else
if
(
optim_name
==
kApplyMomentum
)
{
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
make_shared
<
kernel
::
ps
::
ApplyMomentumPSKernel
>
(
rank_id_
,
pserver_num_
);
std
::
make_shared
<
kernel
::
ps
::
ApplyMomentumPSKernel
>
(
rank_id_
,
pserver_num_
);
optimizer
->
InitKernel
(
optim_inputs_shape_
[
key
]);
optimizer
->
InitKernel
(
cnode
,
optim_inputs_shape_
[
key
]);
optimizers_
[
key
]
=
optimizer
;
optimizers_
[
key
]
=
optimizer
;
}
else
if
(
optim_name
==
kSparseFtrl
)
{
}
else
if
(
optim_name
==
kSparseFtrl
)
{
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
make_shared
<
kernel
::
ps
::
SparseApplyFtrlPSKernel
>
(
rank_id_
,
pserver_num_
);
std
::
make_shared
<
kernel
::
ps
::
SparseApplyFtrlPSKernel
>
(
rank_id_
,
pserver_num_
);
optimizer
->
InitKernel
(
optim_inputs_shape_
[
key
]);
optimizer
->
InitKernel
(
cnode
,
optim_inputs_shape_
[
key
]);
optimizers_
[
key
]
=
optimizer
;
optimizers_
[
key
]
=
optimizer
;
}
}
}
}
}
}
}
}
template
<
typename
T
>
const
CNodePtr
ParameterServer
<
T
>::
GetCNode
(
const
std
::
string
&
name
)
const
{
std
::
list
<
CNodePtr
>
cnodes
=
func_graph_
->
GetOrderedCnodes
();
for
(
CNodePtr
cnode
:
cnodes
)
{
std
::
string
fullname
=
cnode
->
fullname_with_scope
();
if
(
fullname
.
find
(
name
)
!=
std
::
string
::
npos
&&
fullname
.
find
(
"Push"
)
!=
std
::
string
::
npos
)
{
return
cnode
;
}
}
return
nullptr
;
}
template
<
typename
T
>
template
<
typename
T
>
void
ParameterServer
<
T
>::
InitWeight
(
const
Key
&
key
,
const
WeightPtr
&
weight
)
{
void
ParameterServer
<
T
>::
InitWeight
(
const
Key
&
key
,
const
WeightPtr
&
weight
)
{
MS_LOG
(
INFO
)
<<
"Initializing weight for key "
<<
key
;
MS_LOG
(
INFO
)
<<
"Initializing weight for key "
<<
key
;
...
...
mindspore/ccsrc/frontend/parallel/ps/util.cc
浏览文件 @
af62d402
...
@@ -33,6 +33,13 @@ std::unordered_map<int, std::string> Util::id_to_optimizers{
...
@@ -33,6 +33,13 @@ std::unordered_map<int, std::string> Util::id_to_optimizers{
{
1
,
kSparseAdam
},
{
1
,
kSparseAdam
},
{
2
,
kSparseFtrl
},
{
2
,
kSparseFtrl
},
};
};
std
::
unordered_map
<
int
,
std
::
string
>
Util
::
id_to_optimizer_nodes
{
{
0
,
kApplyMomentumOp
},
{
1
,
kSparseAdamOp
},
{
2
,
kSparseFtrlOp
},
};
bool
Util
::
IsParamServerMode
()
{
return
IsRoleOfWorker
()
||
IsRoleOfPServer
()
||
IsRoleOfScheduler
();
}
bool
Util
::
IsParamServerMode
()
{
return
IsRoleOfWorker
()
||
IsRoleOfPServer
()
||
IsRoleOfScheduler
();
}
bool
Util
::
IsRoleOfWorker
()
{
bool
Util
::
IsRoleOfWorker
()
{
...
@@ -112,6 +119,13 @@ std::string Util::optimizer_name(int id) {
...
@@ -112,6 +119,13 @@ std::string Util::optimizer_name(int id) {
return
""
;
return
""
;
}
}
std
::
string
Util
::
optimizer_node_name
(
int
id
)
{
if
(
id_to_optimizer_nodes
.
count
(
id
)
>
0
)
{
return
id_to_optimizer_nodes
[
id
];
}
return
""
;
}
bool
Util
::
is_optimizer
(
std
::
string
name
)
{
return
optimizer_to_ids
.
count
(
name
)
>
0
;
}
bool
Util
::
is_optimizer
(
std
::
string
name
)
{
return
optimizer_to_ids
.
count
(
name
)
>
0
;
}
int
Util
::
LocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
)
{
int
Util
::
LocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
)
{
...
...
mindspore/ccsrc/frontend/parallel/ps/util.h
浏览文件 @
af62d402
...
@@ -34,12 +34,14 @@ class Util {
...
@@ -34,12 +34,14 @@ class Util {
static
void
SetInternalEnvVar
();
static
void
SetInternalEnvVar
();
static
int
optimizer_id
(
std
::
string
name
);
static
int
optimizer_id
(
std
::
string
name
);
static
std
::
string
optimizer_name
(
int
id
);
static
std
::
string
optimizer_name
(
int
id
);
static
std
::
string
optimizer_node_name
(
int
id
);
static
bool
is_optimizer
(
std
::
string
name
);
static
bool
is_optimizer
(
std
::
string
name
);
static
int
LocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
);
static
int
LocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
);
private:
private:
static
std
::
unordered_map
<
std
::
string
,
int
>
optimizer_to_ids
;
static
std
::
unordered_map
<
std
::
string
,
int
>
optimizer_to_ids
;
static
std
::
unordered_map
<
int
,
std
::
string
>
id_to_optimizers
;
static
std
::
unordered_map
<
int
,
std
::
string
>
id_to_optimizers
;
static
std
::
unordered_map
<
int
,
std
::
string
>
id_to_optimizer_nodes
;
};
};
}
// namespace ps
}
// namespace ps
}
// namespace parallel
}
// namespace parallel
...
...
mindspore/nn/optim/__init__.py
浏览文件 @
af62d402
...
@@ -20,14 +20,14 @@ The optimizer is used to calculate and update the gradients.
...
@@ -20,14 +20,14 @@ The optimizer is used to calculate and update the gradients.
"""
"""
from
.optimizer
import
Optimizer
from
.optimizer
import
Optimizer
from
.momentum
import
Momentum
from
.momentum
import
Momentum
from
.adam
import
Adam
,
PSAdam
,
AdamWeightDecay
from
.adam
import
Adam
,
AdamWeightDecay
from
.lamb
import
Lamb
from
.lamb
import
Lamb
from
.sgd
import
SGD
from
.sgd
import
SGD
from
.lars
import
LARS
from
.lars
import
LARS
from
.ftrl
import
FTRL
,
PSFTRL
from
.ftrl
import
FTRL
from
.rmsprop
import
RMSProp
from
.rmsprop
import
RMSProp
from
.proximal_ada_grad
import
ProximalAdagrad
from
.proximal_ada_grad
import
ProximalAdagrad
from
.lazyadam
import
LazyAdam
from
.lazyadam
import
LazyAdam
__all__
=
[
'Optimizer'
,
'Momentum'
,
'LARS'
,
'Adam'
,
'
PSAdam'
,
'
AdamWeightDecay'
,
'LazyAdam'
,
__all__
=
[
'Optimizer'
,
'Momentum'
,
'LARS'
,
'Adam'
,
'AdamWeightDecay'
,
'LazyAdam'
,
'Lamb'
,
'SGD'
,
'FTRL'
,
'
PSFTRL'
,
'
RMSProp'
,
'ProximalAdagrad'
]
'Lamb'
,
'SGD'
,
'FTRL'
,
'RMSProp'
,
'ProximalAdagrad'
]
mindspore/nn/optim/adam.py
浏览文件 @
af62d402
...
@@ -27,7 +27,6 @@ from mindspore._checkparam import Rel
...
@@ -27,7 +27,6 @@ from mindspore._checkparam import Rel
from
.optimizer
import
Optimizer
from
.optimizer
import
Optimizer
_adam_opt
=
C
.
MultitypeFuncGraph
(
"adam_opt"
)
_adam_opt
=
C
.
MultitypeFuncGraph
(
"adam_opt"
)
_adam_push_pull_opt
=
C
.
MultitypeFuncGraph
(
"_adam_push_pull_opt"
)
@
_adam_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
_adam_opt
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
...
@@ -85,22 +84,20 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
...
@@ -85,22 +84,20 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
return
gradient
return
gradient
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"
Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"IndexedSlices
"
,
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"
Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor
"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
"Tensor"
,
"
IndexedSlices"
,
"Tensor"
,
"
Tensor"
,
"Tensor"
,
"Bool"
)
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
moment1
,
moment2
,
ps_parameter
):
gradient
,
params
,
moment1
,
moment2
,
ps_parameter
):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
indices
=
gradient
.
indices
()
indices
=
gradient
.
indices
()
values
=
gradient
.
values
()
values
=
gradient
.
values
()
if
ps_parameter
:
if
ps_parameter
:
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
values
),
op_shape
(
indices
))
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
values
,
indices
),
shapes
),
params
))
eps
,
values
,
indices
),
shapes
),
params
))
else
:
else
:
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
...
@@ -108,54 +105,21 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2
...
@@ -108,54 +105,21 @@ def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2
return
success
return
success
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"
Tensor"
,
"Tensor
"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"
Function"
,
"Function
"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
"Tensor"
,
"Tensor"
,
"Tensor"
,
"
Tensor"
,
"Tensor"
,
"
Bool"
)
def
_run_opt_with_one_number
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_opt_with_one_number
(
opt
,
sparse_opt
,
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
moment1
,
moment2
,
ps_parameter
):
params
,
moment1
,
moment2
,
ps_parameter
):
"""Apply adam optimizer to the weight parameter using Tensor."""
"""Apply adam optimizer to the weight parameter using Tensor."""
success
=
True
success
=
True
if
ps_parameter
:
if
ps_parameter
:
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
),
_ps_push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
))),
params
))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
),
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
))),
params
))
else
:
else
:
success
=
F
.
depend
(
success
,
opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
success
=
F
.
depend
(
success
,
opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
))
eps
,
gradient
))
return
success
return
success
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"IndexedSlices"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_push_pull_opt_with_sparse
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
"""Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse."""
success
=
True
op_shape
=
P
.
Shape
()
values
=
gradient
.
values
()
indices
=
gradient
.
indices
()
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
values
,
indices
),
shapes
),
params
))
return
success
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_push_pull_opt_with_one_number
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
"""Apply adam optimizer by push and pull to the weight parameter using Tensor."""
success
=
True
op_shape
=
P
.
Shape
()
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
),
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
))),
params
))
return
success
def
_check_param_value
(
beta1
,
beta2
,
eps
,
prim_name
):
def
_check_param_value
(
beta1
,
beta2
,
eps
,
prim_name
):
"""Check the type of inputs."""
"""Check the type of inputs."""
validator
.
check_value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
...
@@ -285,6 +249,10 @@ class Adam(Optimizer):
...
@@ -285,6 +249,10 @@ class Adam(Optimizer):
self
.
opt
=
P
.
Adam
(
use_locking
,
use_nesterov
)
self
.
opt
=
P
.
Adam
(
use_locking
,
use_nesterov
)
self
.
sparse_opt
=
P
.
FusedSparseAdam
(
use_locking
,
use_nesterov
)
self
.
sparse_opt
=
P
.
FusedSparseAdam
(
use_locking
,
use_nesterov
)
self
.
_ps_pull
=
P
.
Pull
()
self
.
_ps_push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
self
.
_ps_push
.
add_prim_attr
(
"use_nesterov"
,
use_nesterov
)
def
construct
(
self
,
gradients
):
def
construct
(
self
,
gradients
):
params
=
self
.
parameters
params
=
self
.
parameters
moment1
=
self
.
moment1
moment1
=
self
.
moment1
...
@@ -298,63 +266,16 @@ class Adam(Optimizer):
...
@@ -298,63 +266,16 @@ class Adam(Optimizer):
beta2_power
=
self
.
beta2_power
*
self
.
beta2
beta2_power
=
self
.
beta2_power
*
self
.
beta2
self
.
beta2_power
=
beta2_power
self
.
beta2_power
=
beta2_power
if
self
.
is_group_lr
:
if
self
.
is_group_lr
:
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
beta1_power
,
beta2_power
,
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
self
.
_ps_push
,
self
.
_ps_pull
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
),
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
),
lr
,
gradients
,
params
,
moment1
,
moment2
,
self
.
ps_parameters
)
lr
,
gradients
,
params
,
moment1
,
moment2
,
self
.
ps_parameters
)
else
:
else
:
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
beta1_power
,
beta2_power
,
success
=
self
.
map_
(
F
.
partial
(
_adam_opt
,
self
.
opt
,
self
.
sparse_opt
,
self
.
_ps_push
,
self
.
_ps_pull
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
),
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
),
gradients
,
params
,
moment1
,
moment2
,
self
.
ps_parameters
)
gradients
,
params
,
moment1
,
moment2
,
self
.
ps_parameters
)
return
success
return
success
class
PSAdam
(
Optimizer
):
'''The same usage as Adam optimizer except the parameters are set PS mode.'''
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-8
,
use_locking
=
False
,
use_nesterov
=
False
,
weight_decay
=
0.0
,
loss_scale
=
1.0
):
super
(
PSAdam
,
self
).
__init__
(
learning_rate
,
params
,
weight_decay
,
loss_scale
)
_check_param_value
(
beta1
,
beta2
,
eps
,
self
.
cls_name
)
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
cls_name
)
validator
.
check_value_type
(
"use_nesterov"
,
use_nesterov
,
[
bool
],
self
.
cls_name
)
self
.
beta1
=
Tensor
(
beta1
,
mstype
.
float32
)
self
.
beta2
=
Tensor
(
beta2
,
mstype
.
float32
)
self
.
beta1_power
=
Parameter
(
initializer
(
1
,
[
1
],
mstype
.
float32
),
name
=
"beta1_power"
)
self
.
beta2_power
=
Parameter
(
initializer
(
1
,
[
1
],
mstype
.
float32
),
name
=
"beta2_power"
)
self
.
eps
=
Tensor
(
eps
,
mstype
.
float32
)
self
.
moment1
=
self
.
parameters
.
clone
(
prefix
=
"moment1"
,
init
=
'zeros'
)
self
.
moment2
=
self
.
parameters
.
clone
(
prefix
=
"moment2"
,
init
=
'zeros'
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
self
.
push
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
pull
=
P
.
Pull
()
self
.
pull
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
def
construct
(
self
,
gradients
):
params
=
self
.
parameters
moment1
=
self
.
moment1
moment2
=
self
.
moment2
gradients
=
self
.
decay_weight
(
gradients
)
gradients
=
self
.
scale_grad
(
gradients
)
lr
=
self
.
get_lr
()
beta1_power
=
self
.
beta1_power
*
self
.
beta1
self
.
beta1_power
=
beta1_power
beta2_power
=
self
.
beta2_power
*
self
.
beta2
self
.
beta2_power
=
beta2_power
if
self
.
is_group_lr
:
success
=
self
.
map_
(
F
.
partial
(
_adam_push_pull_opt
,
self
.
push
,
self
.
pull
,
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
),
lr
,
gradients
,
params
,
moment1
,
moment2
)
else
:
success
=
self
.
map_
(
F
.
partial
(
_adam_push_pull_opt
,
self
.
push
,
self
.
pull
,
beta1_power
,
beta2_power
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
),
gradients
,
params
,
moment1
,
moment2
)
return
success
class
AdamWeightDecay
(
Optimizer
):
class
AdamWeightDecay
(
Optimizer
):
"""
"""
Implements Adam algorithm weight decay fix.
Implements Adam algorithm weight decay fix.
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
af62d402
...
@@ -21,68 +21,40 @@ from mindspore._checkparam import Rel
...
@@ -21,68 +21,40 @@ from mindspore._checkparam import Rel
from
.optimizer
import
Optimizer
,
_apply_decay
,
_grad_scale
from
.optimizer
import
Optimizer
,
_apply_decay
,
_grad_scale
_ftrl_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
_ftrl_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
_ftrl_push_pull_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"
Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"IndexedSlices
"
,
"Tensor"
,
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"
Function"
,
"Function"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor
"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
"
IndexedSlices"
,
"Tensor"
,
"
Tensor"
,
"Bool"
)
def
_tensor_run_opt_with_sparse
(
opt
,
spars_opt
,
l1
,
l2
,
lr_power
,
learning_rate
,
linear
,
gradient
,
weight
,
moment
,
def
_tensor_run_opt_with_sparse
(
opt
,
spars_opt
,
push
,
pull
,
l1
,
l2
,
lr_power
,
learning_rate
,
linear
,
ps_parameter
):
gradient
,
weight
,
moment
,
ps_parameter
):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
indices
=
gradient
.
indices
()
indices
=
gradient
.
indices
()
values
=
gradient
.
values
()
values
=
gradient
.
values
()
if
ps_parameter
:
if
ps_parameter
:
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
values
),
op_shape
(
indices
))
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_
push
((
values
,
indices
),
shapes
),
weight
))
success
=
F
.
depend
(
success
,
pull
(
push
((
values
,
indices
),
shapes
),
weight
))
else
:
else
:
success
=
F
.
depend
(
success
,
spars_opt
(
weight
,
moment
,
linear
,
values
,
indices
))
success
=
F
.
depend
(
success
,
spars_opt
(
weight
,
moment
,
linear
,
values
,
indices
))
return
success
return
success
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Function"
,
"Function"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_tensor_run_opt
(
opt
,
spars_opt
,
l1
,
l2
,
lr_power
,
learning_rate
,
linear
,
gradient
,
weight
,
moment
,
ps_parameter
):
def
_tensor_run_opt
(
opt
,
spars_opt
,
push
,
pull
,
l1
,
l2
,
lr_power
,
learning_rate
,
linear
,
gradient
,
weight
,
moment
,
ps_parameter
):
"""Apply ftrl optimizer to the weight parameter."""
"""Apply ftrl optimizer to the weight parameter."""
success
=
True
success
=
True
if
ps_parameter
:
if
ps_parameter
:
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
success
=
F
.
depend
(
success
,
pull
(
push
((
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
),
_ps_push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
),
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
))),
weight
))
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
))),
weight
))
else
:
else
:
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
return
success
return
success
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"IndexedSlices"
,
"Tensor"
,
"Tensor"
)
def
_tensor_run_push_pull_opt_with_sparse
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
success
=
True
op_shape
=
P
.
Shape
()
values
=
gradient
.
values
()
indices
=
gradient
.
indices
()
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
pull
(
push
((
values
,
indices
),
shapes
),
weight
))
return
success
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_tensor_run_push_pull_opt_with_one_number
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
success
=
True
op_shape
=
P
.
Shape
()
success
=
F
.
depend
(
success
,
pull
(
push
((
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
),
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
))),
weight
))
return
success
def
_check_param
(
initial_accum
,
lr_power
,
l1
,
l2
,
use_locking
,
prim_name
=
None
):
def
_check_param
(
initial_accum
,
lr_power
,
l1
,
l2
,
use_locking
,
prim_name
=
None
):
"""Check param."""
"""Check param."""
validator
.
check_value_type
(
"initial_accum"
,
initial_accum
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"initial_accum"
,
initial_accum
,
[
float
],
prim_name
)
...
@@ -188,6 +160,12 @@ class FTRL(Optimizer):
...
@@ -188,6 +160,12 @@ class FTRL(Optimizer):
self
.
hyper_map
=
C
.
HyperMap
()
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyFtrl
(
use_locking
=
use_locking
)
self
.
opt
=
P
.
ApplyFtrl
(
use_locking
=
use_locking
)
self
.
sparse_opt
=
P
.
FusedSparseFtrl
(
learning_rate
,
l1
,
l2
,
lr_power
,
use_locking
=
use_locking
)
self
.
sparse_opt
=
P
.
FusedSparseFtrl
(
learning_rate
,
l1
,
l2
,
lr_power
,
use_locking
=
use_locking
)
self
.
_ps_pull
=
P
.
Pull
()
self
.
_ps_push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
self
.
_ps_push
.
add_prim_attr
(
"lr"
,
learning_rate
)
self
.
_ps_push
.
add_prim_attr
(
"l1"
,
l1
)
self
.
_ps_push
.
add_prim_attr
(
"l2"
,
l2
)
self
.
_ps_push
.
add_prim_attr
(
"lr_power"
,
lr_power
)
def
construct
(
self
,
grads
):
def
construct
(
self
,
grads
):
params
=
self
.
parameters
params
=
self
.
parameters
...
@@ -197,41 +175,7 @@ class FTRL(Optimizer):
...
@@ -197,41 +175,7 @@ class FTRL(Optimizer):
grads
=
self
.
scale_grad
(
grads
)
grads
=
self
.
scale_grad
(
grads
)
lr
=
self
.
get_lr
()
lr
=
self
.
get_lr
()
success
=
self
.
map_
(
F
.
partial
(
_ftrl_opt
,
self
.
opt
,
self
.
sparse_opt
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
,
lr
),
success
=
self
.
map_
(
F
.
partial
(
_ftrl_opt
,
self
.
opt
,
self
.
sparse_opt
,
self
.
_ps_push
,
self
.
_ps_pull
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
,
lr
),
linear
,
grads
,
params
,
moments
,
self
.
ps_parameters
)
linear
,
grads
,
params
,
moments
,
self
.
ps_parameters
)
return
success
return
success
class
PSFTRL
(
Optimizer
):
def
__init__
(
self
,
params
,
initial_accum
=
0.1
,
learning_rate
=
0.001
,
lr_power
=-
0.5
,
l1
=
0.0
,
l2
=
0.0
,
use_locking
=
False
,
loss_scale
=
1.0
,
weight_decay
=
0.0
):
super
(
PSFTRL
,
self
).
__init__
(
learning_rate
,
params
,
loss_scale
=
loss_scale
)
if
self
.
is_group
:
raise
RuntimeError
(
f
"The
{
self
.
cls_name
}
optimizer cannot support group setting."
)
_check_param
(
initial_accum
,
lr_power
,
l1
,
l2
,
use_locking
,
self
.
cls_name
)
self
.
moments
=
self
.
parameters
.
clone
(
prefix
=
"moments"
,
init
=
initial_accum
)
self
.
linear
=
self
.
parameters
.
clone
(
prefix
=
"linear"
,
init
=
'zeros'
)
self
.
l1
=
l1
self
.
l2
=
l2
self
.
lr_power
=
lr_power
self
.
weight_decay
=
weight_decay
self
.
decay_tf
=
tuple
((
lambda
:
True
)()
for
x
in
self
.
parameters
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
self
.
push
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
self
.
pull
=
P
.
Pull
()
self
.
pull
.
add_prim_attr
(
"primitive_target"
,
"CPU"
)
def
construct
(
self
,
grads
):
params
=
self
.
parameters
moments
=
self
.
moments
linear
=
self
.
linear
lr
=
self
.
learning_rate
if
self
.
weight_decay
>
0.0
:
grads
=
self
.
hyper_map
(
F
.
partial
(
_apply_decay
,
self
.
weight_decay
),
self
.
decay_tf
,
params
,
grads
)
grads
=
self
.
scale_grad
(
grads
)
success
=
self
.
map_
(
F
.
partial
(
_ftrl_push_pull_opt
,
self
.
push
,
self
.
pull
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
return
success
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录