Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
255b36da
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
255b36da
编写于
3月 06, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
can run
上级
8c38aca9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
60 addition
and
14 deletion
+60
-14
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+11
-2
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+1
-1
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+6
-0
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+1
-1
paddle/fluid/operators/distributed/rpc_common.h
paddle/fluid/operators/distributed/rpc_common.h
+32
-4
paddle/fluid/operators/distributed_ops/CMakeLists.txt
paddle/fluid/operators/distributed_ops/CMakeLists.txt
+2
-2
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+7
-4
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
255b36da
...
...
@@ -59,6 +59,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
auto
recv_var_name
=
node
->
Op
()
->
Input
(
"X"
)[
0
];
auto
recv_varnames
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
...
...
@@ -68,13 +70,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{});
graphs
[
i
]
->
RemoveNode
(
node
);
VLOG
(
3
)
<<
"find and remove an recv op: "
<<
recv_varname_to_ctx
[
recv_var_name
];
}
}
}
}
// init communicator here
operators
::
distributed
::
Communicator
::
Init
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
scope
);
if
(
send_varname_to_ctx
.
size
()
>
0
)
{
VLOG
(
3
)
<<
"this is distribute mode, will use "
;
operators
::
distributed
::
Communicator
::
Init
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
scope
);
}
}
AsyncSSAGraphExecutor
::
AsyncSSAGraphExecutor
(
...
...
@@ -110,6 +118,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
for
(
auto
*
scope
:
local_scopes_
)
{
NewTempScopeAndInitVars
(
var_infos_
,
scope
);
}
ProcessGraph
(
graphs_
,
local_scopes_
[
0
]);
}
void
AsyncSSAGraphExecutor
::
StartOffPythonTrainLoop
()
{
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
255b36da
...
...
@@ -30,7 +30,7 @@ if(WITH_GRPC)
else
()
set
(
BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc
)
set_source_files_properties
(
${
BRPC_SRCS
}
parameter_prefetch.cc parameter_send.cc parameter_recv.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
${
BRPC_SRCS
}
parameter_prefetch.cc parameter_send.cc parameter_recv.cc
communicator.cc
rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib
)
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
255b36da
...
...
@@ -63,6 +63,9 @@ static inline void MergeVars(const std::string &var_name,
}
}
std
::
unique_ptr
<
Communicator
>
Communicator
::
communicator_
(
nullptr
);
std
::
once_flag
Communicator
::
init_flag_
;
void
Communicator
::
SendThread
()
{
while
(
running_
)
{
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
...
...
@@ -117,6 +120,7 @@ void Communicator::RecvThread() {
void
Communicator
::
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
{
VLOG
(
3
)
<<
"communicator send "
<<
var_name
;
// push var into send queue by var_name
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
PADDLE_ENFORCE
(
grad_var
->
IsInitialized
(),
"grad var should be inited"
);
...
...
@@ -125,6 +129,8 @@ void Communicator::Send(const std::string &var_name,
send_varname_to_queue_
[
var_name
]
->
Push
(
tmp_grad_var
);
}
Communicator
*
Communicator
::
GetInstance
()
{
return
communicator_
.
get
();
}
void
Communicator
::
Start
()
{
running_
=
true
;
// start send and recv thread
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
255b36da
...
...
@@ -144,7 +144,7 @@ class Communicator {
InitImpl
(
send_varname_to_ctx
,
recv_varname_to_ctx
,
recv_scope
);
}
static
Communicator
*
GetInstance
()
{
return
communicator_
.
get
();
}
static
Communicator
*
GetInstance
()
;
private:
// Init is called by GetInstance.
...
...
paddle/fluid/operators/distributed/rpc_common.h
浏览文件 @
255b36da
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
...
...
@@ -22,15 +23,17 @@ namespace operators {
namespace
distributed
{
struct
RpcContext
{
RpcContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>&
names
,
const
std
::
vector
<
std
::
string
>&
emap
,
const
std
::
vector
<
int64_t
>&
sections
)
RpcContext
()
=
default
;
RpcContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
string
>
&
emap
,
const
std
::
vector
<
int64_t
>
&
sections
)
:
var_name
(
name
),
splited_var_names
(
names
),
epmap
(
emap
),
height_sections
(
sections
)
{}
RpcContext
(
const
RpcContext
&
ctx
)
{
RpcContext
(
const
RpcContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
splited_var_names
=
ctx
.
splited_var_names
;
epmap
=
ctx
.
epmap
;
...
...
@@ -43,6 +46,31 @@ struct RpcContext {
std
::
vector
<
int64_t
>
height_sections
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RpcContext
&
rpc_ctx
)
{
os
<<
"{"
;
os
<<
"var_name: "
<<
rpc_ctx
.
var_name
<<
"
\n
"
;
os
<<
"splited_var_names: ["
;
for
(
auto
&
name
:
rpc_ctx
.
splited_var_names
)
{
os
<<
name
<<
", "
;
}
os
<<
"]
\n
"
;
os
<<
"epmap: ["
;
for
(
auto
&
ep
:
rpc_ctx
.
epmap
)
{
os
<<
ep
<<
", "
;
}
os
<<
"]
\n
"
;
os
<<
"height_sections: ["
;
for
(
auto
&
section
:
rpc_ctx
.
height_sections
)
{
os
<<
section
<<
", "
;
}
os
<<
"]
\n
"
;
os
<<
"}"
;
return
os
;
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed_ops/CMakeLists.txt
浏览文件 @
255b36da
...
...
@@ -2,9 +2,9 @@ include(operators)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv
communicator
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv
communicator
brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
255b36da
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
...
...
@@ -47,10 +48,12 @@ class SendOp : public framework::OperatorBase {
if
(
send_varnames
.
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
1
,
""
);
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
rpc_ctx
=
distributed
::
RpcContext
(
ins
[
0
],
send_varnames
,
epmap
,
height_sections
);
send_functor
(
rpc_ctx
,
scope
,
static_cast
<
bool
>
(
sync_send
));
// auto send_functor = distributed::ParameterSend<float>();
// auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames,
// epmap,
// height_sections);
// send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
[
0
],
scope
);
}
else
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录