Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7e70802b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7e70802b
编写于
8月 21, 2020
作者:
S
seiriosPlus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize init from pserver
上级
9ba2ae10
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
71 addition
and
8 deletion
+71
-8
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+2
-2
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+69
-6
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
7e70802b
...
@@ -193,7 +193,7 @@ void AsyncCommunicator::RecvNoBarrier() {
...
@@ -193,7 +193,7 @@ void AsyncCommunicator::RecvNoBarrier() {
auto
&
var_name
=
iter
.
first
;
auto
&
var_name
=
iter
.
first
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
recv_functor
(
iter
.
second
,
*
recv_scope_
,
false
);
recv_functor
(
iter
.
second
,
*
recv_scope_
);
};
};
task_futures
.
emplace_back
(
recv_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
task_futures
.
emplace_back
(
recv_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
}
}
...
@@ -700,7 +700,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) {
...
@@ -700,7 +700,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) {
auto
&
ctx
=
recv_varname_to_ctx_
.
at
(
varname
);
auto
&
ctx
=
recv_varname_to_ctx_
.
at
(
varname
);
auto
recv
=
distributed
::
ParameterRecv
<
float
>
();
auto
recv
=
distributed
::
ParameterRecv
<
float
>
();
recv
(
ctx
,
*
pserver_scope_
,
true
);
recv
(
ctx
,
*
pserver_scope_
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
var_psrever
->
IsInitialized
(),
true
,
var_psrever
->
IsInitialized
(),
true
,
...
...
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
7e70802b
...
@@ -41,8 +41,67 @@ using SelectedRows = framework::SelectedRows;
...
@@ -41,8 +41,67 @@ using SelectedRows = framework::SelectedRows;
using
DDim
=
framework
::
DDim
;
using
DDim
=
framework
::
DDim
;
template
<
typename
T
>
template
<
typename
T
>
void
RecvSelectedRows
(
const
CommContext
&
rpc_ctx
,
void
RecvSparseLodTensor
(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
const
framework
::
Scope
&
scope
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
std
::
unique_ptr
<
framework
::
Scope
>
local_scope
=
scope
.
NewTmpScope
();
std
::
vector
<
const
float
*>
tensors
;
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_varnames
.
size
();
i
++
)
{
auto
&
recv_var_name
=
rpc_ctx
.
splited_varnames
[
i
];
auto
*
local_var
=
local_scope
->
Var
(
recv_var_name
);
VLOG
(
4
)
<<
"recv "
<<
recv_var_name
<<
" from "
<<
rpc_ctx
.
epmap
[
i
];
// sparse param in recv_scope is LoDTensor
rets
.
push_back
(
rpc_client
->
AsyncGetVarNoBarrier
(
rpc_ctx
.
epmap
[
i
],
cpu_ctx
,
*
local_scope
.
get
(),
recv_var_name
,
recv_var_name
,
recv_var_name
));
const
auto
*
value
=
local_var
->
Get
<
framework
::
LoDTensor
>
().
data
<
float
>
();
tensors
.
push_back
(
value
);
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE_NE
(
rets
[
i
]
->
Wait
(),
0U
,
platform
::
errors
::
ExecutionTimeout
(
"internal error in RPCClient"
));
}
auto
*
merged_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
if
(
merged_var
==
nullptr
||
!
merged_var
->
IsInitialized
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s must initialized at first."
));
}
auto
dims1
=
merged_var
->
Get
<
framework
::
LoDTensor
>
().
dims
()[
1
];
int64_t
height
=
0
;
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_varnames
.
size
();
i
++
)
{
auto
*
splited_var
=
local_scope
->
FindVar
(
rpc_ctx
.
splited_varnames
[
i
]);
height
+=
splited_var
->
Get
<
framework
::
LoDTensor
>
().
dims
()[
0
];
}
PADDLE_ENFORCE_EQ
(
merged_var
->
Get
<
framework
::
LoDTensor
>
().
dims
()[
0
],
height
,
"recved var must has same dims with local var"
);
auto
*
merged_t
=
merged_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
merged_d
=
merged_t
->
mutable_data
<
float
>
(
place
)();
auto
pserver_num
=
rpc_ctx
.
splited_varnames
.
size
();
for
(
int
x
=
0
;
x
<
height
;
++
x
)
{
auto
id
=
x
%
pserver_num
;
auto
idx
=
x
/
pserver_num
;
std
::
memcpy
(
merged_d
+
x
*
dims1
,
tensors
[
id
]
+
idx
*
dims1
,
sizeof
(
float
)
*
dims1
);
}
}
template
<
typename
T
>
void
RecvGeoSparseRecords
(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
auto
&
cpu_ctx
=
*
pool
.
Get
(
cpu_place
);
...
@@ -151,7 +210,8 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
...
@@ -151,7 +210,8 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
template
<
typename
T
>
template
<
typename
T
>
void
ParameterRecv
<
T
>::
operator
()(
const
CommContext
&
rpc_ctx
,
void
ParameterRecv
<
T
>::
operator
()(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
,
bool
barrier
)
{
const
framework
::
Scope
&
scope
,
bool
geo_records
)
{
VLOG
(
3
)
<<
"ParameterRecv in "
<<
rpc_ctx
.
var_name
;
VLOG
(
3
)
<<
"ParameterRecv in "
<<
rpc_ctx
.
var_name
;
PADDLE_ENFORCE_GE
(
rpc_ctx
.
origin_varnames
.
size
(),
1
,
PADDLE_ENFORCE_GE
(
rpc_ctx
.
origin_varnames
.
size
(),
1
,
...
@@ -159,18 +219,21 @@ void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
...
@@ -159,18 +219,21 @@ void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
"origin_varnames.size() >= 1 is permitted"
));
"origin_varnames.size() >= 1 is permitted"
));
if
(
rpc_ctx
.
is_sparse
)
{
if
(
rpc_ctx
.
is_sparse
)
{
RecvSelectedRows
<
T
>
(
rpc_ctx
,
scope
);
if
(
geo_records
)
{
RecvGeoSparseRecords
()
<
T
>
(
rpc_ctx
,
scope
);
}
else
{
RecvSparseLodTensor
()
<
T
>
(
rpc_ctx
,
scope
);
}
}
else
{
}
else
{
RecvLodTensor
<
T
>
(
rpc_ctx
,
scope
);
RecvLodTensor
<
T
>
(
rpc_ctx
,
scope
);
}
}
VLOG
(
3
)
<<
"ParameterRecv out "
<<
rpc_ctx
.
var_name
;
VLOG
(
3
)
<<
"ParameterRecv out "
<<
rpc_ctx
.
var_name
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
ParameterRecv
<
T
>::
operator
()(
const
CommContext
&
rpc_ctx
,
void
ParameterRecv
<
T
>::
operator
()(
const
CommContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
const
framework
::
Scope
&
scope
)
{
this
->
operator
()(
rpc_ctx
,
scope
,
tru
e
);
this
->
operator
()(
rpc_ctx
,
scope
,
fals
e
);
}
}
template
struct
ParameterRecv
<
float
>;
template
struct
ParameterRecv
<
float
>;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录