Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6debbcd9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6debbcd9
编写于
5月 23, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
connect fetch barrier and concat op
上级
147d54ba
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
50 addition
and
18 deletion
+50
-18
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+12
-5
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+1
-0
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+8
-1
paddle/fluid/operators/send_vars_op.cc
paddle/fluid/operators/send_vars_op.cc
+0
-2
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+29
-10
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
6debbcd9
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility>
#include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
...
@@ -181,8 +182,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -181,8 +182,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// always use the first device
// always use the first device
CreateRPCOp
(
&
result
,
*
op
);
CreateRPCOp
(
&
result
,
*
op
);
}
else
if
(
IsDistTrainOp
(
*
op
,
send_vars
,
recv_vars
))
{
}
else
if
(
IsDistTrainOp
(
*
op
,
send_vars
,
recv_vars
))
{
// CreateComputationalOps(&result, *op, 1);
CreateDistTrainOp
(
&
result
,
*
op
);
CreateComputationalOp
(
&
result
,
*
op
,
0
);
}
else
if
(
IsScaleLossOp
(
*
op
))
{
}
else
if
(
IsScaleLossOp
(
*
op
))
{
// user can customize loss@grad if not use_default_grad_scale_
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
if
(
strategy_
.
gradient_scale_
!=
...
@@ -247,9 +247,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...
@@ -247,9 +247,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
if
(
VLOG_IS_ON
(
10
))
{
if
(
VLOG_IS_ON
(
10
))
{
std
::
ostringstream
sout
;
std
::
ofstream
fout
(
"/tmp/graph.dot"
);
PrintGraphviz
(
*
graph
,
sout
);
PrintGraphviz
(
*
graph
,
fout
);
VLOG
(
10
)
<<
sout
.
str
();
}
}
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
...
@@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
...
@@ -443,6 +442,14 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
CreateComputationalOp
(
result
,
op
,
0
);
if
(
op
.
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
ops_
.
back
().
get
(),
"fetch_barrier"
);
}
}
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSAGraph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
{
const
OpDesc
&
op
)
const
{
auto
&
p
=
places_
[
0
];
auto
&
p
=
places_
[
0
];
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
6debbcd9
...
@@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -65,6 +65,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateDistTrainOp
(
SSAGraph
*
result
,
const
OpDesc
&
op
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
6debbcd9
...
@@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase {
...
@@ -38,6 +38,7 @@ class RecvOp : public framework::OperatorBase {
auto
outs
=
Outputs
(
"Out"
);
auto
outs
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
auto
client_var_name
=
Output
(
"RPCClient"
);
int
sync_recv
=
Attr
<
int
>
(
"sync_recv"
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
...
@@ -54,8 +55,10 @@ class RecvOp : public framework::OperatorBase {
...
@@ -54,8 +55,10 @@ class RecvOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
}
if
(
sync_recv
)
{
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
}
};
};
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
RecvOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -75,6 +78,10 @@ This operator can get variables from server side.
...
@@ -75,6 +78,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"Server endpoints in the order of input "
"variables for mapping"
)
"variables for mapping"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"sync_recv"
,
"(int, default 0)"
"sync recv or async recv."
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/operators/send_vars_op.cc
浏览文件 @
6debbcd9
...
@@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase {
...
@@ -50,9 +50,7 @@ class SendVarsOp : public framework::OperatorBase {
"Can not find variable '%s' in the scope."
,
"Can not find variable '%s' in the scope."
,
client_var_name
);
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
auto
*
client_var
=
scope
.
FindVar
(
client_var_name
);
VLOG
(
3
)
<<
"client var addr: "
<<
client_var
;
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
detail
::
RPCClient
*
rpc_client
=
client_var
->
GetMutable
<
detail
::
RPCClient
>
();
VLOG
(
3
)
<<
"rpc_client addr: "
<<
rpc_client
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
6debbcd9
...
@@ -357,12 +357,35 @@ class DistributeTranspiler:
...
@@ -357,12 +357,35 @@ class DistributeTranspiler:
ps_dispatcher
.
reset
()
ps_dispatcher
.
reset
()
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
eplist
=
ps_dispatcher
.
dispatch
(
recv_vars
)
#program.global_block().append_op(
# type="recv",
# inputs={},
# outputs={"Out": recv_vars,
# "RPCClient": rpc_client_var},
# attrs={"epmap": eplist})
#program.global_block().append_op(
# type="fetch_barrier",
# inputs={},
# outputs={"RPCClient": rpc_client_var},
# attrs={"endpoints": pserver_endpoints})
for
i
,
ep
in
enumerate
(
eplist
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
eps
=
[]
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
eps
.
append
(
eplist
[
index
])
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"recv"
,
type
=
"recv"
,
inputs
=
{},
inputs
=
{},
outputs
=
{
"Out"
:
recv_vars
,
outputs
=
{
"Out"
:
splited_var
,
"RPCClient"
:
rpc_client_var
},
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"epmap"
:
eplist
})
attrs
=
{
"epmap"
:
eps
})
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"fetch_barrier"
,
type
=
"fetch_barrier"
,
...
@@ -370,10 +393,6 @@ class DistributeTranspiler:
...
@@ -370,10 +393,6 @@ class DistributeTranspiler:
outputs
=
{
"RPCClient"
:
rpc_client_var
},
outputs
=
{
"RPCClient"
:
rpc_client_var
},
attrs
=
{
"endpoints"
:
pserver_endpoints
})
attrs
=
{
"endpoints"
:
pserver_endpoints
})
for
i
,
ep
in
enumerate
(
eplist
):
self
.
param_grad_ep_mapping
[
ep
][
"params"
].
append
(
recv_vars
[
i
])
self
.
param_grad_ep_mapping
[
ep
][
"grads"
].
append
(
send_vars
[
i
])
# step4: Concat the parameters splits together after recv.
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
for
varname
,
splited_var
in
param_var_mapping
.
iteritems
():
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录