Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4140fe11
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4140fe11
编写于
7月 27, 2019
作者:
C
chengduo
提交者:
GitHub
7月 27, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Open fuse optimization ops (#18741)
* open fuse optimization ops test=develop
上级
582cc297
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
202 addition
and
198 deletion
+202
-198
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+153
-176
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+1
-1
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
...framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
...ework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
.../framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
+1
-3
paddle/fluid/framework/ir/graph_printer.h
paddle/fluid/framework/ir/graph_printer.h
+1
-1
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+3
-4
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc
...ulti_devices_graph_pass/multi_devices_graph_print_pass.cc
+7
-2
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+5
-0
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+8
-0
paddle/fluid/framework/ir/pass_builder.cc
paddle/fluid/framework/ir/pass_builder.cc
+3
-0
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
+1
-1
python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py
...d/tests/unittests/test_buffer_shared_memory_reuse_pass.py
+16
-1
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
4140fe11
...
...
@@ -21,12 +21,12 @@ limitations under the License. */
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool
(
use_mkldnn
);
...
...
@@ -48,212 +48,195 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
// Add a graph viz pass to record a graph.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
VLOG
(
1
)
<<
"Add graph_viz_pass"
;
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
}
ResolveOptionConfliction
();
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
VLOG
(
1
)
<<
"Add record_skip_memory_opt_vars_pass"
;
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_original_graph"
);
// Note(zcd): record_skip_memory_opt_vars_pass should
// be the first pass.
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
AppendPassWithCheck
(
strategy_
.
enable_sequential_execution_
,
"sequential_execution_pass"
);
AppendPassWithCheck
(
strategy_
.
sync_batch_norm_
,
"sync_batch_norm_pass"
);
AppendOpFusePasses
();
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_fused_graph"
);
// TODO(dev-paddle): memory optimize pass should be placed last.
AppendMemoryOptimizePasses
();
AppendMultiDevPass
();
AppendMultiGraphOptPasses
();
AppendPassToSetMkldnnAttr
(
"mkldnn_placement_pass"
);
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
AppendPassWithCheck
(
strategy_
.
cache_runtime_context_
,
"runtime_context_cache_pass"
);
AppendPassWithCheck
(
strategy_
.
remove_unnecessary_lock_
,
"modify_op_lock_and_record_event_pass"
);
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass
(
"multi_devices_check_pass"
);
#ifdef PADDLE_WITH_MKLDNN
if
(
FLAGS_use_mkldnn
)
{
VLOG
(
1
)
<<
"Add mkldnn_placement_pass"
;
AppendPass
(
"mkldnn_placement_pass"
);
}
else
if
(
!
strategy_
.
mkldnn_enabled_op_types_
.
empty
())
{
LOG
(
WARNING
)
<<
"mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false."
;
}
#else
PADDLE_ENFORCE
(
!
FLAGS_use_mkldnn
,
"Please compile with MKLDNN first to use MKLDNN"
);
#endif
SetCollectiveContext
();
}
if
(
strategy_
.
enable_sequential_execution_
)
{
VLOG
(
1
)
<<
"Add sequential_execution_pass"
;
AppendPass
(
"sequential_execution_pass"
);
void
ResolveOptionConfliction
()
{
// Specifies the restrictions between different pass.
if
(
strategy_
.
enable_parallel_graph_
)
{
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
<<
"Currently, fuse_all_optimizer_ops doesn't works under "
"parallel_graph."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
// Add op fusion.
if
(
strategy
.
sync_batch_norm_
)
{
AppendPass
(
"sync_batch_norm_pass"
);
if
(
strategy_
.
is_distribution_
)
{
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
<<
"Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
// Add op fusion.
if
(
strategy
.
fuse_relu_depthwise_conv_
)
{
VLOG
(
1
)
<<
"Add fuse_relu_depthwise_conv_pass"
;
AppendPass
(
"fuse_relu_depthwise_conv_pass"
);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
<<
"Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
VLOG_IF
(
3
,
strategy_
.
fuse_all_reduce_ops_
)
<<
"fuse_all_optimizer_ops only work in Reducer mode."
;
strategy_
.
fuse_all_reduce_ops_
=
false
;
}
}
// TODO(zjl): refactor MemoryOptimizePass to fit
// new strategy, which does not need to set
// var.persistable = True
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
if
(
strategy_
.
enable_inplace_
)
{
VLOG
(
5
)
<<
"Add inplace_pass"
;
AppendPass
(
"inplace_pass"
);
}
}
void
AppendMultiGraphOptPasses
()
{
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
AppendPassWithCheck
(
strategy_
.
fuse_all_reduce_ops_
,
"fuse_all_reduce_op_pass"
);
AppendPrintGraphPass
(
"multi_devices_print_pass"
,
"_multi_devices_graph"
);
if
(
strategy_
.
fuse_elewise_add_act_ops_
)
{
VLOG
(
1
)
<<
"Add fuse_elewise_add_act_pass"
;
AppendPass
(
"fuse_elewise_add_act_pass"
);
}
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
bool
append_all_reduce_deps_pass
=
!
strategy_
.
enable_parallel_graph_
&&
(
SeqOnlyAllReduceOps
(
strategy_
)
||
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
);
AppendPassWithCheck
(
append_all_reduce_deps_pass
,
"all_reduce_deps_pass"
);
bool
append_backward_optimizer_op_deps_pass
=
strategy_
.
num_trainers_
>
1
&&
!
strategy_
.
async_mode_
&&
!
strategy_
.
is_distribution_
&&
strategy_
.
enable_backward_optimizer_op_deps_
;
AppendPassWithCheck
(
append_backward_optimizer_op_deps_pass
,
"backward_optimizer_op_deps_pass"
);
}
void
AppendOpFusePasses
()
{
AppendPassWithCheck
(
strategy_
.
fuse_relu_depthwise_conv_
,
"fuse_relu_depthwise_conv_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_elewise_add_act_ops_
,
"fuse_elewise_add_act_pass"
);
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
VLOG
(
1
)
<<
"Add coalesce_grad_tensor_pass"
;
AppendPass
(
"coalesce_grad_tensor_pass"
);
}
AppendPassWithCheck
(
strategy_
.
fuse_all_reduce_ops_
,
"coalesce_grad_tensor_pass"
);
// Fuse all the optimization operators.
if
(
strategy_
.
is_distribution_
)
{
VLOG
(
3
)
<<
"Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
||
strategy_
.
is_distribution_
)
{
VLOG
(
3
)
<<
"Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
if
(
strategy_
.
fuse_all_optimizer_ops_
)
{
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG
(
1
)
<<
"Add fuse_adam_op_pass"
;
AppendPass
(
"fuse_adam_op_pass"
);
VLOG
(
1
)
<<
"Add fuse_sgd_op_pass"
;
AppendPass
(
"fuse_sgd_op_pass"
);
VLOG
(
1
)
<<
"Add fuse_momentum_op_pass"
;
AppendPass
(
"fuse_momentum_op_pass"
);
}
}
// Add a graph viz pass to record a graph.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_fused_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
}
CollectiveContext
*
context
=
CollectiveContext
::
GetInstance
();
context
->
endpoints_
=
strategy_
.
trainers_endpoints_
;
context
->
trainer_id_
=
strategy_
.
trainer_id_
;
PADDLE_ENFORCE
(
strategy_
.
trainer_id_
>=
0
,
"trainer_id_ >= 0"
);
if
(
strategy_
.
trainer_id_
>
0
&&
strategy_
.
trainers_endpoints_
.
size
()
>
0
)
{
PADDLE_ENFORCE
((
unsigned
)(
strategy_
.
trainer_id_
)
<
strategy_
.
trainers_endpoints_
.
size
(),
"trainer_id_ < endpoints_ size"
);
void
AppendMemoryOptimizePasses
()
{
// Append Memory Optimize Pass
// TODO(zjl): refactor MemoryOptimizePass to fit
// new strategy, which does not need to set
// var.persistable = True
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
AppendPassWithCheck
(
strategy_
.
enable_inplace_
,
"inplace_pass"
);
}
VLOG
(
1
)
<<
"CollectiveContext:"
<<
context
->
String
();
// NOTE(dzh): memory optimize should be a runtime pass.
// However, after multi_devices_pass, VarHandle, OpHandle is
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
if
(
strategy_
.
memory_optimize_
)
{
VLOG
(
5
)
<<
"Add memory_optimize_pass"
;
AppendPass
(
"memory_optimize_pass"
);
}
}
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if
(
strategy_
.
cache_runtime_context_
)
{
VLOG
(
1
)
<<
"Add runtime_context_cache_pass"
;
AppendPass
(
"runtime_context_cache_pass"
);
}
AppendMultiDevPass
(
strategy_
);
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG
(
1
)
<<
"Add fuse_all_reduce_op_pass"
;
AppendPass
(
"fuse_all_reduce_op_pass"
);
}
// Add a graph print pass to record a graph with device info.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
VLOG
(
1
)
<<
"Add multi_devices_print_pass"
;
auto
multi_devices_print_pass
=
AppendPass
(
"multi_devices_print_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_multi_devices_graph"
);
multi_devices_print_pass
->
Set
<
std
::
string
>
(
ir
::
kGraphvizPath
,
new
std
::
string
(
graph_path
));
multi_devices_print_pass
->
Set
<
ir
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
ir
::
GraphvizSSAGraphPrinter
);
}
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
if
(
!
strategy_
.
enable_parallel_graph_
&&
(
SeqOnlyAllReduceOps
(
strategy_
)
||
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
))
{
VLOG
(
1
)
<<
"Add all_reduce_deps_pass"
;
AppendPass
(
"all_reduce_deps_pass"
);
}
if
(
strategy_
.
num_trainers_
>
1
&&
!
strategy_
.
async_mode_
&&
!
strategy_
.
is_distribution_
&&
strategy_
.
enable_backward_optimizer_op_deps_
)
{
VLOG
(
1
)
<<
"Add backward_op_deps_pass"
;
AppendPass
(
"backward_optimizer_op_deps_pass"
);
AppendPassWithCheck
(
strategy_
.
memory_optimize_
,
"memory_optimize_pass"
);
}
}
if
(
strategy_
.
remove_unnecessary_lock_
)
{
VLOG
(
1
)
<<
"Add modify_op_lock_and_record_event_pass"
;
AppendPass
(
"modify_op_lock_and_record_event_pass"
);
void
SetCollectiveContext
()
const
{
CollectiveContext
*
context
=
CollectiveContext
::
GetInstance
();
context
->
endpoints_
=
strategy_
.
trainers_endpoints_
;
context
->
trainer_id_
=
strategy_
.
trainer_id_
;
PADDLE_ENFORCE_GE
(
strategy_
.
trainer_id_
,
0
,
"trainer_id_ >= 0"
);
if
(
strategy_
.
trainer_id_
>
0
&&
strategy_
.
trainers_endpoints_
.
size
()
>
0
)
{
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
strategy_
.
trainer_id_
),
strategy_
.
trainers_endpoints_
.
size
(),
"trainer_id_ < endpoints_ size"
);
}
// Verify that the graph is correct for multi-device executor.
VLOG
(
1
)
<<
"Add multi_devices_check_pass"
;
AppendPass
(
"multi_devices_check_pass"
);
VLOG
(
1
)
<<
"CollectiveContext:"
<<
context
->
String
();
}
// Convert graph to run on multi-devices.
void
AppendMultiDevPass
(
const
BuildStrategy
&
strategy
)
{
void
AppendMultiDevPass
()
{
ir
::
Pass
*
multi_devices_pass
=
nullptr
;
if
(
strategy_
.
async_mode_
)
{
VLOG
(
1
)
<<
"Add async_multi_devices_pass"
;
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
}
else
if
(
strategy_
.
is_distribution_
)
{
VLOG
(
1
)
<<
"Add dist_multi_devices_pass, multi device parameter server mode"
;
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
}
else
{
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
VLOG
(
1
)
<<
"Add all_reduce_mode_multi_devices_pass"
;
multi_devices_pass
=
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
}
else
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
VLOG
(
1
)
<<
"Add reduce_mode_multi_devices_pass"
;
multi_devices_pass
=
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
}
else
{
PADDLE_THROW
(
"Unknown reduce strategy."
);
switch
(
strategy_
.
reduce_
)
{
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
multi_devices_pass
=
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
break
;
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
multi_devices_pass
=
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
break
;
default:
PADDLE_THROW
(
"Unknown reduce strategy."
);
}
}
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy_
);
}
void
AppendPrintGraphPass
(
const
std
::
string
&
pass_name
,
const
std
::
string
&
debug_file_suffix
)
{
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
pass_name
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
debug_file_suffix
);
viz_pass
->
Set
<
std
::
string
>
(
ir
::
kGraphvizPath
,
new
std
::
string
(
graph_path
));
}
}
void
AppendPassWithCheck
(
bool
append_pass
,
const
std
::
string
&
pass_name
)
{
if
(
append_pass
)
{
AppendPass
(
pass_name
);
}
}
void
AppendPassToSetMkldnnAttr
(
const
std
::
string
&
pass_name
)
{
#ifdef PADDLE_WITH_MKLDNN
if
(
FLAGS_use_mkldnn
)
{
AppendPass
(
pass_name
);
}
else
if
(
!
strategy_
.
mkldnn_enabled_op_types_
.
empty
())
{
LOG
(
WARNING
)
<<
"mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false."
;
}
#else
PADDLE_ENFORCE
(
!
FLAGS_use_mkldnn
,
"Please compile with MKLDNN first to use MKLDNN"
);
#endif
}
private:
BuildStrategy
strategy_
;
};
...
...
@@ -307,26 +290,20 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
#endif
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
||
pass
->
Type
()
==
"fuse_adam_op_pass"
||
pass
->
Type
()
==
"fuse_sgd_op_pass"
||
pass
->
Type
()
==
"fuse_momentum_op_pass"
||
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
#endif
}
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
4140fe11
...
...
@@ -88,7 +88,7 @@ struct BuildStrategy {
bool
fuse_elewise_add_act_ops_
{
false
};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
bool
fuse_all_optimizer_ops_
{
fals
e
};
bool
fuse_all_optimizer_ops_
{
tru
e
};
bool
fuse_all_reduce_ops_
{
false
};
// fuse_relu_depthwise_conv can fuse the `relu ->
// depthwise_conv`
...
...
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
浏览文件 @
4140fe11
...
...
@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
}
// namespace paddle
REGISTER_PASS
(
coalesce_grad_tensor_pass
,
paddle
::
framework
::
ir
::
CoalesceGradTensorPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle
::
framework
::
ir
::
CoalesceGradTensorPass
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
浏览文件 @
4140fe11
...
...
@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
ir
::
FuseAdamOpPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
ir
::
FuseAdamOpPass
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
浏览文件 @
4140fe11
...
...
@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
ir
::
FuseMomentumOpPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
ir
::
FuseMomentumOpPass
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
浏览文件 @
4140fe11
...
...
@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
ir
::
FuseSgdOpPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
ir
::
FuseSgdOpPass
);
paddle/fluid/framework/ir/
multi_devices_graph_pass/multi_devices_graph_print_pass
.h
→
paddle/fluid/framework/ir/
graph_printer
.h
浏览文件 @
4140fe11
...
...
@@ -26,7 +26,7 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
constexpr
char
kGraphvizPath
[]
=
"
debug_graph
viz_path"
;
constexpr
char
kGraphvizPath
[]
=
"
graph_
viz_path"
;
class
SSAGraphPrinter
{
public:
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
4140fe11
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
...
...
@@ -25,8 +26,6 @@ namespace framework {
namespace
ir
{
using
inference
::
analysis
::
Dot
;
namespace
{
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
string
FormatName
(
const
Node
*
node
)
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
||
!
node
->
Op
()
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
()))
{
...
...
@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
}
// namespace
void
GraphVizPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphV
izPath
);
const
std
::
string
&
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphv
izPath
);
VLOG
(
3
)
<<
"draw IR graph viz to "
<<
graph_viz_path
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
...
...
@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraph
V
izPath
);
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraph
v
izPath
);
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc
浏览文件 @
4140fe11
...
...
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
if
(
Has
(
"graph_printer"
))
{
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
}
else
{
GraphvizSSAGraphPrinter
printer
;
printer
.
Print
(
*
graph
,
*
fout
);
}
}
};
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
4140fe11
...
...
@@ -24,6 +24,7 @@ namespace framework {
namespace
ir
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
CheckPrevPass
();
PADDLE_ENFORCE
(
graph
,
"graph passed to Pass::Apply() cannot be empty."
);
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
...
...
@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE
(
VarDescIsConsistency
(
*
graph
),
"The VarDescs of persistable variable are not consistency."
);
applied_
=
true
;
if
(
!
graph
->
Has
(
kPassRecorder
))
{
graph
->
Set
<
PassRecorder
>
(
kPassRecorder
,
new
PassRecorder
);
}
graph
->
Get
<
PassRecorder
>
(
kPassRecorder
).
insert
(
Type
());
return
graph
;
}
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
4140fe11
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
@@ -31,6 +32,9 @@ namespace ir {
template
<
typename
PassType
>
struct
PassRegistrar
;
typedef
std
::
unordered_set
<
std
::
string
>
PassRecorder
;
constexpr
char
kPassRecorder
[]
=
"pass_recorder"
;
class
Pass
{
public:
Pass
()
=
default
;
...
...
@@ -104,6 +108,10 @@ class Pass {
LOG
(
FATAL
)
<<
"Calling virtual Pass not implemented."
;
}
// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
virtual
void
CheckPrevPass
()
const
{}
private:
template
<
typename
PassType
>
friend
struct
PassRegistrar
;
...
...
paddle/fluid/framework/ir/pass_builder.cc
浏览文件 @
4140fe11
...
...
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass_builder.h"
#include <memory>
#include <utility>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
std
::
shared_ptr
<
Pass
>
PassBuilder
::
AppendPass
(
const
std
::
string
&
pass_type
)
{
VLOG
(
3
)
<<
"Append "
<<
pass_type
;
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
pass_type
);
passes_
.
emplace_back
(
pass
.
release
());
return
passes_
.
back
();
...
...
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
浏览文件 @
4140fe11
...
...
@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
VLOG
(
3
)
<<
"Use synchronous batch norm"
;
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
IsOp
())
{
if
(
n
->
IsOp
()
&&
n
->
Op
()
)
{
auto
*
op
=
n
->
Op
();
if
(
op
->
Type
()
==
"batch_norm"
)
{
op
->
SetType
(
"sync_batch_norm"
);
...
...
python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py
浏览文件 @
4140fe11
...
...
@@ -32,6 +32,7 @@ feed_dict = {
class
InplaceTestBase
(
unittest
.
TestCase
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
False
def
setUp
(
self
):
self
.
initParameter
()
...
...
@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
self
.
device_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
self
.
device_count
=
4
assert
batch_size
%
self
.
device_count
==
0
def
build_program_and_scope
(
self
):
...
...
@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
fuse_all_optimizer_ops
=
self
.
fuse_all_optimizer_ops
compiled_prog
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
...
...
@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
fuse_all_optimizer_ops
=
self
.
fuse_all_optimizer_ops
compiled_program
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
...
...
@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
class
CPUInplaceTest
(
InplaceTestBase
):
def
initParameter
(
self
):
self
.
use_cuda
=
False
self
.
fuse_all_optimizer_ops
=
False
class
CUDAInplaceTestWithFuseOptimizationOps
(
InplaceTestBase
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
True
class
CPUInplaceTestWithFuseOptimizationOps
(
InplaceTestBase
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
True
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录