Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
73005ee0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73005ee0
编写于
2月 14, 2019
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cleanup code test=develop
上级
88d3dc94
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
19 addition
and
45 deletion
+19
-45
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+0
-4
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+8
-9
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+5
-11
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
+0
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+1
-1
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+0
-10
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+1
-3
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+4
-5
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
73005ee0
...
@@ -119,8 +119,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -119,8 +119,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor.
// Verify that the graph is correct for multi-device executor.
auto
multi_devices_pass
=
AppendPass
(
"multi_devices_check_pass"
);
auto
multi_devices_pass
=
AppendPass
(
"multi_devices_check_pass"
);
multi_devices_pass
->
Set
<
bool
>
(
kEnablePG
,
new
bool
(
strategy
.
enable_parallel_graph_
));
if
(
SeqOnlyAllReduceOps
(
strategy
))
{
if
(
SeqOnlyAllReduceOps
(
strategy
))
{
AppendPass
(
"all_reduce_deps_pass"
);
AppendPass
(
"all_reduce_deps_pass"
);
...
@@ -194,8 +192,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -194,8 +192,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
&
local_scopes
);
&
local_scopes
);
pass
->
Erase
(
kNRanks
);
pass
->
Erase
(
kNRanks
);
pass
->
Set
<
size_t
>
(
kNRanks
,
new
size_t
(
nranks
));
pass
->
Set
<
size_t
>
(
kNRanks
,
new
size_t
(
nranks
));
pass
->
Erase
(
kEnablePG
);
pass
->
Set
<
bool
>
(
kEnablePG
,
new
bool
(
true
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
73005ee0
...
@@ -201,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
...
@@ -201,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
auto
&
g_name
=
backward_vars
[
i
+
1
];
auto
&
g_name
=
backward_vars
[
i
+
1
];
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
InsertCollectiveOp
(
&
result
,
node
,
p_name
,
g_name
);
InsertCollectiveOp
(
&
result
,
p_name
,
g_name
);
}
}
}
catch
(
boost
::
bad_get
e
)
{
}
catch
(
boost
::
bad_get
e
)
{
}
}
...
@@ -386,7 +386,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
...
@@ -386,7 +386,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
}
}
void
MultiDevSSAGraphBuilderBase
::
CreateAllReduceOp
(
void
MultiDevSSAGraphBuilderBase
::
CreateAllReduceOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
og
)
const
{
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
OpHandleBase
*
op_handle
=
nullptr
;
OpHandleBase
*
op_handle
=
nullptr
;
auto
append_allreduce_op
=
[
&
](
auto
append_allreduce_op
=
[
&
](
...
@@ -510,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
...
@@ -510,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
}
}
void
AllReduceSSAGraphBuilder
::
InsertCollectiveOp
(
void
AllReduceSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
const
std
::
string
&
g_name
)
const
{
if
(
IsSparseGradient
(
g_name
))
{
if
(
IsSparseGradient
(
g_name
))
{
CreateReduceOp
(
result
,
g_name
,
0
);
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
}
else
{
CreateAllReduceOp
(
result
,
node
,
g_name
);
CreateAllReduceOp
(
result
,
g_name
);
}
}
}
}
...
@@ -589,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
...
@@ -589,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
}
}
void
ReduceSSAGraphBuilder
::
InsertCollectiveOp
(
void
ReduceSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
const
std
::
string
&
g_name
)
const
{
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
...
@@ -909,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -909,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return
op_dev_id
;
return
op_dev_id
;
}
}
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
const
std
::
string
&
g_name
)
const
{
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
...
@@ -924,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
...
@@ -924,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ir::Node *node,
CreateReduceOp
(
result
,
g_name
,
0
);
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
}
else
{
CreateAllReduceOp
(
result
,
node
,
g_name
);
CreateAllReduceOp
(
result
,
g_name
);
}
}
break
;
break
;
default:
default:
...
@@ -975,8 +975,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
...
@@ -975,8 +975,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks) \
.RequirePassAttr(paddle::framework::details::kNRanks)
.RequirePassAttr(paddle::framework::details::kEnablePG)
REGISTER_MULTI_DEVICES_PASS
(
reduce_mode_multi_devices_pass
,
REGISTER_MULTI_DEVICES_PASS
(
reduce_mode_multi_devices_pass
,
paddle
::
framework
::
details
::
ReduceSSAGraphBuilder
);
paddle
::
framework
::
details
::
ReduceSSAGraphBuilder
);
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
73005ee0
...
@@ -36,7 +36,6 @@ constexpr char kPlaces[] = "places";
...
@@ -36,7 +36,6 @@ constexpr char kPlaces[] = "places";
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kStrategy
[]
=
"strategy"
;
constexpr
char
kStrategy
[]
=
"strategy"
;
constexpr
char
kNRanks
[]
=
"nranks"
;
constexpr
char
kNRanks
[]
=
"nranks"
;
constexpr
char
kEnablePG
[]
=
"enable_pg"
;
class
MultiDevSSAGraphBuilderBase
:
public
ir
::
Pass
{
class
MultiDevSSAGraphBuilderBase
:
public
ir
::
Pass
{
protected:
protected:
...
@@ -47,8 +46,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -47,8 +46,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
virtual
std
::
vector
<
ir
::
Node
*>
SortOperations
(
const
ir
::
Graph
&
graph
)
const
;
virtual
std
::
vector
<
ir
::
Node
*>
SortOperations
(
const
ir
::
Graph
&
graph
)
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
=
0
;
const
std
::
string
&
g_name
)
const
=
0
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
=
0
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
=
0
;
...
@@ -77,8 +75,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -77,8 +75,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
void
CreateAllReduceOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
const
std
::
string
&
og
)
const
;
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
...
@@ -109,8 +106,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -109,8 +106,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
class
AllReduceSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
class
AllReduceSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
protected:
protected:
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
const
std
::
string
&
g_name
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
...
@@ -139,8 +135,7 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
...
@@ -139,8 +135,7 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
protected:
virtual
void
Init
()
const
;
virtual
void
Init
()
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
const
std
::
string
&
g_name
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
...
@@ -169,8 +164,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
...
@@ -169,8 +164,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
;
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
const
std
::
string
&
g_name
)
const
;
virtual
void
ResetState
()
const
;
virtual
void
ResetState
()
const
;
...
...
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
浏览文件 @
73005ee0
...
@@ -45,8 +45,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -45,8 +45,6 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
private:
private:
// std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph();
ExecutionStrategy
strategy_
;
ExecutionStrategy
strategy_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
73005ee0
...
@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
}
}
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
);
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
);
}
}
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
);
ready_ops
.
insert
(
op
);
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
73005ee0
...
@@ -176,12 +176,6 @@ class Graph {
...
@@ -176,12 +176,6 @@ class Graph {
return
ret
;
return
ret
;
}
}
void
RemoveNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
!=
node_set_
.
end
());
node_set_
.
erase
(
node
);
nodes_
.
erase
(
node
);
}
// NOTE low performance, but simple and secure.
// NOTE low performance, but simple and secure.
Node
*
RetrieveNode
(
int
id
)
{
Node
*
RetrieveNode
(
int
id
)
{
for
(
auto
&
node
:
nodes_
)
{
for
(
auto
&
node
:
nodes_
)
{
...
@@ -200,10 +194,6 @@ class Graph {
...
@@ -200,10 +194,6 @@ class Graph {
return
node
;
return
node
;
}
}
bool
ContainNode
(
ir
::
Node
*
node
)
{
return
node_set_
.
find
(
node
)
!=
node_set_
.
end
();
}
void
ResolveHazard
(
void
ResolveHazard
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
&
var_nodes
);
const
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
&
var_nodes
);
...
...
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
73005ee0
...
@@ -64,9 +64,7 @@ template <typename T>
...
@@ -64,9 +64,7 @@ template <typename T>
std
::
vector
<
T
*>
FilterByNodeWrapper
(
const
Graph
&
graph
)
{
std
::
vector
<
T
*>
FilterByNodeWrapper
(
const
Graph
&
graph
)
{
std
::
vector
<
T
*>
ret
;
std
::
vector
<
T
*>
ret
;
for
(
ir
::
Node
*
n
:
graph
.
Nodes
())
{
for
(
ir
::
Node
*
n
:
graph
.
Nodes
())
{
if
(
n
->
IsWrappedBy
<
T
>
())
{
if
(
n
->
IsWrappedBy
<
T
>
())
ret
.
push_back
(
&
n
->
Wrapper
<
T
>
());
ret
.
push_back
(
&
n
->
Wrapper
<
T
>
());
}
}
}
return
ret
;
return
ret
;
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
73005ee0
...
@@ -478,12 +478,11 @@ bool ParallelExecutor::EnableParallelGraphExecution(
...
@@ -478,12 +478,11 @@ bool ParallelExecutor::EnableParallelGraphExecution(
}
}
}
}
// if (!member_->use_all_reduce_ || !member_->use_cuda_)
if
(
!
member_
->
use_all_reduce_
||
!
member_
->
use_cuda_
)
if
(
!
member_
->
use_all_reduce_
)
enable_parallel_graph
=
false
;
if
(
build_strategy
.
enable_sequential_execution_
||
if
(
build_strategy
.
enable_sequential_execution_
||
exec_strategy
.
type_
==
ExecutionStrategy
::
ExecutorType
::
kExperimental
)
exec_strategy
.
type_
==
ExecutionStrategy
::
ExecutorType
::
kExperimental
)
enable_parallel_graph
=
false
;
enable_parallel_graph
=
false
;
return
enable_parallel_graph
;
return
enable_parallel_graph
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录