Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
268e9dc1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
268e9dc1
编写于
5月 27, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
上级
ceefbf32
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
22 addition
and
20 deletion
+22
-20
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+13
-9
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+8
-9
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+1
-2
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
268e9dc1
...
@@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
...
@@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
const
ProgramDesc
&
program
)
const
{
const
ProgramDesc
&
program
)
const
{
std
::
vector
<
std
::
string
>
send_vars
;
std
::
vector
<
std
::
string
>
send_vars
;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"send_vars"
||
op
->
Type
()
==
"send"
)
{
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if
(
op
->
Type
()
==
"send_vars"
)
{
auto
op_vars
=
op
->
InputArgumentNames
();
auto
op_vars
=
op
->
InputArgumentNames
();
send_vars
.
reserve
(
send_vars
.
size
()
+
send_vars
.
reserve
(
send_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
...
@@ -99,7 +103,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
...
@@ -99,7 +103,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const
ProgramDesc
&
program
)
const
{
const
ProgramDesc
&
program
)
const
{
std
::
vector
<
std
::
string
>
recv_vars
;
std
::
vector
<
std
::
string
>
recv_vars
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"recv"
||
op
->
Type
()
==
"send"
)
{
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if
(
op
->
Type
()
==
"recv"
)
{
auto
op_vars
=
op
->
OutputArgumentNames
();
auto
op_vars
=
op
->
OutputArgumentNames
();
recv_vars
.
reserve
(
recv_vars
.
size
()
+
recv_vars
.
reserve
(
recv_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
...
@@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
...
@@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
auto
checker
=
[](
const
std
::
vector
<
std
::
string
>
&
opvars
,
auto
checker
=
[](
const
std
::
vector
<
std
::
string
>
&
opvars
,
const
std
::
vector
<
std
::
string
>
&
rpc_vars
)
->
bool
{
const
std
::
vector
<
std
::
string
>
&
rpc_vars
)
->
bool
{
for
(
auto
&
var
:
opvars
)
{
for
(
auto
&
var
:
opvars
)
{
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if
(
var
.
find
(
".block"
)
!=
std
::
string
::
npos
&&
if
(
var
.
find
(
".block"
)
!=
std
::
string
::
npos
&&
std
::
find
(
rpc_vars
.
begin
(),
rpc_vars
.
end
(),
var
)
!=
rpc_vars
.
end
())
{
std
::
find
(
rpc_vars
.
begin
(),
rpc_vars
.
end
(),
var
)
!=
rpc_vars
.
end
())
{
return
true
;
return
true
;
...
@@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
...
@@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
return
false
;
return
false
;
};
};
if
(
op
.
Type
()
==
"split"
||
op
.
Type
()
==
"split_byref"
||
return
checker
(
op
.
OutputArgumentNames
(),
send_vars
)
||
op
.
Type
()
==
"split_selected_rows"
)
{
checker
(
op
.
InputArgumentNames
(),
recv_vars
);
return
checker
(
op
.
OutputArgumentNames
(),
send_vars
);
}
else
if
(
op
.
Type
()
==
"concat"
)
{
return
checker
(
op
.
InputArgumentNames
(),
recv_vars
);
}
return
false
;
return
false
;
}
}
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
268e9dc1
...
@@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
...
@@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
,
ep_val
+
":"
+
var_name_val
);
const
auto
ch
=
GetChannel
(
ep_val
);
framework
::
AsyncIO
([
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch
,
framework
::
AsyncIO
([
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch
,
this
]
{
this
]
{
...
@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
...
@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
,
ep_val
+
":"
+
var_name_val
);
const
auto
ch
=
GetChannel
(
ep_val
);
framework
::
AsyncIO
([
var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
framework
::
AsyncIO
([
var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{
this
]
{
...
@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
...
@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
,
ep_val
+
":"
+
in_var_name_val
);
const
auto
ch
=
GetChannel
(
ep_val
);
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{
time_out
,
ch
,
this
]
{
...
@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
...
@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
}
}
void
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
,
ep
);
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
...
@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
}
}
void
RPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
RPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
,
ep
);
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
...
@@ -248,10 +248,9 @@ bool RPCClient::Proceed() {
...
@@ -248,10 +248,9 @@ bool RPCClient::Proceed() {
delete
c
;
delete
c
;
return
true
;
return
true
;
}
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
,
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
const
std
::
string
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
auto
it
=
channels_
.
find
(
key
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
if
(
it
!=
channels_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
...
@@ -263,7 +262,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
...
@@ -263,7 +262,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
auto
ch
=
auto
ch
=
grpc
::
CreateCustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
);
grpc
::
CreateCustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
);
channels_
[
key
]
=
ch
;
channels_
[
ep
]
=
ch
;
return
ch
;
return
ch
;
}
}
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
268e9dc1
...
@@ -191,8 +191,7 @@ class RPCClient {
...
@@ -191,8 +191,7 @@ class RPCClient {
private:
private:
bool
Proceed
();
bool
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
,
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
const
std
::
string
&
key
);
private:
private:
grpc
::
CompletionQueue
cq_
;
grpc
::
CompletionQueue
cq_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录