Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
626abfc3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
626abfc3
编写于
8月 09, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code clean up and renaming
Reduce one level of inheritence.
上级
66be5326
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
152 addition
and
180 deletion
+152
-180
doc/fluid/design/ir/overview.md
doc/fluid/design/ir/overview.md
+2
-2
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-4
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
...fluid/framework/details/multi_devices_graph_check_pass.cc
+2
-2
paddle/fluid/framework/details/multi_devices_graph_check_pass.h
.../fluid/framework/details/multi_devices_graph_check_pass.h
+2
-2
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+88
-2
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+2
-2
paddle/fluid/framework/details/multi_devices_graph_print_pass.cc
...fluid/framework/details/multi_devices_graph_print_pass.cc
+2
-2
paddle/fluid/framework/details/multi_devices_graph_print_pass.h
.../fluid/framework/details/multi_devices_graph_print_pass.h
+2
-2
paddle/fluid/framework/details/multi_devices_helper.cc
paddle/fluid/framework/details/multi_devices_helper.cc
+20
-0
paddle/fluid/framework/details/multi_devices_helper.h
paddle/fluid/framework/details/multi_devices_helper.h
+0
-27
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+0
-107
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+25
-25
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+1
-1
未找到文件。
doc/fluid/design/ir/
draft
.md
→
doc/fluid/design/ir/
overview
.md
浏览文件 @
626abfc3
...
...
@@ -177,8 +177,8 @@ graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
mem_opt_pass->Apply(std::move(graph));
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device
s
_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device
s
_check_pass").Apply(std::move(grah));
Executor exe;
exe.Run(graph);
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
626abfc3
...
...
@@ -100,7 +100,7 @@ else()
endif
()
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_
builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_
pass multi_devices_graph_print_pass multi_devices_graph_check_pass
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
626abfc3
...
...
@@ -5,9 +5,9 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_build
er.cc DEPS graph graph_helper
)
cc_library
(
ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_build
er
)
cc_library
(
ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_build
er
)
cc_library
(
multi_devices_helper SRCS multi_devices_help
er.cc DEPS graph graph_helper
)
cc_library
(
multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_help
er
)
cc_library
(
multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_help
er
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
...
...
@@ -28,7 +28,7 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
multi_devices_graph_
builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_build
er computation_op_handle
cc_library
(
multi_devices_graph_
pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_help
er computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
...
...
paddle/fluid/framework/details/
ssa_graph_checker
.cc
→
paddle/fluid/framework/details/
multi_devices_graph_check_pass
.cc
浏览文件 @
626abfc3
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/
ssa_graph_checker
.h"
#include "paddle/fluid/framework/details/
multi_devices_graph_check_pass
.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
...
...
@@ -86,7 +86,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_check_pass
,
REGISTER_PASS
(
multi_device
s
_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
)
...
...
paddle/fluid/framework/details/
ssa_graph_checker
.h
→
paddle/fluid/framework/details/
multi_devices_graph_check_pass
.h
浏览文件 @
626abfc3
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
#include <string>
...
...
@@ -22,7 +22,7 @@ namespace paddle {
namespace
framework
{
namespace
details
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithChecker
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
...
...
paddle/fluid/framework/details/multi_devices_graph_
builder
.cc
→
paddle/fluid/framework/details/multi_devices_graph_
pass
.cc
浏览文件 @
626abfc3
...
...
@@ -21,7 +21,7 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
builder
.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
pass
.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
...
...
@@ -33,6 +33,92 @@
namespace
paddle
{
namespace
framework
{
namespace
details
{
namespace
{
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
}
auto
it_new
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
OpHandleBase
*
write_op
=
(
*
it_new
)
->
GeneratedOp
();
const
auto
&
read_ops
=
(
*
it_old
)
->
PendingOps
();
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
bool
has_dep
=
false
;
for
(
auto
*
r_out
:
read_op
->
Outputs
())
{
for
(
auto
*
w_in
:
write_op
->
Inputs
())
{
if
(
r_out
->
Node
()
==
w_in
->
Node
())
{
has_dep
=
true
;
break
;
}
}
}
if
(
has_dep
)
continue
;
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
}
}
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
node
->
Var
())
{
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
place_offset
,
node
->
Name
(),
place
);
}
else
{
var
=
new
VarHandle
(
graph
->
CreateEmptyNode
(
node
->
Name
(),
ir
::
Node
::
Type
::
kVariable
),
0
,
place_offset
,
node
->
Name
(),
place
);
}
var_holder
.
emplace_back
(
var
);
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
}
return
var
;
}
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
void
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
// namespace
static
const
char
kLossVarName
[]
=
"loss_var_name"
;
static
const
char
kPlaces
[]
=
"places"
;
...
...
@@ -751,7 +837,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_pass
,
REGISTER_PASS
(
multi_device
s
_pass
,
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLossVarName
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
...
...
paddle/fluid/framework/details/multi_devices_graph_
builder
.h
→
paddle/fluid/framework/details/multi_devices_graph_
pass
.h
浏览文件 @
626abfc3
...
...
@@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
...
...
@@ -30,7 +30,7 @@ namespace framework {
class
Scope
;
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
class
MultiDevSSAGraphBuilder
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
...
...
paddle/fluid/framework/details/
ssa_graph_printer
.cc
→
paddle/fluid/framework/details/
multi_devices_graph_print_pass
.cc
浏览文件 @
626abfc3
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/
ssa_graph_printer
.h"
#include "paddle/fluid/framework/details/
multi_devices_graph_print_pass
.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
...
...
@@ -82,5 +82,5 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_print_pass
,
REGISTER_PASS
(
multi_device
s
_print_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithPrinter
);
paddle/fluid/framework/details/
ssa_graph_printer
.h
→
paddle/fluid/framework/details/
multi_devices_graph_print_pass
.h
浏览文件 @
626abfc3
...
...
@@ -18,7 +18,7 @@
#include <iosfwd>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -35,7 +35,7 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithPrinter
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
...
...
paddle/fluid/framework/details/multi_devices_helper.cc
0 → 100644
浏览文件 @
626abfc3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/
ssa_graph_build
er.h
→
paddle/fluid/framework/details/
multi_devices_help
er.h
浏览文件 @
626abfc3
...
...
@@ -52,33 +52,6 @@ const char kGraphOps[] = "ops";
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
const
char
kShardedVarDevice
[]
=
"sharded_var_device"
;
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
*/
static
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
ir
::
Graph
*
graph
);
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_builder.cc
已删除
100644 → 0
浏览文件 @
66be5326
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <utility>
namespace
paddle
{
namespace
framework
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
}
auto
it_new
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
OpHandleBase
*
write_op
=
(
*
it_new
)
->
GeneratedOp
();
const
auto
&
read_ops
=
(
*
it_old
)
->
PendingOps
();
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
bool
has_dep
=
false
;
for
(
auto
*
r_out
:
read_op
->
Outputs
())
{
for
(
auto
*
w_in
:
write_op
->
Inputs
())
{
if
(
r_out
->
Node
()
==
w_in
->
Node
())
{
has_dep
=
true
;
break
;
}
}
}
if
(
has_dep
)
continue
;
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
}
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
node
->
Var
())
{
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
place_offset
,
node
->
Name
(),
place
);
}
else
{
var
=
new
VarHandle
(
graph
->
CreateEmptyNode
(
node
->
Name
(),
ir
::
Node
::
Type
::
kVariable
),
0
,
place_offset
,
node
->
Name
(),
place
);
}
var_holder
.
emplace_back
(
var
);
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
}
return
var
;
}
void
SSAGraphBuilder
::
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
626abfc3
...
...
@@ -14,7 +14,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
626abfc3
...
...
@@ -25,9 +25,9 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -57,39 +57,39 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
}
// Convert graph to run on multi-devices.
auto
multi_device_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_pass"
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
multi_device_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
multi_device_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
auto
multi_device
s
_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device
s
_pass"
);
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
multi_device_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
multi_device
s
_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
multi_device_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
multi_device
s
_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
graph
=
multi_device_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
multi_device
s
_pass
->
Apply
(
std
::
move
(
graph
));
// Apply a graph print pass to record a graph with device info.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
multi_device_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_print_pass"
);
multi_device_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
auto
multi_device
s
_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device
s
_print_pass"
);
multi_device
s
_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy
.
debug_graphviz_path_
);
multi_device_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
multi_device
s
_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
details
::
GraphvizSSAGraphPrinter
);
graph
=
multi_device_print_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
multi_device
s
_print_pass
->
Apply
(
std
::
move
(
graph
));
}
// Verify that the graph is correct for multi-device executor.
auto
multi_device_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_check_pass"
);
graph
=
multi_device_check_pass
->
Apply
(
std
::
move
(
graph
));
auto
multi_device
s
_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device
s
_check_pass"
);
graph
=
multi_device
s
_check_pass
->
Apply
(
std
::
move
(
graph
));
return
graph
;
}
...
...
@@ -354,6 +354,6 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace paddle
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_device_pass
);
USE_PASS
(
multi_device_check_pass
);
USE_PASS
(
multi_device_print_pass
);
USE_PASS
(
multi_device
s
_pass
);
USE_PASS
(
multi_device
s
_check_pass
);
USE_PASS
(
multi_device
s
_print_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
626abfc3
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
builder
.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
pass
.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录