Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6b45c5a1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6b45c5a1
编写于
8月 13, 2018
作者:
X
Xin Pan
提交者:
GitHub
8月 13, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12605 from panyx0718/ir
code clean up and renaming
上级
24283d9f
626abfc3
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
152 addition
and
180 deletion
+152
-180
doc/fluid/design/ir/overview.md
doc/fluid/design/ir/overview.md
+2
-2
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-4
paddle/fluid/framework/details/multi_devices_graph_check_pass.cc
...fluid/framework/details/multi_devices_graph_check_pass.cc
+2
-2
paddle/fluid/framework/details/multi_devices_graph_check_pass.h
.../fluid/framework/details/multi_devices_graph_check_pass.h
+2
-2
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+88
-2
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+2
-2
paddle/fluid/framework/details/multi_devices_graph_print_pass.cc
...fluid/framework/details/multi_devices_graph_print_pass.cc
+2
-2
paddle/fluid/framework/details/multi_devices_graph_print_pass.h
.../fluid/framework/details/multi_devices_graph_print_pass.h
+2
-2
paddle/fluid/framework/details/multi_devices_helper.cc
paddle/fluid/framework/details/multi_devices_helper.cc
+20
-0
paddle/fluid/framework/details/multi_devices_helper.h
paddle/fluid/framework/details/multi_devices_helper.h
+0
-27
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+0
-107
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+25
-25
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+1
-1
未找到文件。
doc/fluid/design/ir/
draft
.md
→
doc/fluid/design/ir/
overview
.md
浏览文件 @
6b45c5a1
...
...
@@ -177,8 +177,8 @@ graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
mem_opt_pass->Apply(std::move(graph));
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device
s
_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device
s
_check_pass").Apply(std::move(grah));
Executor exe;
exe.Run(graph);
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
6b45c5a1
...
...
@@ -100,7 +100,7 @@ else()
endif
()
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_
builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_
pass multi_devices_graph_print_pass multi_devices_graph_check_pass
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
6b45c5a1
...
...
@@ -5,9 +5,9 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_build
er.cc DEPS graph graph_helper
)
cc_library
(
ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_build
er
)
cc_library
(
ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_build
er
)
cc_library
(
multi_devices_helper SRCS multi_devices_help
er.cc DEPS graph graph_helper
)
cc_library
(
multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_help
er
)
cc_library
(
multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_help
er
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
...
...
@@ -28,7 +28,7 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
multi_devices_graph_
builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_build
er computation_op_handle
cc_library
(
multi_devices_graph_
pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_help
er computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
...
...
paddle/fluid/framework/details/
ssa_graph_checker
.cc
→
paddle/fluid/framework/details/
multi_devices_graph_check_pass
.cc
浏览文件 @
6b45c5a1
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/
ssa_graph_checker
.h"
#include "paddle/fluid/framework/details/
multi_devices_graph_check_pass
.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
...
...
@@ -86,7 +86,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_check_pass
,
REGISTER_PASS
(
multi_device
s
_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
)
...
...
paddle/fluid/framework/details/
ssa_graph_checker
.h
→
paddle/fluid/framework/details/
multi_devices_graph_check_pass
.h
浏览文件 @
6b45c5a1
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
#include <string>
...
...
@@ -22,7 +22,7 @@ namespace paddle {
namespace
framework
{
namespace
details
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithChecker
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
...
...
paddle/fluid/framework/details/multi_devices_graph_
builder
.cc
→
paddle/fluid/framework/details/multi_devices_graph_
pass
.cc
浏览文件 @
6b45c5a1
...
...
@@ -21,7 +21,7 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
builder
.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
pass
.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
...
...
@@ -33,6 +33,92 @@
namespace
paddle
{
namespace
framework
{
namespace
details
{
namespace
{
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
}
auto
it_new
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
OpHandleBase
*
write_op
=
(
*
it_new
)
->
GeneratedOp
();
const
auto
&
read_ops
=
(
*
it_old
)
->
PendingOps
();
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
bool
has_dep
=
false
;
for
(
auto
*
r_out
:
read_op
->
Outputs
())
{
for
(
auto
*
w_in
:
write_op
->
Inputs
())
{
if
(
r_out
->
Node
()
==
w_in
->
Node
())
{
has_dep
=
true
;
break
;
}
}
}
if
(
has_dep
)
continue
;
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
}
}
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
node
->
Var
())
{
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
place_offset
,
node
->
Name
(),
place
);
}
else
{
var
=
new
VarHandle
(
graph
->
CreateEmptyNode
(
node
->
Name
(),
ir
::
Node
::
Type
::
kVariable
),
0
,
place_offset
,
node
->
Name
(),
place
);
}
var_holder
.
emplace_back
(
var
);
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
}
return
var
;
}
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
void
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
// namespace
static
const
char
kLossVarName
[]
=
"loss_var_name"
;
static
const
char
kPlaces
[]
=
"places"
;
...
...
@@ -751,7 +837,7 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_pass
,
REGISTER_PASS
(
multi_device
s
_pass
,
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLossVarName
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
...
...
paddle/fluid/framework/details/multi_devices_graph_
builder
.h
→
paddle/fluid/framework/details/multi_devices_graph_
pass
.h
浏览文件 @
6b45c5a1
...
...
@@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
...
...
@@ -30,7 +30,7 @@ namespace framework {
class
Scope
;
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
class
MultiDevSSAGraphBuilder
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
...
...
paddle/fluid/framework/details/
ssa_graph_printer
.cc
→
paddle/fluid/framework/details/
multi_devices_graph_print_pass
.cc
浏览文件 @
6b45c5a1
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/
ssa_graph_printer
.h"
#include "paddle/fluid/framework/details/
multi_devices_graph_print_pass
.h"
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
...
...
@@ -82,5 +82,5 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
multi_device_print_pass
,
REGISTER_PASS
(
multi_device
s
_print_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithPrinter
);
paddle/fluid/framework/details/
ssa_graph_printer
.h
→
paddle/fluid/framework/details/
multi_devices_graph_print_pass
.h
浏览文件 @
6b45c5a1
...
...
@@ -18,7 +18,7 @@
#include <iosfwd>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -35,7 +35,7 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithPrinter
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
...
...
paddle/fluid/framework/details/multi_devices_helper.cc
0 → 100644
浏览文件 @
6b45c5a1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/
ssa_graph_build
er.h
→
paddle/fluid/framework/details/
multi_devices_help
er.h
浏览文件 @
6b45c5a1
...
...
@@ -52,33 +52,6 @@ const char kGraphOps[] = "ops";
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
const
char
kShardedVarDevice
[]
=
"sharded_var_device"
;
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
*/
static
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
ir
::
Graph
*
graph
);
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_builder.cc
已删除
100644 → 0
浏览文件 @
24283d9f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include <utility>
namespace
paddle
{
namespace
framework
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
}
auto
it_new
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
OpHandleBase
*
write_op
=
(
*
it_new
)
->
GeneratedOp
();
const
auto
&
read_ops
=
(
*
it_old
)
->
PendingOps
();
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
bool
has_dep
=
false
;
for
(
auto
*
r_out
:
read_op
->
Outputs
())
{
for
(
auto
*
w_in
:
write_op
->
Inputs
())
{
if
(
r_out
->
Node
()
==
w_in
->
Node
())
{
has_dep
=
true
;
break
;
}
}
}
if
(
has_dep
)
continue
;
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
}
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
node
->
Var
())
{
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
place_offset
,
node
->
Name
(),
place
);
}
else
{
var
=
new
VarHandle
(
graph
->
CreateEmptyNode
(
node
->
Name
(),
ir
::
Node
::
Type
::
kVariable
),
0
,
place_offset
,
node
->
Name
(),
place
);
}
var_holder
.
emplace_back
(
var
);
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
}
return
var
;
}
void
SSAGraphBuilder
::
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
6b45c5a1
...
...
@@ -14,7 +14,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/
ssa_graph_build
er.h"
#include "paddle/fluid/framework/details/
multi_devices_help
er.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
6b45c5a1
...
...
@@ -25,9 +25,9 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -57,39 +57,39 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
}
// Convert graph to run on multi-devices.
auto
multi_device_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_pass"
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
auto
multi_device
s
_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device
s
_pass"
);
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
multi_device_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
multi_device_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
multi_device
s
_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
multi_device_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
multi_device
s
_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
multi_device_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
multi_device
s
_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
graph
=
multi_device_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
multi_device
s
_pass
->
Apply
(
std
::
move
(
graph
));
// Apply a graph print pass to record a graph with device info.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
multi_device_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_print_pass"
);
multi_device_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
auto
multi_device
s
_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device
s
_print_pass"
);
multi_device
s
_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy
.
debug_graphviz_path_
);
multi_device_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
multi_device
s
_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
details
::
GraphvizSSAGraphPrinter
);
graph
=
multi_device_print_pass
->
Apply
(
std
::
move
(
graph
));
graph
=
multi_device
s
_print_pass
->
Apply
(
std
::
move
(
graph
));
}
// Verify that the graph is correct for multi-device executor.
auto
multi_device_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_check_pass"
);
graph
=
multi_device_check_pass
->
Apply
(
std
::
move
(
graph
));
auto
multi_device
s
_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device
s
_check_pass"
);
graph
=
multi_device
s
_check_pass
->
Apply
(
std
::
move
(
graph
));
return
graph
;
}
...
...
@@ -354,6 +354,6 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace paddle
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_device_pass
);
USE_PASS
(
multi_device_check_pass
);
USE_PASS
(
multi_device_print_pass
);
USE_PASS
(
multi_device
s
_pass
);
USE_PASS
(
multi_device
s
_check_pass
);
USE_PASS
(
multi_device
s
_print_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
6b45c5a1
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
builder
.h"
#include "paddle/fluid/framework/details/multi_devices_graph_
pass
.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录