Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7233d650
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7233d650
编写于
7月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3063 Enable to train in parameter server mode
Merge pull request !3063 from ZPaC/add-ps-training-mode
上级
25168309
52022c80
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
376 addition
and
141 deletion
+376
-141
cmake/external_libs/glog.cmake
cmake/external_libs/glog.cmake
+1
-1
cmake/options.cmake
cmake/options.cmake
+4
-0
mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt
mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt
+11
-8
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
...ackend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
+1
-1
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
.../kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
+3
-3
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc
...end/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc
+1
-1
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc
...spore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc
+8
-0
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h
+1
-1
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc
...end/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc
+1
-1
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
...end/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
+1
-1
mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc
...ore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc
+0
-1
mindspore/ccsrc/backend/session/ascend_session.cc
mindspore/ccsrc/backend/session/ascend_session.cc
+10
-0
mindspore/ccsrc/backend/session/cpu_session.cc
mindspore/ccsrc/backend/session/cpu_session.cc
+29
-0
mindspore/ccsrc/backend/session/cpu_session.h
mindspore/ccsrc/backend/session/cpu_session.h
+1
-0
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+8
-0
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+92
-0
mindspore/ccsrc/backend/session/session_basic.h
mindspore/ccsrc/backend/session/session_basic.h
+5
-1
mindspore/ccsrc/frontend/parallel/CMakeLists.txt
mindspore/ccsrc/frontend/parallel/CMakeLists.txt
+8
-1
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc
+8
-7
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h
+3
-2
mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc
...pore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc
+15
-10
mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h
...spore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h
+1
-1
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
+17
-22
mindspore/ccsrc/frontend/parallel/ps/worker.h
mindspore/ccsrc/frontend/parallel/ps/worker.h
+6
-5
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
+59
-73
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
+6
-0
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+45
-1
mindspore/ccsrc/pipeline/jit/action.h
mindspore/ccsrc/pipeline/jit/action.h
+5
-0
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+26
-0
未找到文件。
cmake/external_libs/glog.cmake
浏览文件 @
7233d650
set
(
glog_CXXFLAGS
"-D_FORTIFY_SOURCE=2 -O2
${
SECURE_CXX_FLAGS
}
"
)
set
(
glog_CXXFLAGS
"-D_FORTIFY_SOURCE=2 -O2
${
SECURE_CXX_FLAGS
}
-D_GLIBCXX_USE_CXX11_ABI=0
"
)
set
(
glog_CFLAGS
"-D_FORTIFY_SOURCE=2 -O2"
)
mindspore_add_pkg
(
glog
VER 0.4.0
...
...
cmake/options.cmake
浏览文件 @
7233d650
...
...
@@ -119,3 +119,7 @@ endif()
if
(
ENABLE_DEBUGGER
)
add_compile_definitions
(
ENABLE_DEBUGGER
)
endif
()
if
(
ENABLE_TESTCASES
)
add_compile_definitions
(
ENABLE_TESTCASES
)
endif
()
\ No newline at end of file
mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt
浏览文件 @
7233d650
...
...
@@ -26,14 +26,6 @@ if (ENABLE_CPU)
"cpu/*.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/push_kernel.cc"
"cpu/ps/pull_kernel.cc"
"cpu/ps/embedding_look_up_ps_kernel.cc"
"cpu/ps/embedding_look_up_proxy_kernel.cc"
"cpu/ps/apply_momentum_ps_kernel.cc"
"cpu/ps/sparse_apply_adam_ps_kernel.cc"
"cpu/ps/sparse_apply_ftrl_ps_kernel.cc"
)
if
(
NOT ENABLE_MPI
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/allgather_cpu_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/reduce_scatter_cpu_kernel.cc"
)
...
...
@@ -41,6 +33,17 @@ if (ENABLE_CPU)
endif
()
endif
()
if
(
${
CMAKE_SYSTEM_NAME
}
MATCHES
"Windows"
OR ENABLE_GE
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/apply_momentum_ps_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/embedding_look_up_proxy_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/embedding_look_up_ps_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/pserver_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/pull_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/push_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/sparse_apply_adam_ps_kernel.cc"
)
list
(
REMOVE_ITEM CPU_SRC_LIST
"cpu/ps/sparse_apply_ftrl_ps_kernel.cc"
)
endif
()
if
(
ENABLE_GPU
)
file
(
GLOB_RECURSE CUDA_SRC_LIST RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"gpu/*.cu"
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
浏览文件 @
7233d650
...
...
@@ -46,7 +46,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
)
override
;
pr
ivate
:
pr
otected
:
void
LookUpTable
(
const
std
::
vector
<
kernel
::
AddressPtr
>
&
inputs
,
size_t
dim0
,
size_t
dim1
,
size_t
dim2
,
float
**
output_addr
);
void
CheckParam
(
const
CNodePtr
&
kernel_node
);
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
浏览文件 @
7233d650
...
...
@@ -53,15 +53,15 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &i
size_t
output_size
=
outputs
[
0
]
->
size
;
size_t
size
=
input_size
/
sizeof
(
float
);
::
ps
::
SArray
<
floa
t
>
lookup_ids
(
size
,
0
);
::
ps
::
SArray
<
in
t
>
lookup_ids
(
size
,
0
);
::
ps
::
SArray
<
int
>
lengths
{
size
};
::
ps
::
SArray
<
float
>
lookup_result
;
::
ps
::
SArray
<
float
>
lookup_result
(
output_size
/
sizeof
(
float
),
0
)
;
auto
ret
=
memcpy_s
(
lookup_ids
.
data
(),
input_size
,
indices_addr
,
input_size
);
if
(
ret
!=
EOK
)
{
MS_LOG
(
EXCEPTION
)
<<
"Lookup id memcpy failed."
;
}
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
DoPSEmbeddingLookup
({
key_
},
lookup_ids
,
lengths
,
lookup_result
,
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
DoPSEmbeddingLookup
({
key_
},
lookup_ids
,
lengths
,
&
lookup_result
,
parallel
::
ps
::
kEmbeddingLookupCmd
);
auto
ret2
=
memcpy_s
(
output_addr
,
output_size
,
lookup_result
.
data
(),
output_size
);
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc
浏览文件 @
7233d650
...
...
@@ -50,7 +50,7 @@ void EmbeddingLookUpPSKernel::InitKernel(
split_num_
=
pserver_num_
;
// input shape should be sharded after computing offset_;
Shard
(
input_shape_
,
axis_
);
Shard
(
&
input_shape_
,
axis_
);
size_t
output_size
=
std
::
accumulate
(
output_shape_
.
begin
(),
output_shape_
.
end
(),
sizeof
(
float
),
std
::
multiplies
<
size_t
>
());
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc
浏览文件 @
7233d650
...
...
@@ -34,5 +34,13 @@ MS_REG_CPU_KERNEL_T(Push,
MS_REG_CPU_KERNEL_T
(
Push
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeUInt64
),
PushKernel
,
float
);
MS_REG_CPU_KERNEL_T
(
Push
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeUInt64
),
PushKernel
,
float
);
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h
浏览文件 @
7233d650
...
...
@@ -43,7 +43,7 @@ class PushKernel : public CPUKernel {
sizes
.
push_back
(
SizeToInt
(
input
->
size
)
/
sizeof
(
T
));
}
parallel
::
ps
::
Worker
<
T
>::
GetInstance
().
Push
(
keys
,
addrs
,
sizes
);
memcpy
(
outputs
[
0
]
->
addr
,
&
key_
,
sizeof
(
size_t
));
memcpy
_s
(
outputs
[
0
]
->
addr
,
sizeof
(
size_t
)
,
&
key_
,
sizeof
(
size_t
));
return
true
;
}
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc
浏览文件 @
7233d650
...
...
@@ -75,7 +75,7 @@ void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr<std::vector<std::shar
void
SparseApplyAdamPSKernel
::
ReInit
(
const
std
::
vector
<
AddressPtr
>
&
inputs
)
{
const
auto
&
indices_addr
=
inputs
[
10
];
indices_size_
=
indices_addr
->
size
;
indices_size_
=
indices_addr
->
size
/
sizeof
(
int
)
;
workspace_size_list_
[
0
]
=
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
);
workspace_size_list_
[
1
]
=
indices_size_
*
sizeof
(
int
);
}
...
...
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc
浏览文件 @
7233d650
...
...
@@ -64,7 +64,7 @@ void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr<std::vector<std::shar
void
SparseApplyFtrlPSKernel
::
ReInit
(
const
std
::
vector
<
AddressPtr
>
&
inputs
)
{
const
auto
&
indices_addr
=
inputs
[
4
];
indices_size_
=
indices_addr
->
size
;
indices_size_
=
indices_addr
->
size
/
sizeof
(
int
)
;
workspace_size_list_
[
0
]
=
indices_size_
*
var_outer_dim_size_
*
sizeof
(
float
);
workspace_size_list_
[
1
]
=
indices_size_
*
sizeof
(
int
);
}
...
...
mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc
浏览文件 @
7233d650
...
...
@@ -71,7 +71,6 @@ bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
AbstractBasePtrList
abstract_list
;
AnfAlgo
::
CopyNodeAttr
(
kAttrPsKey
,
cnode
,
proxy_node
);
AnfAlgo
::
CopyNodeAttr
(
"reduce_scatter_flag"
,
cnode
,
proxy_node
);
AnfAlgo
::
CopyNodeAttr
(
"offset"
,
cnode
,
proxy_node
);
abstract_list
.
push_back
(
cnode
->
abstract
());
auto
abstract_tuple
=
std
::
make_shared
<
abstract
::
AbstractTuple
>
(
abstract_list
);
...
...
mindspore/ccsrc/backend/session/ascend_session.cc
浏览文件 @
7233d650
...
...
@@ -353,6 +353,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
RootGraphExecutorValidate
(
NOT_NULL
(
root_graph
));
// adjust kernel
AdjustKernel
(
root_graph
);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Assign parameter keys.
AssignParamKey
(
root_graph
);
#endif
// assign stream
AssignStream
(
NOT_NULL
(
root_graph
));
// insert profiling point
...
...
@@ -511,6 +515,12 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
}
// load input data from user input
LoadInputData
(
kernel_graph
,
inputs
);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Initialize parameter server
if
(
!
ps_init_
)
{
InitPSParamAndOptim
(
kernel_graph
,
inputs
);
}
#endif
// convert inputs to model
predictmodel
::
StepConvertWeight
(
inputs
);
{
...
...
mindspore/ccsrc/backend/session/cpu_session.cc
浏览文件 @
7233d650
...
...
@@ -25,9 +25,15 @@
#include "predict/predict.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/pass/replace_node_by_proxy.h"
#ifdef ENABLE_DEBUGGER
#include "debug/debugger/debugger.h"
#endif
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/util.h"
#endif
namespace
mindspore
{
namespace
session
{
...
...
@@ -49,12 +55,29 @@ ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf,
return
new_parameter
;
}
void
CPUSession
::
Optimize
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
{
auto
optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
std
::
string
pass_name
=
"replace_node_by_proxy"
;
pass_name
.
append
(
std
::
to_string
(
graph_sum_
));
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceNodeByProxy
>
(
pass_name
));
optimizer
->
AddPassManager
(
pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
}
GraphId
CPUSession
::
CompileGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
{
auto
graph_id
=
graph_sum_
;
auto
graph
=
ConstructKernelGraph
(
lst
,
outputs
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_LOG
(
INFO
)
<<
"Set kernel info"
;
SetKernelInfo
(
graph
.
get
());
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
AssignParamKey
(
graph
);
if
(
parallel
::
ps
::
Util
::
IsRoleOfWorker
())
{
Optimize
(
graph
);
}
#endif
predictmodel
::
StepConvertGraph
(
graph
);
MS_LOG
(
INFO
)
<<
"Build kernel"
;
BuildKernel
(
graph
.
get
());
...
...
@@ -66,6 +89,12 @@ GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
void
CPUSession
::
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
{
auto
&
kernel_graph
=
graphs_
[
graph_id
];
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Initialize parameter server
if
(
!
ps_init_
)
{
InitPSParamAndOptim
(
kernel_graph
,
inputs
);
}
#endif
MS_LOG
(
INFO
)
<<
"Bind input output address"
;
std
::
vector
<
tensor
::
TensorPtr
>
need_sync_outputs
;
runtime_
.
BindInputOutput
(
kernel_graph
.
get
(),
inputs
,
outputs
,
&
need_sync_outputs
);
...
...
mindspore/ccsrc/backend/session/cpu_session.h
浏览文件 @
7233d650
...
...
@@ -37,6 +37,7 @@ class CPUSession : public SessionBasic {
protected:
ParameterPtr
CreateNewParameterFromParameter
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
)
override
;
void
Optimize
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
);
private:
void
SetKernelInfo
(
const
KernelGraph
*
kernel_graph
);
...
...
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
7233d650
...
...
@@ -177,6 +177,10 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
Optimize
(
graph
);
// Select kernel build info
SelectKernel
(
graph
);
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
// Assign parameter keys.
AssignParamKey
(
graph
);
#endif
// Convert kernel Graph to model
predictmodel
::
StepConvertGraph
(
graph
);
// Start gpu kernel runtime
...
...
@@ -214,6 +218,10 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
auto
&
kernel_graph
=
graphs_
[
graph_id
];
// Load input data from user input
LoadInputData
(
kernel_graph
,
inputs
);
// Initialize parameter server
if
(
!
ps_init_
)
{
InitPSParamAndOptim
(
kernel_graph
,
inputs
);
}
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
// Convert inputs to model
predictmodel
::
StepConvertWeight
(
inputs
);
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
7233d650
...
...
@@ -35,6 +35,11 @@
#include "ir/dtype.h"
#include "ir/anf.h"
#include "ir/func_graph_cloner.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/worker.h"
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/util.h"
#endif
namespace
mindspore
{
namespace
session
{
...
...
@@ -1097,5 +1102,92 @@ KernelGraphPtr SessionBasic::NewKernelGraph() {
graphs_
[
graph_sum_
++
]
=
graph
;
return
graph
;
}
AnfNodePtr
SessionBasic
::
FindPullNode
(
const
AnfNodePtr
&
push_node
,
const
std
::
vector
<
AnfNodePtr
>
&
node_list
)
{
MS_EXCEPTION_IF_NULL
(
push_node
);
for
(
auto
&
node
:
node_list
)
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
())
{
for
(
auto
input
:
node
->
cast
<
CNodePtr
>
()
->
inputs
())
{
if
(
push_node
==
AnfAlgo
::
VisitKernel
(
input
,
0
).
first
)
{
if
(
AnfAlgo
::
GetCNodeName
(
node
)
!=
kPullOpName
)
{
MS_LOG
(
EXCEPTION
)
<<
"The edge between Push and Pull node is invalid."
;
}
return
node
;
}
}
}
}
return
nullptr
;
}
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
void
SessionBasic
::
AssignParamKey
(
const
KernelGraphPtr
&
kernel_graph
)
{
if
(
!
parallel
::
ps
::
Util
::
IsRoleOfWorker
())
{
MS_LOG
(
INFO
)
<<
"Not parameter server mode."
;
return
;
}
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
kernel_graph
->
get_return
());
for
(
auto
&
node
:
node_list
)
{
if
(
node
!=
nullptr
&&
node
->
isa
<
CNode
>
())
{
// Assign key for forward kernel EmbeddingLookup.
// The key will be assigned to embedding table ande Push kernel as well.
if
(
AnfAlgo
::
GetCNodeName
(
node
)
==
kEmbeddingLookupOpName
)
{
size_t
embedding_table_idx
=
0
;
auto
embedding_table
=
AnfAlgo
::
GetInputNode
(
node
->
cast
<
CNodePtr
>
(),
embedding_table_idx
);
size_t
key
=
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
SetParamKey
(
embedding_table
->
fullname_with_scope
());
AnfAlgo
::
SetNodeAttr
(
kAttrPsKey
,
MakeValue
(
key
),
node
);
}
else
if
(
AnfAlgo
::
GetCNodeName
(
node
)
==
kPushOpName
)
{
auto
pull_node
=
FindPullNode
(
node
,
node_list
);
if
(
!
pull_node
)
{
MS_LOG
(
EXCEPTION
)
<<
"Assigning parameter key failed: can't find Pull node of the Push node."
;
}
// Second input of Pull node is the trainable parameter.
size_t
parameter_index
=
1
;
auto
parameter_node
=
AnfAlgo
::
GetInputNode
(
pull_node
->
cast
<
CNodePtr
>
(),
parameter_index
);
size_t
key
=
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
SetParamKey
(
parameter_node
->
fullname_with_scope
());
AnfAlgo
::
SetNodeAttr
(
kAttrPsKey
,
MakeValue
(
key
),
node
);
AnfAlgo
::
SetNodeAttr
(
kAttrPsKey
,
MakeValue
(
key
),
pull_node
);
std
::
string
optimizer_name
=
AnfAlgo
::
GetNodeAttr
<
std
::
string
>
(
node
,
kAttrOptimizerType
);
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
SetKeyOptimId
(
key
,
optimizer_name
);
}
}
}
}
void
SessionBasic
::
InitPSParamAndOptim
(
const
KernelGraphPtr
&
kernel_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs_const
)
{
if
(
!
parallel
::
ps
::
Util
::
IsRoleOfWorker
())
{
return
;
}
std
::
vector
<
tensor
::
TensorPtr
>
inputs
(
inputs_const
);
size_t
input_ctrl_size
=
1
;
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
if
(
kernel_graph
->
input_ctrl_tensors
())
{
input_ctrl_size
=
LoadCtrlInputTensor
(
kernel_graph
,
&
inputs
);
}
auto
input_nodes
=
kernel_graph
->
inputs
();
if
((
inputs
.
size
()
+
input_ctrl_size
)
-
1
!=
input_nodes
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Tensor input:"
<<
inputs
.
size
()
<<
" is not equal graph inputs:"
<<
input_nodes
.
size
()
<<
", input_ctrl_size:"
<<
input_ctrl_size
;
}
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
tensor
=
inputs
[
i
];
MS_EXCEPTION_IF_NULL
(
tensor
);
auto
input_node
=
input_nodes
[
i
];
MS_EXCEPTION_IF_NULL
(
input_node
);
if
(
input_node
->
isa
<
Parameter
>
()
&&
AnfAlgo
::
OutputAddrExist
(
input_node
,
0
))
{
auto
pk_node
=
input_node
->
cast
<
ParameterPtr
>
();
mindspore
::
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
InitPSParamAndOptim
(
pk_node
->
fullname_with_scope
(),
tensor
->
data_c
(),
LongToSize
(
tensor
->
data
().
nbytes
()));
}
}
ps_init_
=
true
;
}
#endif
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/backend/session/session_basic.h
浏览文件 @
7233d650
...
...
@@ -51,7 +51,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class
SessionBasic
{
public:
SessionBasic
()
:
context_
(
nullptr
),
summary_callback_
(
nullptr
),
device_id_
(
0
)
{
SessionBasic
()
:
context_
(
nullptr
),
summary_callback_
(
nullptr
),
device_id_
(
0
)
,
ps_init_
(
false
)
{
#ifdef ENABLE_DEBUGGER
debugger_
=
nullptr
;
#endif
...
...
@@ -104,6 +104,8 @@ class SessionBasic {
virtual
GraphId
GetFinalRunGraph
()
const
{
return
kInvalidGraphId
;
}
virtual
void
SetActive
(
GraphId
,
GraphId
)
{}
virtual
void
GetSummaryNodes
(
KernelGraph
*
graph
);
void
AssignParamKey
(
const
KernelGraphPtr
&
kernel_graph
);
void
InitPSParamAndOptim
(
const
KernelGraphPtr
&
kernel_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs_const
);
#ifdef ENABLE_DEBUGGER
// set debugger
...
...
@@ -140,6 +142,7 @@ class SessionBasic {
AnfNodePtr
CreateNewParameterFromCNode
(
const
AnfNodePtr
&
anf
,
bool
valid_input
,
KernelGraph
*
graph
);
void
AddParameterToGraphInputs
(
const
std
::
vector
<
AnfNodePtr
>
&
parameters
,
KernelGraph
*
graph
);
void
InitInternalOutputParameter
(
const
AnfNodePtr
&
out_node
,
const
AnfNodePtr
&
parameter
);
AnfNodePtr
FindPullNode
(
const
AnfNodePtr
&
push_node
,
const
std
::
vector
<
AnfNodePtr
>
&
node_list
);
std
::
unordered_map
<
GraphId
,
std
::
shared_ptr
<
KernelGraph
>>
graphs_
;
std
::
unordered_map
<
GraphInfo
,
std
::
shared_ptr
<
KernelGraph
>>
run_op_graphs_
;
...
...
@@ -148,6 +151,7 @@ class SessionBasic {
CallBackFunc
summary_callback_
;
static
GraphId
graph_sum_
;
uint32_t
device_id_
;
bool
ps_init_
;
#ifdef ENABLE_DEBUGGER
std
::
shared_ptr
<
Debugger
>
debugger_
;
#endif
...
...
mindspore/ccsrc/frontend/parallel/CMakeLists.txt
浏览文件 @
7233d650
file
(
GLOB_RECURSE _PARALLEL_SRC_FILES RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"*.cc"
)
list
(
REMOVE_ITEM _PARALLEL_SRC_FILES
"ps/util.cc"
"ps/scheduler.cc"
"ps/optimizer_info.cc"
"ps/optimizer_info_builder.cc"
)
if
(
${
CMAKE_SYSTEM_NAME
}
MATCHES
"Windows"
OR ENABLE_GE
)
list
(
REMOVE_ITEM _PARALLEL_SRC_FILES
"ps/optimizer_info_builder.cc"
)
list
(
REMOVE_ITEM _PARALLEL_SRC_FILES
"ps/optimizer_info.cc"
)
list
(
REMOVE_ITEM _PARALLEL_SRC_FILES
"ps/scheduler.cc"
)
list
(
REMOVE_ITEM _PARALLEL_SRC_FILES
"ps/util.cc"
)
endif
()
if
(
ENABLE_DUMP_PROTO
)
list
(
REMOVE_ITEM _PARALLEL_SRC_FILES
"parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc"
)
endif
()
...
...
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc
浏览文件 @
7233d650
...
...
@@ -118,11 +118,13 @@ const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; }
const
AddressPtr
&
MomentumOptimInfo
::
indices
()
{
return
inputs_
[
3
];
}
size_t
MomentumOptimInfo
::
grad_index
()
{
return
1
;
}
SparseAdamOptimInfo
::
SparseAdamOptimInfo
(
const
AddressPtr
&
weight
,
const
AddressPtr
&
m
,
const
AddressPtr
&
v
,
const
AddressPtr
&
beta1_power
,
const
AddressPtr
&
beta2_power
,
const
AddressPtr
&
learning_rate
,
const
AddressPtr
&
beta1
,
const
AddressPtr
&
beta2
,
const
AddressPtr
&
epsilon
,
const
AddressPtr
&
grad
,
const
AddressPtr
&
indices
,
size_t
grads_offset
,
size_t
indices_offset
)
{
const
AddressPtr
&
indices
)
{
inputs_
.
push_back
(
weight
);
inputs_
.
push_back
(
m
);
inputs_
.
push_back
(
v
);
...
...
@@ -134,8 +136,8 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address
inputs_
.
push_back
(
epsilon
);
inputs_
.
push_back
(
grad
);
inputs_
.
push_back
(
indices
);
grads_offset_
=
grads_offset
;
indices_offset_
=
indices_offset
;
grads_offset_
=
0
;
indices_offset_
=
0
;
}
void
SparseAdamOptimInfo
::
Update
(
const
Values
&
values
,
const
Lengths
&
lens
)
{
...
...
@@ -159,15 +161,14 @@ size_t SparseAdamOptimInfo::grad_index() { return 6; }
size_t
SparseAdamOptimInfo
::
indices_index
()
{
return
7
;
}
SparseFtrlOptimInfo
::
SparseFtrlOptimInfo
(
const
AddressPtr
&
weight
,
const
AddressPtr
&
accum
,
const
AddressPtr
&
linear
,
const
AddressPtr
&
grad
,
const
AddressPtr
&
indices
,
size_t
grads_offset
,
size_t
indices_offset
)
{
const
AddressPtr
&
grad
,
const
AddressPtr
&
indices
)
{
inputs_
.
push_back
(
weight
);
inputs_
.
push_back
(
accum
);
inputs_
.
push_back
(
linear
);
inputs_
.
push_back
(
grad
);
inputs_
.
push_back
(
indices
);
grads_offset_
=
grads_offset
;
indices_offset_
=
indices_offset
;
grads_offset_
=
0
;
indices_offset_
=
0
;
}
const
AddressPtr
&
SparseFtrlOptimInfo
::
gradient
()
{
return
inputs_
[
3
];
}
...
...
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h
浏览文件 @
7233d650
...
...
@@ -81,6 +81,7 @@ class MomentumOptimInfo : public DenseOptimInfo {
const
AddressPtr
&
gradient
();
const
AddressPtr
&
indices
();
size_t
grad_index
()
override
;
};
class
SparseAdamOptimInfo
:
public
SparseOptimInfo
{
...
...
@@ -88,7 +89,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo {
SparseAdamOptimInfo
(
const
AddressPtr
&
weight
,
const
AddressPtr
&
m
,
const
AddressPtr
&
v
,
const
AddressPtr
&
beta1_power
,
const
AddressPtr
&
beta2_power
,
const
AddressPtr
&
learning_rate
,
const
AddressPtr
&
beta1
,
const
AddressPtr
&
beta2
,
const
AddressPtr
&
epsilon
,
const
AddressPtr
&
grad
,
const
AddressPtr
&
indices
,
size_t
grads_offset
,
size_t
indices_offset
);
const
AddressPtr
&
indices
);
~
SparseAdamOptimInfo
()
override
=
default
;
void
Update
(
const
Values
&
values
,
const
Lengths
&
lens
)
override
;
...
...
@@ -102,7 +103,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo {
class
SparseFtrlOptimInfo
:
public
SparseOptimInfo
{
public:
SparseFtrlOptimInfo
(
const
AddressPtr
&
weight
,
const
AddressPtr
&
accum
,
const
AddressPtr
&
linear
,
const
AddressPtr
&
grad
,
const
AddressPtr
&
indices
,
size_t
grads_offset
,
size_t
indices_offset
);
const
AddressPtr
&
grad
,
const
AddressPtr
&
indices
);
~
SparseFtrlOptimInfo
()
override
=
default
;
const
AddressPtr
&
gradient
();
...
...
mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc
浏览文件 @
7233d650
...
...
@@ -48,20 +48,25 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co
size_t
worker_num
)
{
AddressPtr
weight_addr
=
std
::
make_shared
<
kernel
::
Address
>
();
weight_addr
->
addr
=
weight
->
data
();
weight_addr
->
size
=
weight
->
size
();
weight_addr
->
size
=
weight
->
size
()
*
sizeof
(
float
)
;
void
*
data_ptr
=
values
.
data
();
void
*
copy_data_ptr
=
new
float
[
values
.
size
()];
auto
ret
=
memcpy_s
(
copy_data_ptr
,
values
.
size
()
*
sizeof
(
float
),
data_ptr
,
values
.
size
()
*
sizeof
(
float
));
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
<<
")"
;
}
AddressPtr
accumulate
=
std
::
make_shared
<
kernel
::
Address
>
();
accumulate
->
addr
=
new
float
[
weight
->
size
()];
accumulate
->
size
=
weight
->
size
();
accumulate
->
size
=
weight
->
size
()
*
sizeof
(
float
)
;
AddressPtr
learning_rate
=
std
::
make_shared
<
kernel
::
Address
>
();
learning_rate
->
addr
=
data_ptr
;
learning_rate
->
size
=
lens
[
0
];
learning_rate
->
addr
=
copy_
data_ptr
;
learning_rate
->
size
=
lens
[
0
]
*
sizeof
(
float
)
;
AddressPtr
gradient
=
std
::
make_shared
<
kernel
::
Address
>
();
gradient
->
addr
=
reinterpret_cast
<
float
*>
(
learning_rate
->
addr
)
+
lens
[
0
];
gradient
->
size
=
lens
[
1
];
gradient
->
size
=
lens
[
1
]
*
sizeof
(
float
)
;
AddressPtr
momentum
=
std
::
make_shared
<
kernel
::
Address
>
();
momentum
->
addr
=
reinterpret_cast
<
float
*>
(
gradient
->
addr
)
+
lens
[
1
];
momentum
->
size
=
lens
[
2
];
momentum
->
size
=
lens
[
2
]
*
sizeof
(
float
)
;
return
new
MomentumOptimInfo
(
weight_addr
,
accumulate
,
learning_rate
,
gradient
,
momentum
);
}
...
...
@@ -131,10 +136,10 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
if
(
ret3
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret3
<<
")"
;
}
indices
->
size
=
lens
[
7
]
*
sizeof
(
floa
t
);
indices
->
size
=
lens
[
7
]
*
sizeof
(
in
t
);
return
new
SparseAdamOptimInfo
(
weight_addr
,
m
,
v
,
beta1_power
,
beta2_power
,
learning_rate
,
beta1
,
beta2
,
epsilon
,
grad
,
indices
,
total_grad_size
,
total_indice_size
);
grad
,
indices
);
}
OptimizerInfo
*
SparseFtrlOptimInfoBuilder
::
BuildInputs
(
const
WeightPtr
&
weight
,
const
Keys
&
keys
,
const
Values
&
values
,
...
...
@@ -175,9 +180,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
if
(
ret2
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret2
<<
")"
;
}
indices
->
size
=
lens
[
1
]
*
sizeof
(
floa
t
);
indices
->
size
=
lens
[
1
]
*
sizeof
(
in
t
);
return
new
SparseFtrlOptimInfo
(
weight_addr
,
accum
,
linear
,
grad
,
indices
,
total_grad_size
,
total_indice_size
);
return
new
SparseFtrlOptimInfo
(
weight_addr
,
accum
,
linear
,
grad
,
indices
);
}
}
// namespace ps
}
// namespace parallel
...
...
mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h
浏览文件 @
7233d650
...
...
@@ -19,7 +19,7 @@
#include <vector>
#include <memory>
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/ps/pserver_kernel.h"
#include "backend/kernel_compiler/
cpu/
ps/pserver_kernel.h"
#include "frontend/parallel/ps/optimizer_info.h"
namespace
mindspore
{
...
...
mindspore/ccsrc/frontend/parallel/ps/parameter_server.h
浏览文件 @
7233d650
...
...
@@ -40,12 +40,12 @@
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/context/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/ps/pserver_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/ps/sparse_apply_adam_ps_kernel.h"
#include "backend/kernel_compiler/ps/sparse_apply_ftrl_ps_kernel.h"
#include "backend/kernel_compiler/ps/apply_momentum_ps_kernel.h"
#include "backend/kernel_compiler/ps/embedding_look_up_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
namespace
mindspore
{
namespace
parallel
{
...
...
@@ -118,7 +118,7 @@ class ParameterServer {
std
::
shared_ptr
<
session
::
KernelGraph
>
kernel_graph_
;
std
::
shared_ptr
<
session
::
SessionBasic
>
sess_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
PServerKernel
>>
optimizers_
;
std
::
unordered_map
<
Key
,
std
::
shared_ptr
<
PServerKernel
>>
optimizers_
;
std
::
unordered_map
<
Key
,
InputsShapePtr
>
optim_inputs_shape_
;
std
::
unordered_map
<
Key
,
std
::
shared_ptr
<
OptimizerInfo
>>
optim_infos_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
OptimizerInfoBuilder
>>
optim_info_builders_
;
...
...
@@ -249,10 +249,10 @@ template <typename T>
void
ParameterServer
<
T
>::
ServerHandler
::
HandleEmbeddingLookup
(
const
::
ps
::
KVMeta
&
req_meta
,
const
::
ps
::
KVPairs
<
T
>
&
req_data
,
::
ps
::
KVPairs
<
T
>
*
res
)
{
const
Key
&
key
=
req_data
.
keys
[
0
];
ps_
->
DoEmbeddingLookup
(
key
,
req_data
.
vals
,
res
);
for
(
size_t
i
=
0
;
i
<
req_data
.
vals
.
size
();
i
++
)
{
res
->
keys
->
push_back
(
req_data
.
vals
[
i
]);
res
->
keys
.
push_back
(
req_data
.
vals
[
i
]);
}
ps_
->
DoEmbeddingLookup
(
key
,
req_data
.
vals
,
res
);
}
template
<
typename
T
>
...
...
@@ -288,7 +288,7 @@ void ParameterServer<T>::InitOptimInfoBuilders() {
template
<
typename
T
>
void
ParameterServer
<
T
>::
InitWeightKeyToOptims
(
const
Key
&
key
,
const
int
&
optim_id
)
{
if
(
weight_key_to_optims_
.
count
(
key
)
>
0
||
Util
::
optimizer_name
(
key
)
==
""
)
{
if
(
weight_key_to_optims_
.
count
(
key
)
>
0
||
Util
::
optimizer_name
(
optim_id
)
==
""
)
{
return
;
}
weight_key_to_optims_
[
key
]
=
Util
::
optimizer_name
(
optim_id
);
...
...
@@ -314,22 +314,22 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
}
if
(
weight_key_to_optims_
.
count
(
key
)
>
0
)
{
const
std
::
string
&
optim_name
=
weight_key_to_optims_
[
key
];
if
(
optimizers_
.
count
(
optim_name
)
==
0
&&
optim_inputs_shape_
.
count
(
key
)
>
0
)
{
if
(
optimizers_
.
count
(
key
)
==
0
&&
optim_inputs_shape_
.
count
(
key
)
>
0
)
{
if
(
optim_name
==
kSparseAdam
)
{
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
make_shared
<
kernel
::
ps
::
SparseApplyAdamPSKernel
>
(
rank_id_
,
pserver_num_
);
optimizer
->
InitKernel
(
optim_inputs_shape_
[
key
]);
optimizers_
[
optim_name
]
=
optimizer
;
optimizers_
[
key
]
=
optimizer
;
}
else
if
(
optim_name
==
kApplyMomentum
)
{
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
make_shared
<
kernel
::
ps
::
ApplyMomentumPSKernel
>
(
rank_id_
,
pserver_num_
);
optimizer
->
InitKernel
(
optim_inputs_shape_
[
key
]);
optimizers_
[
optim_name
]
=
optimizer
;
optimizers_
[
key
]
=
optimizer
;
}
else
if
(
optim_name
==
kSparseFtrl
)
{
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
std
::
make_shared
<
kernel
::
ps
::
SparseApplyFtrlPSKernel
>
(
rank_id_
,
pserver_num_
);
optimizer
->
InitKernel
(
optim_inputs_shape_
[
key
]);
optimizers_
[
optim_name
]
=
optimizer
;
optimizers_
[
key
]
=
optimizer
;
}
}
}
...
...
@@ -382,8 +382,7 @@ void ParameterServer<T>::UpdateWeights() {
std
::
shared_ptr
<
PServerKernel
>
optimizer
=
nullptr
;
if
(
weight_key_to_optims_
.
count
(
key
)
>
0
)
{
const
std
::
string
&
optim_name
=
weight_key_to_optims_
[
key
];
optimizer
=
optimizers_
[
optim_name
];
optimizer
=
optimizers_
[
key
];
}
MS_EXCEPTION_IF_NULL
(
optimizer
);
...
...
@@ -391,8 +390,6 @@ void ParameterServer<T>::UpdateWeights() {
if
(
optim_info
==
nullptr
)
{
continue
;
}
const
WeightPtr
&
weight
=
weights_
[
key
];
optim_info
->
UpdateWeight
(
weight
);
const
std
::
vector
<
kernel
::
AddressPtr
>
&
inputs
=
optim_info
->
inputs
();
const
std
::
vector
<
kernel
::
AddressPtr
>
&
workspaces
=
optim_info
->
workspaces
();
const
std
::
vector
<
kernel
::
AddressPtr
>
&
outputs
=
optim_info
->
outputs
();
...
...
@@ -416,7 +413,7 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
// Create or update the optimizer info
if
(
optim_info
==
nullptr
)
{
const
std
::
shared_ptr
<
OptimizerInfoBuilder
>
&
builder
=
optim_info_builders_
[
weight_key_to_optims_
[
key
]];
std
::
shared_ptr
<
kernel
::
ps
::
PServerKernel
>
pserver_kernel
=
optimizers_
[
weight_key_to_optims_
[
key
]
];
std
::
shared_ptr
<
kernel
::
ps
::
PServerKernel
>
pserver_kernel
=
optimizers_
[
key
];
if
(
pserver_kernel
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"no optimizer found for key "
<<
key
<<
" optim name "
<<
weight_key_to_optims_
[
key
];
}
...
...
@@ -427,10 +424,8 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
optim_infos_
[
key
]
=
optim_info
;
}
else
{
optim_info
->
Update
(
values
,
lengths
);
optim_info
->
Accumulate
(
values
,
lengths
);
}
MS_EXCEPTION_IF_NULL
(
optim_info
);
optim_info
->
Accumulate
(
values
,
lengths
);
grads_accum_counter_
[
key
]
+=
1
;
if
(
grads_accum_counter_
[
key
]
==
worker_num_
)
{
...
...
@@ -499,7 +494,7 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids,
table_lookup_op
->
Execute
(
inputs
,
workspaces
,
outputs
);
res
->
vals
=
*
addr
;
res
->
lens
.
push_back
(
res
.
vals
.
size
());
res
->
lens
.
push_back
(
res
->
vals
.
size
());
}
template
<
typename
T
>
...
...
mindspore/ccsrc/frontend/parallel/ps/worker.h
浏览文件 @
7233d650
...
...
@@ -48,7 +48,7 @@ class Worker {
void
AddEmbeddingTable
(
const
::
ps
::
Key
&
key
,
const
size_t
&
row_count
);
void
InitPSEmbeddingTable
(
const
std
::
vector
<
size_t
>
&
keys
,
std
::
vector
<
size_t
>
shapes
,
const
std
::
vector
<
int
>
&
sizes
);
void
InitPSParamAndOptim
(
const
std
::
string
&
param_name
,
void
*
param_data
,
size_t
param_size
);
void
DoPSEmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
void
DoPSEmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
const
::
ps
::
SArray
<
int
>
&
lens
,
::
ps
::
SArray
<
T
>
*
lookup_result
,
int
cmd
);
private:
...
...
@@ -98,7 +98,8 @@ void Worker<T>::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> add
::
ps
::
SArray
<
T
>
total_buffer
(
total_size
,
0
);
size_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
sizes
.
size
();
i
++
)
{
memcpy
(
total_buffer
.
data
()
+
offset
/
sizeof
(
T
),
addrs
[
i
],
sizes
[
i
]
*
sizeof
(
T
));
memcpy_s
(
total_buffer
.
data
()
+
offset
/
sizeof
(
T
),
sizes
[
i
]
*
sizeof
(
T
),
reinterpret_cast
<
void
*>
(
addrs
[
i
]),
sizes
[
i
]
*
sizeof
(
T
));
offset
+=
sizes
[
i
]
*
sizeof
(
T
);
}
kv_worker_
->
PushData
(
::
ps
::
SArray
<::
ps
::
Key
>
(
keys
),
total_buffer
,
::
ps
::
SArray
<
int
>
(
sizes
));
...
...
@@ -108,13 +109,13 @@ template <typename T>
void
Worker
<
T
>::
Pull
(
const
size_t
key
,
void
*
dev_addr
,
const
size_t
size
)
{
::
ps
::
SArray
<
T
>
variables
(
size
/
sizeof
(
T
),
0
);
kv_worker_
->
Wait
(
kv_worker_
->
ZPull
({
key
},
&
variables
));
memcpy
(
dev_addr
,
variables
.
data
(),
size
);
memcpy
_s
(
dev_addr
,
size
,
variables
.
data
(),
size
);
}
template
<
typename
T
>
void
Worker
<
T
>::
DoPSEmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
void
Worker
<
T
>::
DoPSEmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
const
::
ps
::
SArray
<
int
>
&
lens
,
::
ps
::
SArray
<
T
>
*
lookup_result
,
int
cmd
)
{
kv_worker_
->
EmbeddingLookup
(
keys
,
lookup_ids
,
lens
,
&
lookup_result
,
cmd
);
kv_worker_
->
EmbeddingLookup
(
keys
,
lookup_ids
,
lens
,
lookup_result
,
cmd
);
}
template
<
typename
T
>
...
...
mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h
浏览文件 @
7233d650
...
...
@@ -22,6 +22,7 @@
#include <utility>
#include <memory>
#include <vector>
#include <unordered_set>
#include "ps/ps.h"
#include "frontend/parallel/ps/util.h"
...
...
@@ -34,24 +35,23 @@ class WorkerProxy : public ::ps::KVWorker<T> {
using
Worker
=
::
ps
::
KVWorker
<
T
>
;
using
Callback
=
std
::
function
<
void
()
>
;
using
SlicedKVs
=
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
;
using
Slicer
=
std
::
function
<
void
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
ranges
,
SlicedKVs
*
sliced
)
>
;
using
Slicer
=
std
::
function
<
void
(
int
ts
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
ranges
,
SlicedKVs
*
sliced
)
>
;
using
::
ps
::
SimpleApp
::
obj_
;
explicit
WorkerProxy
(
int
app_id
,
int
customer_id
,
int
lookup_customer_id
)
:
Worker
(
app_id
,
customer_id
)
{
using
_1
=
std
::
placeholders
::
_1
;
using
_2
=
std
::
placeholders
::
_2
;
using
_3
=
std
::
placeholders
::
_3
;
using
std
::
placeholders
::
_1
;
using
std
::
placeholders
::
_2
;
using
std
::
placeholders
::
_3
;
using
std
::
placeholders
::
_4
;
lookup_customer_
=
std
::
unique_ptr
<::
ps
::
Customer
>
(
new
::
ps
::
Customer
(
app_id
,
lookup_customer_id
,
std
::
bind
(
&
WorkerProxy
<
T
>::
ProcessLookupResult
,
this
,
_1
)));
lookup_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
LookupIdSlicer
,
this
,
_1
,
_2
,
_3
);
init_embedding_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
EmbeddingTableInitSlicer
,
this
,
_1
,
_2
,
_3
);
push_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
PushSlicer
,
this
,
_1
,
_2
,
_3
);
broadcast_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
BroadcastSlicer
,
this
,
_1
,
_2
,
_3
);
lookup_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
LookupIdSlicer
,
this
,
_1
,
_2
,
_3
,
_4
);
broadcast_slicer_
=
std
::
bind
(
&
WorkerProxy
<
T
>::
BroadcastSlicer
,
this
,
_1
,
_2
,
_3
,
_4
);
}
~
WorkerProxy
()
override
=
default
;
void
AddEmbeddingTable
(
const
::
ps
::
Key
&
key
,
const
size_t
&
row_count
);
void
EmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
void
EmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
const
::
ps
::
SArray
<
int
>
&
lens
,
::
ps
::
SArray
<
T
>
*
outs
,
int
cmd
=
0
,
const
Callback
&
cb
=
nullptr
,
int
priority
=
0
);
int
InitEmbeddingTable
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
vals
,
...
...
@@ -61,15 +61,11 @@ class WorkerProxy : public ::ps::KVWorker<T> {
private:
template
<
typename
C
>
int
AddLookupCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
C
*
vals
,
int
cmd
,
int
AddLookupCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
C
*
vals
,
int
cmd
,
const
Callback
&
cb
);
void
LookupIdSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
void
LookupIdSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
EmbeddingTableInitSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
PushSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
BroadcastSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
void
BroadcastSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
);
void
ProcessLookupResult
(
const
::
ps
::
Message
&
msg
);
void
Send
(
::
ps
::
Customer
*
customer
,
int
timestamp
,
bool
push
,
bool
pull
,
int
cmd
,
const
::
ps
::
KVPairs
<
T
>
&
kvs
,
...
...
@@ -80,10 +76,9 @@ class WorkerProxy : public ::ps::KVWorker<T> {
std
::
unordered_map
<
int
,
std
::
vector
<::
ps
::
KVPairs
<
T
>>>
lookup_results_
;
std
::
mutex
mutex_
;
Slicer
lookup_slicer_
;
Slicer
init_embedding_slicer_
;
Slicer
push_slicer_
;
Slicer
broadcast_slicer_
;
std
::
unordered_map
<
int
,
Callback
>
lookup_callbacks_
;
std
::
unordered_map
<
int
,
int
>
expected_result_count_
;
};
template
<
typename
T
>
...
...
@@ -108,17 +103,21 @@ void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_c
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
EmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
void
WorkerProxy
<
T
>::
EmbeddingLookup
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
const
::
ps
::
SArray
<
int
>
&
lens
,
::
ps
::
SArray
<
T
>
*
outs
,
int
cmd
,
const
Callback
&
cb
,
int
priority
)
{
int
ts
=
AddLookupCB
(
keys
,
lookup_ids
,
outs
,
cmd
,
cb
);
::
ps
::
KVPairs
<
T
>
kvs
;
kvs
.
keys
=
keys
;
kvs
.
vals
=
lookup_ids
;
kvs
.
lens
=
lens
;
kvs
.
lens
=
lookup_ids
;
kvs
.
priority
=
priority
;
Send
(
lookup_customer_
.
get
(),
ts
,
true
,
true
,
cmd
,
kvs
,
broadcast_slicer_
);
expected_result_count_
[
ts
]
=
0
;
Send
(
lookup_customer_
.
get
(),
ts
,
true
,
true
,
cmd
,
kvs
,
lookup_slicer_
);
int
server_num
=
::
ps
::
NumServers
();
int
expect_rt_count
=
expected_result_count_
[
ts
];
lookup_customer_
->
AddResponse
(
ts
,
server_num
-
expect_rt_count
);
lookup_customer_
->
WaitRequest
(
ts
);
expected_result_count_
.
erase
(
ts
);
}
template
<
typename
T
>
...
...
@@ -130,7 +129,7 @@ int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, cons
kvs
.
vals
=
vals
;
kvs
.
lens
=
lens
;
kvs
.
priority
=
priority
;
Send
(
obj_
,
ts
,
true
,
false
,
kInitEmbeddingsCmd
,
kvs
,
init_embedding
_slicer_
);
Send
(
obj_
,
ts
,
true
,
false
,
kInitEmbeddingsCmd
,
kvs
,
broadcast
_slicer_
);
return
ts
;
}
...
...
@@ -143,13 +142,13 @@ void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::S
kvs
.
vals
=
vals
;
kvs
.
lens
=
lens
;
kvs
.
priority
=
priority
;
Send
(
obj_
,
ts
,
true
,
false
,
cmd
,
kvs
,
push
_slicer_
);
Send
(
obj_
,
ts
,
true
,
false
,
cmd
,
kvs
,
broadcast
_slicer_
);
obj_
->
WaitRequest
(
ts
);
}
template
<
typename
T
>
template
<
typename
C
>
int
WorkerProxy
<
T
>::
AddLookupCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
T
>
&
lookup_ids
,
int
WorkerProxy
<
T
>::
AddLookupCB
(
const
::
ps
::
SArray
<::
ps
::
Key
>
&
keys
,
const
::
ps
::
SArray
<
int
>
&
lookup_ids
,
C
*
lookup_result
,
int
cmd
,
const
Callback
&
cb
)
{
int
ts
=
lookup_customer_
->
NewRequest
(
::
ps
::
kServerGroup
);
const
auto
&
callback
=
[
this
,
ts
,
keys
,
lookup_ids
,
lookup_result
,
cb
]()
mutable
{
...
...
@@ -158,18 +157,28 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
mutex_
.
unlock
();
size_t
total_len
=
0
;
const
auto
&
s
=
kvs
[
0
];
for
(
size_t
i
=
0
;
i
<
s
.
lens
.
size
();
i
++
)
{
total_len
+=
s
.
lens
[
i
];
std
::
unordered_map
<
Key
,
std
::
shared_ptr
<
std
::
pair
<
T
*
,
int
>>>
id_addr_map
;
for
(
const
auto
&
s
:
kvs
)
{
int
offset
=
0
;
int
len
=
s
.
vals
.
size
()
/
s
.
keys
.
size
();
for
(
size_t
i
=
0
;
i
<
s
.
keys
.
size
();
i
++
)
{
const
Key
&
key
=
s
.
keys
[
i
];
T
*
addr
=
s
.
vals
.
data
()
+
offset
;
offset
+=
len
;
total_len
+=
len
;
id_addr_map
[
key
]
=
std
::
make_shared
<
std
::
pair
<
T
*
,
int
>>
(
std
::
make_pair
(
addr
,
len
));
}
}
lookup_result
->
resize
(
total_len
,
0
);
T
*
result_addr
=
lookup_result
->
data
();
for
(
const
auto
&
s
:
kvs
)
{
size_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
s
.
vals
.
size
();
i
++
)
{
result_addr
[
offset
++
]
+=
s
.
vals
[
i
];
T
*
result_addr
=
lookup_result
->
data
();
int
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
lookup_ids
.
size
();
i
++
)
{
auto
&
pair
=
id_addr_map
[
static_cast
<
Key
>
(
lookup_ids
[
i
])];
auto
ret
=
memcpy_s
(
result_addr
+
offset
,
pair
->
second
,
pair
->
first
,
pair
->
second
);
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"memcpy_s error, errorno("
<<
ret
<<
")"
;
}
offset
+=
pair
->
second
;
}
mutex_
.
lock
();
...
...
@@ -182,31 +191,30 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
LookupIdSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
void
WorkerProxy
<
T
>::
LookupIdSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
int
*
data
=
send
.
lens
.
data
();
size_t
size
=
send
.
lens
.
size
();
std
::
vector
<
int
>
lookup_ids
(
data
,
data
+
size
);
std
::
sort
(
lookup_ids
.
begin
(),
lookup_ids
.
end
());
int
*
lookup_ids
=
send
.
lens
.
data
();
size_t
id_size
=
send
.
lens
.
size
();
const
Key
&
key
=
send
.
keys
[
0
];
const
std
::
vector
<::
ps
::
Range
>
&
ranges
=
*
(
embedding_table_ranges_
[
key
]);
sliced
->
resize
(
ranges
.
size
());
size_t
index
=
0
;
for
(
size_t
i
=
0
;
i
<
ranges
.
size
();
i
++
)
{
const
::
ps
::
Range
&
range
=
ranges
[
i
];
const
auto
&
begin
=
range
.
begin
();
const
auto
&
end
=
range
.
end
();
std
::
unordered_set
<
int
>
unique_ids
;
auto
&
kvs
=
sliced
->
at
(
i
).
second
;
auto
lookup_id
=
static_cast
<
uint64_t
>
(
lookup_ids
[
index
]);
while
(
lookup_id
>=
begin
&&
lookup_id
<=
end
)
{
kvs
.
vals
.
push_back
(
lookup_id
);
if
(
++
index
>=
lookup_ids
.
size
())
{
break
;
for
(
size_t
j
=
0
;
j
<
id_size
;
j
++
)
{
auto
lookup_id
=
static_cast
<
uint64_t
>
(
lookup_ids
[
j
]);
if
(
lookup_id
>=
begin
&&
lookup_id
<=
end
)
{
unique_ids
.
insert
(
lookup_id
);
}
lookup_id
=
static_cast
<
uint64_t
>
(
lookup_ids
[
index
]);
}
for
(
const
auto
&
lookup_id
:
unique_ids
)
{
kvs
.
vals
.
push_back
(
lookup_id
);
}
kvs
.
keys
.
push_back
(
key
);
kvs
.
lens
.
push_back
(
kvs
.
vals
.
size
());
...
...
@@ -215,35 +223,13 @@ void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vec
sliced
->
at
(
i
).
first
=
false
;
}
else
{
sliced
->
at
(
i
).
first
=
true
;
expected_result_count_
[
timestamp
]
+=
1
;
}
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
EmbeddingTableInitSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
const
Key
&
key
=
send
.
keys
[
0
];
const
std
::
vector
<::
ps
::
Range
>
&
ranges
=
*
(
embedding_table_ranges_
[
key
]);
sliced
->
resize
(
ranges
.
size
());
for
(
size_t
i
=
0
;
i
<
ranges
.
size
();
i
++
)
{
sliced
->
at
(
i
).
first
=
true
;
sliced
->
at
(
i
).
second
=
send
;
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
PushSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
auto
server_num
=
::
ps
::
Postoffice
::
Get
()
->
num_servers
();
sliced
->
resize
(
server_num
);
for
(
int
i
=
0
;
i
<
server_num
;
i
++
)
{
sliced
->
at
(
i
).
first
=
true
;
sliced
->
at
(
i
).
second
=
send
;
}
}
template
<
typename
T
>
void
WorkerProxy
<
T
>::
BroadcastSlicer
(
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
void
WorkerProxy
<
T
>::
BroadcastSlicer
(
int
timestamp
,
const
::
ps
::
KVPairs
<
T
>
&
send
,
const
std
::
vector
<::
ps
::
Range
>
&
,
std
::
vector
<
std
::
pair
<
bool
,
::
ps
::
KVPairs
<
T
>>>
*
sliced
)
{
auto
server_num
=
::
ps
::
Postoffice
::
Get
()
->
num_servers
();
sliced
->
resize
(
server_num
);
...
...
@@ -268,7 +254,7 @@ void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
lookup_results_
[
ts
].
push_back
(
kvs
);
mutex_
.
unlock
();
}
if
(
lookup_customer_
->
NumResponse
(
ts
)
==
::
ps
::
Postoffice
::
Get
()
->
num_servers
()
-
1
)
{
if
(
lookup_customer_
->
NumResponse
(
ts
)
==
expected_result_count_
[
ts
]
-
1
)
{
const
auto
&
cb
=
lookup_callbacks_
[
ts
];
cb
();
lookup_callbacks_
.
erase
(
ts
);
...
...
@@ -279,7 +265,7 @@ template <typename T>
void
WorkerProxy
<
T
>::
Send
(
::
ps
::
Customer
*
customer
,
int
timestamp
,
bool
push
,
bool
pull
,
int
cmd
,
const
::
ps
::
KVPairs
<
T
>
&
kvs
,
const
Slicer
&
slicer
)
{
SlicedKVs
sliced
;
slicer
(
kvs
,
::
ps
::
Postoffice
::
Get
()
->
GetServerKeyRanges
(),
&
sliced
);
slicer
(
timestamp
,
kvs
,
::
ps
::
Postoffice
::
Get
()
->
GetServerKeyRanges
(),
&
sliced
);
for
(
size_t
i
=
0
;
i
<
sliced
.
size
();
i
++
)
{
const
auto
&
s
=
sliced
[
i
];
...
...
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
浏览文件 @
7233d650
...
...
@@ -146,6 +146,12 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
target_link_libraries
(
_c_dataengine PRIVATE _c_mindrecord
${
MINDRECORD_LINK_OBJECT
}
mindspore::sqlite
)
else
()
target_link_libraries
(
_c_dataengine PRIVATE _c_mindrecord
)
if
(
NOT ENABLE_GE
)
target_link_libraries
(
_c_dataengine PRIVATE mindspore::pslite mindspore::protobuf
${
zeromq_DIRPATH
}
/zmq_install/lib/libzmq.a
)
if
(
${
ENABLE_IBVERBS
}
STREQUAL
"ON"
)
target_link_libraries
(
_c_dataengine PRIVATE ibverbs rdmacm
)
endif
()
endif
()
endif
()
if
(
USE_GLOG
)
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
7233d650
...
...
@@ -40,6 +40,11 @@
#include "vm/transform.h"
#include "parse/python_adapter.h"
#include "frontend/optimizer/py_pass_manager.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/parameter_server.h"
#include "frontend/parallel/ps/scheduler.h"
#include "frontend/parallel/ps/worker.h"
#endif
namespace
mindspore
{
namespace
pipeline
{
...
...
@@ -374,6 +379,25 @@ bool ExecuteAction(const ResourcePtr &res) {
return
true
;
}
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
bool
StartPSWorkerAction
(
const
ResourcePtr
&
res
)
{
parallel
::
ps
::
Worker
<
float
>::
GetInstance
().
Run
();
return
true
;
}
bool
StartPSServerAction
(
const
ResourcePtr
&
res
)
{
FuncGraphPtr
func_graph
=
res
->
func_graph
();
auto
&
ps
=
parallel
::
ps
::
ParameterServer
<
float
>::
GetInstance
();
ps
.
Run
(
func_graph
);
return
true
;
}
bool
StartPSSchedulerAction
(
const
ResourcePtr
&
res
)
{
parallel
::
ps
::
Scheduler
::
GetInstance
().
Run
();
return
true
;
}
#endif
// The parallel primitive related valuenode might be partitioned so that its value changes by device,
// that will result in a syncronization error due to different executing order.
// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
...
...
@@ -481,7 +505,11 @@ std::vector<ActionItem> VmPipeline() {
actions
.
emplace_back
(
std
::
make_pair
(
"py_opt"
,
OptActionPyStub
));
actions
.
emplace_back
(
std
::
make_pair
(
"validate"
,
ValidateAction
));
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
if
(
parallel
::
ps
::
Util
::
IsRoleOfWorker
())
{
actions
.
emplace_back
(
std
::
make_pair
(
"worker"
,
StartPSWorkerAction
));
}
#endif
// compile the ANF graph
actions
.
emplace_back
(
std
::
make_pair
(
"task_emit"
,
TaskEmitAction
));
...
...
@@ -490,5 +518,21 @@ std::vector<ActionItem> VmPipeline() {
return
actions
;
}
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
std
::
vector
<
ActionItem
>
PServerPipeline
()
{
auto
actions
=
CommonPipeline
();
actions
.
emplace_back
(
std
::
make_pair
(
"optimize"
,
VmOptimizeAction
));
actions
.
emplace_back
(
std
::
make_pair
(
"validate"
,
ValidateAction
));
actions
.
emplace_back
(
std
::
make_pair
(
"pserver"
,
StartPSServerAction
));
return
actions
;
}
std
::
vector
<
ActionItem
>
PSchedulerPipeline
()
{
std
::
vector
<
ActionItem
>
actions
;
actions
.
emplace_back
(
std
::
make_pair
(
"scheduler"
,
StartPSSchedulerAction
));
return
actions
;
}
#endif
}
// namespace pipeline
}
// namespace mindspore
mindspore/ccsrc/pipeline/jit/action.h
浏览文件 @
7233d650
...
...
@@ -38,9 +38,14 @@ bool VmOptimizeAction(const ResourcePtr &res);
bool
PynativeOptimizeAction
(
const
ResourcePtr
&
res
);
bool
TaskEmitAction
(
const
ResourcePtr
&
res
);
bool
ExecuteAction
(
const
ResourcePtr
&
res
);
bool
StartPSWorkerAction
(
const
ResourcePtr
&
res
);
bool
StartPSServerAction
(
const
ResourcePtr
&
res
);
bool
StartPSSchedulerAction
(
const
ResourcePtr
&
res
);
std
::
vector
<
ActionItem
>
GePipeline
();
std
::
vector
<
ActionItem
>
VmPipeline
();
std
::
vector
<
ActionItem
>
PServerPipeline
();
std
::
vector
<
ActionItem
>
PSchedulerPipeline
();
abstract
::
AnalysisResult
AbstractAnalyze
(
const
ResourcePtr
&
res
,
const
FuncGraphPtr
&
func_graph
,
const
abstract
::
AbstractBasePtrList
&
args_spec
,
bool
clear
=
false
);
FuncGraphPtr
ProgramSpecialize
(
const
ResourcePtr
&
res
,
const
FuncGraphPtr
&
func_graph
,
...
...
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
7233d650
...
...
@@ -41,6 +41,11 @@
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/common.h"
#include "frontend/parallel/ps/util.h"
#endif
#if (ENABLE_GE || ENABLE_D)
#include "pipeline/jit/pipeline_ge.h"
#include "transform/graph_ir/convert.h"
...
...
@@ -420,6 +425,26 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
use_vm
=
ChangeExportGeirUseVmFlag
(
use_vm
,
phase_s
);
std
::
string
backend
=
MsContext
::
GetInstance
()
->
backend_policy
();
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
if
(
mindspore
::
parallel
::
ps
::
Util
::
IsParamServerMode
())
{
mindspore
::
parallel
::
ps
::
Util
::
SetInternalEnvVar
();
}
if
(
parallel
::
ps
::
Util
::
IsRoleOfPServer
())
{
resource
->
results
()[
kBackend
]
=
compile
::
CreateBackend
();
p_actions
=
PServerPipeline
();
}
else
if
(
parallel
::
ps
::
Util
::
IsRoleOfScheduler
())
{
p_actions
=
PSchedulerPipeline
();
}
else
if
(
use_vm
&&
backend
!=
"ge"
)
{
// Create backend and session
auto
backend_ptr
=
compile
::
CreateBackend
();
// Connect session to debugger
backend_ptr
->
SetDebugger
();
resource
->
results
()[
kBackend
]
=
backend_ptr
;
p_actions
=
VmPipeline
();
}
else
{
p_actions
=
GePipeline
();
}
#else
if
(
use_vm
&&
backend
!=
"ge"
)
{
// Create backend and session
auto
backend_ptr
=
compile
::
CreateBackend
();
...
...
@@ -430,6 +455,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
}
else
{
p_actions
=
GePipeline
();
}
#endif
std
::
shared_ptr
<
Pipeline
>
pip
=
std
::
make_shared
<
Pipeline
>
(
resource
,
FilterActions
(
p_actions
,
phase_s
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录