Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4140fe11
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4140fe11
编写于
7月 27, 2019
作者:
C
chengduo
提交者:
GitHub
7月 27, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Open fuse optimization ops (#18741)
* open fuse optimization ops test=develop
上级
582cc297
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
202 addition
and
198 deletion
+202
-198
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+153
-176
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+1
-1
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
...framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
...ework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
.../framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
+1
-3
paddle/fluid/framework/ir/graph_printer.h
paddle/fluid/framework/ir/graph_printer.h
+1
-1
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+3
-4
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc
...ulti_devices_graph_pass/multi_devices_graph_print_pass.cc
+7
-2
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+5
-0
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+8
-0
paddle/fluid/framework/ir/pass_builder.cc
paddle/fluid/framework/ir/pass_builder.cc
+3
-0
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
+1
-1
python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py
...d/tests/unittests/test_buffer_shared_memory_reuse_pass.py
+16
-1
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
4140fe11
...
@@ -21,12 +21,12 @@ limitations under the License. */
...
@@ -21,12 +21,12 @@ limitations under the License. */
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool
(
use_mkldnn
);
DECLARE_bool
(
use_mkldnn
);
...
@@ -48,205 +48,153 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -48,205 +48,153 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
public:
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
// Add a graph viz pass to record a graph.
ResolveOptionConfliction
();
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
VLOG
(
1
)
<<
"Add graph_viz_pass"
;
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
}
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_original_graph"
);
VLOG
(
1
)
<<
"Add record_skip_memory_opt_vars_pass"
;
// Note(zcd): record_skip_memory_opt_vars_pass should
// be the first pass.
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
AppendPassWithCheck
(
strategy_
.
enable_sequential_execution_
,
"sequential_execution_pass"
);
AppendPassWithCheck
(
strategy_
.
sync_batch_norm_
,
"sync_batch_norm_pass"
);
AppendOpFusePasses
();
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_fused_graph"
);
// TODO(dev-paddle): memory optimize pass should be placed last.
AppendMemoryOptimizePasses
();
AppendMultiDevPass
();
AppendMultiGraphOptPasses
();
AppendPassToSetMkldnnAttr
(
"mkldnn_placement_pass"
);
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
AppendPassWithCheck
(
strategy_
.
cache_runtime_context_
,
"runtime_context_cache_pass"
);
AppendPassWithCheck
(
strategy_
.
remove_unnecessary_lock_
,
"modify_op_lock_and_record_event_pass"
);
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass
(
"multi_devices_check_pass"
);
#ifdef PADDLE_WITH_MKLDNN
SetCollectiveContext
();
if
(
FLAGS_use_mkldnn
)
{
VLOG
(
1
)
<<
"Add mkldnn_placement_pass"
;
AppendPass
(
"mkldnn_placement_pass"
);
}
else
if
(
!
strategy_
.
mkldnn_enabled_op_types_
.
empty
())
{
LOG
(
WARNING
)
<<
"mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false."
;
}
#else
PADDLE_ENFORCE
(
!
FLAGS_use_mkldnn
,
"Please compile with MKLDNN first to use MKLDNN"
);
#endif
if
(
strategy_
.
enable_sequential_execution_
)
{
VLOG
(
1
)
<<
"Add sequential_execution_pass"
;
AppendPass
(
"sequential_execution_pass"
);
}
}
// Add op fusion.
void
ResolveOptionConfliction
()
{
if
(
strategy
.
sync_batch_norm_
)
{
// Specifies the restrictions between different pass.
AppendPass
(
"sync_batch_norm_pass"
);
if
(
strategy_
.
enable_parallel_graph_
)
{
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
<<
"Currently, fuse_all_optimizer_ops doesn't works under "
"parallel_graph."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
}
if
(
strategy_
.
is_distribution_
)
{
// Add op fusion.
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
if
(
strategy
.
fuse_relu_depthwise_conv_
)
{
<<
"Currently, fuse_all_optimizer_ops only works under "
VLOG
(
1
)
<<
"Add fuse_relu_depthwise_conv_pass
"
;
"Non-distributed mode.
"
;
AppendPass
(
"fuse_relu_depthwise_conv_pass"
)
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
}
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
// TODO(zjl): refactor MemoryOptimizePass to fit
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
// new strategy, which does not need to set
<<
"Currently, fuse_all_optimizer_ops only works under AllReduce "
// var.persistable = True
"mode."
;
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
strategy_
.
fuse_all_optimizer_ops_
=
false
;
if
(
strategy_
.
enable_inplace_
)
{
VLOG_IF
(
3
,
strategy_
.
fuse_all_reduce_ops_
)
VLOG
(
5
)
<<
"Add inplace_pass
"
;
<<
"fuse_all_optimizer_ops only work in Reducer mode.
"
;
AppendPass
(
"inplace_pass"
)
;
strategy_
.
fuse_all_reduce_ops_
=
false
;
}
}
}
}
if
(
strategy_
.
fuse_elewise_add_act_ops_
)
{
void
AppendMultiGraphOptPasses
()
{
VLOG
(
1
)
<<
"Add fuse_elewise_add_act_pass"
;
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
AppendPass
(
"fuse_elewise_add_act_pass"
);
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
AppendPassWithCheck
(
strategy_
.
fuse_all_reduce_ops_
,
"fuse_all_reduce_op_pass"
);
AppendPrintGraphPass
(
"multi_devices_print_pass"
,
"_multi_devices_graph"
);
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
bool
append_all_reduce_deps_pass
=
!
strategy_
.
enable_parallel_graph_
&&
(
SeqOnlyAllReduceOps
(
strategy_
)
||
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
);
AppendPassWithCheck
(
append_all_reduce_deps_pass
,
"all_reduce_deps_pass"
);
bool
append_backward_optimizer_op_deps_pass
=
strategy_
.
num_trainers_
>
1
&&
!
strategy_
.
async_mode_
&&
!
strategy_
.
is_distribution_
&&
strategy_
.
enable_backward_optimizer_op_deps_
;
AppendPassWithCheck
(
append_backward_optimizer_op_deps_pass
,
"backward_optimizer_op_deps_pass"
);
}
}
void
AppendOpFusePasses
()
{
AppendPassWithCheck
(
strategy_
.
fuse_relu_depthwise_conv_
,
"fuse_relu_depthwise_conv_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_elewise_add_act_ops_
,
"fuse_elewise_add_act_pass"
);
// for single card training, fuse_all_reduce_ops is unnecessary.
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
AppendPassWithCheck
(
strategy_
.
fuse_all_reduce_ops_
,
VLOG
(
1
)
<<
"Add coalesce_grad_tensor_pass"
;
"coalesce_grad_tensor_pass"
);
AppendPass
(
"coalesce_grad_tensor_pass"
);
}
// Fuse all the optimization operators.
// Fuse all the optimization operators.
if
(
strategy_
.
is_distribution_
)
{
VLOG
(
3
)
<<
"Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
||
strategy_
.
is_distribution_
)
{
VLOG
(
3
)
<<
"Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
if
(
strategy_
.
fuse_all_optimizer_ops_
)
{
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
// Currently, only one type of optimization algorithm can be fused.
VLOG
(
1
)
<<
"Add fuse_adam_op_pass"
;
if
(
strategy_
.
fuse_all_optimizer_ops_
)
{
AppendPass
(
"fuse_adam_op_pass"
);
AppendPass
(
"fuse_adam_op_pass"
);
VLOG
(
1
)
<<
"Add fuse_sgd_op_pass"
;
AppendPass
(
"fuse_sgd_op_pass"
);
AppendPass
(
"fuse_sgd_op_pass"
);
VLOG
(
1
)
<<
"Add fuse_momentum_op_pass"
;
AppendPass
(
"fuse_momentum_op_pass"
);
AppendPass
(
"fuse_momentum_op_pass"
);
}
}
// Add a graph viz pass to record a graph.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_fused_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
}
}
CollectiveContext
*
context
=
CollectiveContext
::
GetInstance
();
void
AppendMemoryOptimizePasses
()
{
// Append Memory Optimize Pass
context
->
endpoints_
=
strategy_
.
trainers_endpoints_
;
// TODO(zjl): refactor MemoryOptimizePass to fit
context
->
trainer_id_
=
strategy_
.
trainer_id_
;
// new strategy, which does not need to set
PADDLE_ENFORCE
(
strategy_
.
trainer_id_
>=
0
,
"trainer_id_ >= 0"
);
// var.persistable = True
if
(
strategy_
.
trainer_id_
>
0
&&
strategy_
.
trainers_endpoints_
.
size
()
>
0
)
{
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
PADDLE_ENFORCE
((
unsigned
)(
strategy_
.
trainer_id_
)
<
AppendPassWithCheck
(
strategy_
.
enable_inplace_
,
"inplace_pass"
);
strategy_
.
trainers_endpoints_
.
size
(),
"trainer_id_ < endpoints_ size"
);
}
}
VLOG
(
1
)
<<
"CollectiveContext:"
<<
context
->
String
();
// NOTE(dzh): memory optimize should be a runtime pass.
// NOTE(dzh): memory optimize should be a runtime pass.
// However, after multi_devices_pass, VarHandle, OpHandle is
// However, after multi_devices_pass, VarHandle, OpHandle is
// the de-fact IR, any reuse on Graph is meaningless.
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
// , so fetchlist should be set persistable before call the Run interface.
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
if
(
strategy_
.
memory_optimize_
)
{
AppendPassWithCheck
(
strategy_
.
memory_optimize_
,
"memory_optimize_pass"
);
VLOG
(
5
)
<<
"Add memory_optimize_pass"
;
AppendPass
(
"memory_optimize_pass"
);
}
}
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if
(
strategy_
.
cache_runtime_context_
)
{
VLOG
(
1
)
<<
"Add runtime_context_cache_pass"
;
AppendPass
(
"runtime_context_cache_pass"
);
}
AppendMultiDevPass
(
strategy_
);
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG
(
1
)
<<
"Add fuse_all_reduce_op_pass"
;
AppendPass
(
"fuse_all_reduce_op_pass"
);
}
// Add a graph print pass to record a graph with device info.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
VLOG
(
1
)
<<
"Add multi_devices_print_pass"
;
auto
multi_devices_print_pass
=
AppendPass
(
"multi_devices_print_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_multi_devices_graph"
);
multi_devices_print_pass
->
Set
<
std
::
string
>
(
ir
::
kGraphvizPath
,
new
std
::
string
(
graph_path
));
multi_devices_print_pass
->
Set
<
ir
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
ir
::
GraphvizSSAGraphPrinter
);
}
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
if
(
!
strategy_
.
enable_parallel_graph_
&&
(
SeqOnlyAllReduceOps
(
strategy_
)
||
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
))
{
VLOG
(
1
)
<<
"Add all_reduce_deps_pass"
;
AppendPass
(
"all_reduce_deps_pass"
);
}
}
if
(
strategy_
.
num_trainers_
>
1
&&
!
strategy_
.
async_mode_
&&
!
strategy_
.
is_distribution_
&&
strategy_
.
enable_backward_optimizer_op_deps_
)
{
VLOG
(
1
)
<<
"Add backward_op_deps_pass"
;
AppendPass
(
"backward_optimizer_op_deps_pass"
);
}
}
if
(
strategy_
.
remove_unnecessary_lock_
)
{
void
SetCollectiveContext
()
const
{
VLOG
(
1
)
<<
"Add modify_op_lock_and_record_event_pass"
;
CollectiveContext
*
context
=
CollectiveContext
::
GetInstance
();
AppendPass
(
"modify_op_lock_and_record_event_pass"
);
context
->
endpoints_
=
strategy_
.
trainers_endpoints_
;
context
->
trainer_id_
=
strategy_
.
trainer_id_
;
PADDLE_ENFORCE_GE
(
strategy_
.
trainer_id_
,
0
,
"trainer_id_ >= 0"
);
if
(
strategy_
.
trainer_id_
>
0
&&
strategy_
.
trainers_endpoints_
.
size
()
>
0
)
{
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
strategy_
.
trainer_id_
),
strategy_
.
trainers_endpoints_
.
size
(),
"trainer_id_ < endpoints_ size"
);
}
}
VLOG
(
1
)
<<
"CollectiveContext:"
<<
context
->
String
();
// Verify that the graph is correct for multi-device executor.
VLOG
(
1
)
<<
"Add multi_devices_check_pass"
;
AppendPass
(
"multi_devices_check_pass"
);
}
}
// Convert graph to run on multi-devices.
// Convert graph to run on multi-devices.
void
AppendMultiDevPass
(
const
BuildStrategy
&
strategy
)
{
void
AppendMultiDevPass
()
{
ir
::
Pass
*
multi_devices_pass
=
nullptr
;
ir
::
Pass
*
multi_devices_pass
=
nullptr
;
if
(
strategy_
.
async_mode_
)
{
if
(
strategy_
.
async_mode_
)
{
VLOG
(
1
)
<<
"Add async_multi_devices_pass"
;
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
}
else
if
(
strategy_
.
is_distribution_
)
{
}
else
if
(
strategy_
.
is_distribution_
)
{
VLOG
(
1
)
<<
"Add dist_multi_devices_pass, multi device parameter server mode"
;
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
}
else
{
}
else
{
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
switch
(
strategy_
.
reduce_
)
{
VLOG
(
1
)
<<
"Add all_reduce_mode_multi_devices_pass"
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
multi_devices_pass
=
multi_devices_pass
=
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
}
else
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
break
;
VLOG
(
1
)
<<
"Add reduce_mode_multi_devices_pass"
;
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
multi_devices_pass
=
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
multi_devices_pass
=
}
else
{
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
break
;
default:
PADDLE_THROW
(
"Unknown reduce strategy."
);
PADDLE_THROW
(
"Unknown reduce strategy."
);
}
}
}
}
...
@@ -254,6 +202,41 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -254,6 +202,41 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&
strategy_
);
&
strategy_
);
}
}
void
AppendPrintGraphPass
(
const
std
::
string
&
pass_name
,
const
std
::
string
&
debug_file_suffix
)
{
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
pass_name
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
debug_file_suffix
);
viz_pass
->
Set
<
std
::
string
>
(
ir
::
kGraphvizPath
,
new
std
::
string
(
graph_path
));
}
}
void
AppendPassWithCheck
(
bool
append_pass
,
const
std
::
string
&
pass_name
)
{
if
(
append_pass
)
{
AppendPass
(
pass_name
);
}
}
void
AppendPassToSetMkldnnAttr
(
const
std
::
string
&
pass_name
)
{
#ifdef PADDLE_WITH_MKLDNN
if
(
FLAGS_use_mkldnn
)
{
AppendPass
(
pass_name
);
}
else
if
(
!
strategy_
.
mkldnn_enabled_op_types_
.
empty
())
{
LOG
(
WARNING
)
<<
"mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false."
;
}
#else
PADDLE_ENFORCE
(
!
FLAGS_use_mkldnn
,
"Please compile with MKLDNN first to use MKLDNN"
);
#endif
}
private:
private:
BuildStrategy
strategy_
;
BuildStrategy
strategy_
;
};
};
...
@@ -307,17 +290,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -307,17 +290,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
#endif
#endif
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
||
}
else
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
pass
->
Type
()
==
"fuse_adam_op_pass"
||
pass
->
Type
()
==
"fuse_sgd_op_pass"
||
pass
->
Type
()
==
"fuse_momentum_op_pass"
||
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
kLocalScopes
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
&
local_scopes
);
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
...
@@ -326,7 +304,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -326,7 +304,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
new
bool
(
use_hierarchical_allreduce_
));
#endif
#endif
}
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
4140fe11
...
@@ -88,7 +88,7 @@ struct BuildStrategy {
...
@@ -88,7 +88,7 @@ struct BuildStrategy {
bool
fuse_elewise_add_act_ops_
{
false
};
bool
fuse_elewise_add_act_ops_
{
false
};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
// should not be sparse types
bool
fuse_all_optimizer_ops_
{
fals
e
};
bool
fuse_all_optimizer_ops_
{
tru
e
};
bool
fuse_all_reduce_ops_
{
false
};
bool
fuse_all_reduce_ops_
{
false
};
// fuse_relu_depthwise_conv can fuse the `relu ->
// fuse_relu_depthwise_conv can fuse the `relu ->
// depthwise_conv`
// depthwise_conv`
...
...
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
浏览文件 @
4140fe11
...
@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
...
@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
coalesce_grad_tensor_pass
,
REGISTER_PASS
(
coalesce_grad_tensor_pass
,
paddle
::
framework
::
ir
::
CoalesceGradTensorPass
)
paddle
::
framework
::
ir
::
CoalesceGradTensorPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
浏览文件 @
4140fe11
...
@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
...
@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
ir
::
FuseAdamOpPass
)
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
ir
::
FuseAdamOpPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
浏览文件 @
4140fe11
...
@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
...
@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
ir
::
FuseMomentumOpPass
)
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
ir
::
FuseMomentumOpPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
浏览文件 @
4140fe11
...
@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
...
@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
ir
::
FuseSgdOpPass
)
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
ir
::
FuseSgdOpPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/
multi_devices_graph_pass/multi_devices_graph_print_pass
.h
→
paddle/fluid/framework/ir/
graph_printer
.h
浏览文件 @
4140fe11
...
@@ -26,7 +26,7 @@ namespace paddle {
...
@@ -26,7 +26,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
constexpr
char
kGraphvizPath
[]
=
"
debug_graph
viz_path"
;
constexpr
char
kGraphvizPath
[]
=
"
graph_
viz_path"
;
class
SSAGraphPrinter
{
class
SSAGraphPrinter
{
public:
public:
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
4140fe11
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
...
@@ -25,8 +26,6 @@ namespace framework {
...
@@ -25,8 +26,6 @@ namespace framework {
namespace
ir
{
namespace
ir
{
using
inference
::
analysis
::
Dot
;
using
inference
::
analysis
::
Dot
;
namespace
{
namespace
{
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
string
FormatName
(
const
Node
*
node
)
{
std
::
string
FormatName
(
const
Node
*
node
)
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
||
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
||
!
node
->
Op
()
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
()))
{
!
node
->
Op
()
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
()))
{
...
@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
...
@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
}
// namespace
}
// namespace
void
GraphVizPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
GraphVizPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphV
izPath
);
const
std
::
string
&
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphv
izPath
);
VLOG
(
3
)
<<
"draw IR graph viz to "
<<
graph_viz_path
;
VLOG
(
3
)
<<
"draw IR graph viz to "
<<
graph_viz_path
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
PADDLE_ENFORCE
(
fout
->
good
());
...
@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
...
@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraph
V
izPath
);
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraph
v
izPath
);
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc
浏览文件 @
4140fe11
...
@@ -12,12 +12,12 @@
...
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
...
@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
std
::
unique_ptr
<
std
::
ostream
>
fout
(
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
PADDLE_ENFORCE
(
fout
->
good
());
PADDLE_ENFORCE
(
fout
->
good
());
if
(
Has
(
"graph_printer"
))
{
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
}
else
{
GraphvizSSAGraphPrinter
printer
;
printer
.
Print
(
*
graph
,
*
fout
);
}
}
}
};
};
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
4140fe11
...
@@ -24,6 +24,7 @@ namespace framework {
...
@@ -24,6 +24,7 @@ namespace framework {
namespace
ir
{
namespace
ir
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
CheckPrevPass
();
PADDLE_ENFORCE
(
graph
,
"graph passed to Pass::Apply() cannot be empty."
);
PADDLE_ENFORCE
(
graph
,
"graph passed to Pass::Apply() cannot be empty."
);
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
...
@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
...
@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE
(
VarDescIsConsistency
(
*
graph
),
PADDLE_ENFORCE
(
VarDescIsConsistency
(
*
graph
),
"The VarDescs of persistable variable are not consistency."
);
"The VarDescs of persistable variable are not consistency."
);
applied_
=
true
;
applied_
=
true
;
if
(
!
graph
->
Has
(
kPassRecorder
))
{
graph
->
Set
<
PassRecorder
>
(
kPassRecorder
,
new
PassRecorder
);
}
graph
->
Get
<
PassRecorder
>
(
kPassRecorder
).
insert
(
Type
());
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
4140fe11
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
...
@@ -31,6 +32,9 @@ namespace ir {
...
@@ -31,6 +32,9 @@ namespace ir {
template
<
typename
PassType
>
template
<
typename
PassType
>
struct
PassRegistrar
;
struct
PassRegistrar
;
typedef
std
::
unordered_set
<
std
::
string
>
PassRecorder
;
constexpr
char
kPassRecorder
[]
=
"pass_recorder"
;
class
Pass
{
class
Pass
{
public:
public:
Pass
()
=
default
;
Pass
()
=
default
;
...
@@ -104,6 +108,10 @@ class Pass {
...
@@ -104,6 +108,10 @@ class Pass {
LOG
(
FATAL
)
<<
"Calling virtual Pass not implemented."
;
LOG
(
FATAL
)
<<
"Calling virtual Pass not implemented."
;
}
}
// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
virtual
void
CheckPrevPass
()
const
{}
private:
private:
template
<
typename
PassType
>
template
<
typename
PassType
>
friend
struct
PassRegistrar
;
friend
struct
PassRegistrar
;
...
...
paddle/fluid/framework/ir/pass_builder.cc
浏览文件 @
4140fe11
...
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
...
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include <memory>
#include <utility>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
std
::
shared_ptr
<
Pass
>
PassBuilder
::
AppendPass
(
const
std
::
string
&
pass_type
)
{
std
::
shared_ptr
<
Pass
>
PassBuilder
::
AppendPass
(
const
std
::
string
&
pass_type
)
{
VLOG
(
3
)
<<
"Append "
<<
pass_type
;
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
pass_type
);
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
pass_type
);
passes_
.
emplace_back
(
pass
.
release
());
passes_
.
emplace_back
(
pass
.
release
());
return
passes_
.
back
();
return
passes_
.
back
();
...
...
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
浏览文件 @
4140fe11
...
@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
...
@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
VLOG
(
3
)
<<
"Use synchronous batch norm"
;
VLOG
(
3
)
<<
"Use synchronous batch norm"
;
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
IsOp
())
{
if
(
n
->
IsOp
()
&&
n
->
Op
()
)
{
auto
*
op
=
n
->
Op
();
auto
*
op
=
n
->
Op
();
if
(
op
->
Type
()
==
"batch_norm"
)
{
if
(
op
->
Type
()
==
"batch_norm"
)
{
op
->
SetType
(
"sync_batch_norm"
);
op
->
SetType
(
"sync_batch_norm"
);
...
...
python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py
浏览文件 @
4140fe11
...
@@ -32,6 +32,7 @@ feed_dict = {
...
@@ -32,6 +32,7 @@ feed_dict = {
class
InplaceTestBase
(
unittest
.
TestCase
):
class
InplaceTestBase
(
unittest
.
TestCase
):
def
initParameter
(
self
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
False
def
setUp
(
self
):
def
setUp
(
self
):
self
.
initParameter
()
self
.
initParameter
()
...
@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
self
.
device_count
=
fluid
.
core
.
get_cuda_device_count
()
self
.
device_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
else
:
self
.
device_count
=
4
self
.
device_count
=
4
assert
batch_size
%
self
.
device_count
==
0
assert
batch_size
%
self
.
device_count
==
0
def
build_program_and_scope
(
self
):
def
build_program_and_scope
(
self
):
...
@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
fuse_all_optimizer_ops
=
self
.
fuse_all_optimizer_ops
compiled_prog
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
compiled_prog
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
build_strategy
=
build_strategy
,
...
@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
fuse_all_optimizer_ops
=
self
.
fuse_all_optimizer_ops
compiled_program
=
fluid
.
CompiledProgram
(
compiled_program
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
loss_name
=
loss
.
name
,
...
@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
class
CPUInplaceTest
(
InplaceTestBase
):
class
CPUInplaceTest
(
InplaceTestBase
):
def
initParameter
(
self
):
def
initParameter
(
self
):
self
.
use_cuda
=
False
self
.
use_cuda
=
False
self
.
fuse_all_optimizer_ops
=
False
class
CUDAInplaceTestWithFuseOptimizationOps
(
InplaceTestBase
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
True
class
CPUInplaceTestWithFuseOptimizationOps
(
InplaceTestBase
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录