Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4140fe11
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4140fe11
编写于
7月 27, 2019
作者:
C
chengduo
提交者:
GitHub
7月 27, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Open fuse optimization ops (#18741)
* open fuse optimization ops test=develop
上级
582cc297
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
202 addition
and
198 deletion
+202
-198
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+153
-176
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+1
-1
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
...framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
...ework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
+1
-3
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
.../framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
+1
-3
paddle/fluid/framework/ir/graph_printer.h
paddle/fluid/framework/ir/graph_printer.h
+1
-1
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+3
-4
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc
...ulti_devices_graph_pass/multi_devices_graph_print_pass.cc
+7
-2
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+5
-0
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+8
-0
paddle/fluid/framework/ir/pass_builder.cc
paddle/fluid/framework/ir/pass_builder.cc
+3
-0
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
+1
-1
python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py
...d/tests/unittests/test_buffer_shared_memory_reuse_pass.py
+16
-1
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
4140fe11
...
@@ -21,12 +21,12 @@ limitations under the License. */
...
@@ -21,12 +21,12 @@ limitations under the License. */
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool
(
use_mkldnn
);
DECLARE_bool
(
use_mkldnn
);
...
@@ -48,212 +48,195 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -48,212 +48,195 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
public:
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
// Add a graph viz pass to record a graph.
ResolveOptionConfliction
();
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
VLOG
(
1
)
<<
"Add graph_viz_pass"
;
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
}
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_original_graph"
);
VLOG
(
1
)
<<
"Add record_skip_memory_opt_vars_pass"
;
// Note(zcd): record_skip_memory_opt_vars_pass should
// be the first pass.
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
AppendPass
(
"record_skip_memory_opt_vars_pass"
);
AppendPassWithCheck
(
strategy_
.
enable_sequential_execution_
,
"sequential_execution_pass"
);
AppendPassWithCheck
(
strategy_
.
sync_batch_norm_
,
"sync_batch_norm_pass"
);
AppendOpFusePasses
();
AppendPrintGraphPass
(
"graph_viz_pass"
,
"_fused_graph"
);
// TODO(dev-paddle): memory optimize pass should be placed last.
AppendMemoryOptimizePasses
();
AppendMultiDevPass
();
AppendMultiGraphOptPasses
();
AppendPassToSetMkldnnAttr
(
"mkldnn_placement_pass"
);
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
AppendPassWithCheck
(
strategy_
.
cache_runtime_context_
,
"runtime_context_cache_pass"
);
AppendPassWithCheck
(
strategy_
.
remove_unnecessary_lock_
,
"modify_op_lock_and_record_event_pass"
);
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass
(
"multi_devices_check_pass"
);
#ifdef PADDLE_WITH_MKLDNN
SetCollectiveContext
();
if
(
FLAGS_use_mkldnn
)
{
}
VLOG
(
1
)
<<
"Add mkldnn_placement_pass"
;
AppendPass
(
"mkldnn_placement_pass"
);
}
else
if
(
!
strategy_
.
mkldnn_enabled_op_types_
.
empty
())
{
LOG
(
WARNING
)
<<
"mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false."
;
}
#else
PADDLE_ENFORCE
(
!
FLAGS_use_mkldnn
,
"Please compile with MKLDNN first to use MKLDNN"
);
#endif
if
(
strategy_
.
enable_sequential_execution_
)
{
void
ResolveOptionConfliction
()
{
VLOG
(
1
)
<<
"Add sequential_execution_pass"
;
// Specifies the restrictions between different pass.
AppendPass
(
"sequential_execution_pass"
);
if
(
strategy_
.
enable_parallel_graph_
)
{
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
<<
"Currently, fuse_all_optimizer_ops doesn't works under "
"parallel_graph."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
}
if
(
strategy_
.
is_distribution_
)
{
// Add op fusion.
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
if
(
strategy
.
sync_batch_norm_
)
{
<<
"Currently, fuse_all_optimizer_ops only works under "
AppendPass
(
"sync_batch_norm_pass"
);
"Non-distributed mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
}
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
// Add op fusion.
VLOG_IF
(
3
,
strategy_
.
fuse_all_optimizer_ops_
)
if
(
strategy
.
fuse_relu_depthwise_conv_
)
{
<<
"Currently, fuse_all_optimizer_ops only works under AllReduce "
VLOG
(
1
)
<<
"Add fuse_relu_depthwise_conv_pass"
;
"mode."
;
AppendPass
(
"fuse_relu_depthwise_conv_pass"
);
strategy_
.
fuse_all_optimizer_ops_
=
false
;
VLOG_IF
(
3
,
strategy_
.
fuse_all_reduce_ops_
)
<<
"fuse_all_optimizer_ops only work in Reducer mode."
;
strategy_
.
fuse_all_reduce_ops_
=
false
;
}
}
}
// TODO(zjl): refactor MemoryOptimizePass to fit
void
AppendMultiGraphOptPasses
()
{
// new strategy, which does not need to set
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// var.persistable = True
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
AppendPassWithCheck
(
strategy_
.
fuse_all_reduce_ops_
,
if
(
strategy_
.
enable_inplace_
)
{
"fuse_all_reduce_op_pass"
);
VLOG
(
5
)
<<
"Add inplace_pass"
;
AppendPrintGraphPass
(
"multi_devices_print_pass"
,
"_multi_devices_graph"
);
AppendPass
(
"inplace_pass"
);
}
}
if
(
strategy_
.
fuse_elewise_add_act_ops_
)
{
// experimental shows that the program will be faster if append
VLOG
(
1
)
<<
"Add fuse_elewise_add_act_pass"
;
// all_reduce_deps_pass here.
AppendPass
(
"fuse_elewise_add_act_pass"
);
bool
append_all_reduce_deps_pass
=
}
!
strategy_
.
enable_parallel_graph_
&&
(
SeqOnlyAllReduceOps
(
strategy_
)
||
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
);
AppendPassWithCheck
(
append_all_reduce_deps_pass
,
"all_reduce_deps_pass"
);
bool
append_backward_optimizer_op_deps_pass
=
strategy_
.
num_trainers_
>
1
&&
!
strategy_
.
async_mode_
&&
!
strategy_
.
is_distribution_
&&
strategy_
.
enable_backward_optimizer_op_deps_
;
AppendPassWithCheck
(
append_backward_optimizer_op_deps_pass
,
"backward_optimizer_op_deps_pass"
);
}
void
AppendOpFusePasses
()
{
AppendPassWithCheck
(
strategy_
.
fuse_relu_depthwise_conv_
,
"fuse_relu_depthwise_conv_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_elewise_add_act_ops_
,
"fuse_elewise_add_act_pass"
);
// for single card training, fuse_all_reduce_ops is unnecessary.
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
AppendPassWithCheck
(
strategy_
.
fuse_all_reduce_ops_
,
VLOG
(
1
)
<<
"Add coalesce_grad_tensor_pass"
;
"coalesce_grad_tensor_pass"
);
AppendPass
(
"coalesce_grad_tensor_pass"
);
}
// Fuse all the optimization operators.
// Fuse all the optimization operators.
if
(
strategy_
.
is_distribution_
)
{
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
VLOG
(
3
)
<<
"Currently, fuse_all_optimizer_ops only works under "
// if the number is zero, fuse_all_reduce_ops will do nothing.
"Non-distributed mode."
;
// Currently, only one type of optimization algorithm can be fused.
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
||
strategy_
.
is_distribution_
)
{
VLOG
(
3
)
<<
"Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode."
;
strategy_
.
fuse_all_optimizer_ops_
=
false
;
}
if
(
strategy_
.
fuse_all_optimizer_ops_
)
{
if
(
strategy_
.
fuse_all_optimizer_ops_
)
{
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG
(
1
)
<<
"Add fuse_adam_op_pass"
;
AppendPass
(
"fuse_adam_op_pass"
);
AppendPass
(
"fuse_adam_op_pass"
);
VLOG
(
1
)
<<
"Add fuse_sgd_op_pass"
;
AppendPass
(
"fuse_sgd_op_pass"
);
AppendPass
(
"fuse_sgd_op_pass"
);
VLOG
(
1
)
<<
"Add fuse_momentum_op_pass"
;
AppendPass
(
"fuse_momentum_op_pass"
);
AppendPass
(
"fuse_momentum_op_pass"
);
}
}
}
// Add a graph viz pass to record a graph.
void
AppendMemoryOptimizePasses
()
{
// Append Memory Optimize Pass
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
// TODO(zjl): refactor MemoryOptimizePass to fit
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
// new strategy, which does not need to set
const
std
::
string
graph_path
=
string
::
Sprintf
(
// var.persistable = True
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_fused_graph"
);
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
AppendPassWithCheck
(
strategy_
.
enable_inplace_
,
"inplace_pass"
);
}
CollectiveContext
*
context
=
CollectiveContext
::
GetInstance
();
context
->
endpoints_
=
strategy_
.
trainers_endpoints_
;
context
->
trainer_id_
=
strategy_
.
trainer_id_
;
PADDLE_ENFORCE
(
strategy_
.
trainer_id_
>=
0
,
"trainer_id_ >= 0"
);
if
(
strategy_
.
trainer_id_
>
0
&&
strategy_
.
trainers_endpoints_
.
size
()
>
0
)
{
PADDLE_ENFORCE
((
unsigned
)(
strategy_
.
trainer_id_
)
<
strategy_
.
trainers_endpoints_
.
size
(),
"trainer_id_ < endpoints_ size"
);
}
}
VLOG
(
1
)
<<
"CollectiveContext:"
<<
context
->
String
();
// NOTE(dzh): memory optimize should be a runtime pass.
// NOTE(dzh): memory optimize should be a runtime pass.
// However, after multi_devices_pass, VarHandle, OpHandle is
// However, after multi_devices_pass, VarHandle, OpHandle is
// the de-fact IR, any reuse on Graph is meaningless.
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
// , so fetchlist should be set persistable before call the Run interface.
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
if
(
strategy_
.
use_legacy_memory_optimize_strategy_
)
{
if
(
strategy_
.
memory_optimize_
)
{
AppendPassWithCheck
(
strategy_
.
memory_optimize_
,
"memory_optimize_pass"
);
VLOG
(
5
)
<<
"Add memory_optimize_pass"
;
AppendPass
(
"memory_optimize_pass"
);
}
}
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if
(
strategy_
.
cache_runtime_context_
)
{
VLOG
(
1
)
<<
"Add runtime_context_cache_pass"
;
AppendPass
(
"runtime_context_cache_pass"
);
}
AppendMultiDevPass
(
strategy_
);
if
(
strategy_
.
fuse_all_reduce_ops_
)
{
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG
(
1
)
<<
"Add fuse_all_reduce_op_pass"
;
AppendPass
(
"fuse_all_reduce_op_pass"
);
}
// Add a graph print pass to record a graph with device info.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
VLOG
(
1
)
<<
"Add multi_devices_print_pass"
;
auto
multi_devices_print_pass
=
AppendPass
(
"multi_devices_print_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_multi_devices_graph"
);
multi_devices_print_pass
->
Set
<
std
::
string
>
(
ir
::
kGraphvizPath
,
new
std
::
string
(
graph_path
));
multi_devices_print_pass
->
Set
<
ir
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
ir
::
GraphvizSSAGraphPrinter
);
}
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
if
(
!
strategy_
.
enable_parallel_graph_
&&
(
SeqOnlyAllReduceOps
(
strategy_
)
||
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
))
{
VLOG
(
1
)
<<
"Add all_reduce_deps_pass"
;
AppendPass
(
"all_reduce_deps_pass"
);
}
if
(
strategy_
.
num_trainers_
>
1
&&
!
strategy_
.
async_mode_
&&
!
strategy_
.
is_distribution_
&&
strategy_
.
enable_backward_optimizer_op_deps_
)
{
VLOG
(
1
)
<<
"Add backward_op_deps_pass"
;
AppendPass
(
"backward_optimizer_op_deps_pass"
);
}
}
}
if
(
strategy_
.
remove_unnecessary_lock_
)
{
void
SetCollectiveContext
()
const
{
VLOG
(
1
)
<<
"Add modify_op_lock_and_record_event_pass"
;
CollectiveContext
*
context
=
CollectiveContext
::
GetInstance
();
AppendPass
(
"modify_op_lock_and_record_event_pass"
);
context
->
endpoints_
=
strategy_
.
trainers_endpoints_
;
context
->
trainer_id_
=
strategy_
.
trainer_id_
;
PADDLE_ENFORCE_GE
(
strategy_
.
trainer_id_
,
0
,
"trainer_id_ >= 0"
);
if
(
strategy_
.
trainer_id_
>
0
&&
strategy_
.
trainers_endpoints_
.
size
()
>
0
)
{
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
strategy_
.
trainer_id_
),
strategy_
.
trainers_endpoints_
.
size
(),
"trainer_id_ < endpoints_ size"
);
}
}
VLOG
(
1
)
<<
"CollectiveContext:"
<<
context
->
String
();
// Verify that the graph is correct for multi-device executor.
VLOG
(
1
)
<<
"Add multi_devices_check_pass"
;
AppendPass
(
"multi_devices_check_pass"
);
}
}
// Convert graph to run on multi-devices.
// Convert graph to run on multi-devices.
void
AppendMultiDevPass
(
const
BuildStrategy
&
strategy
)
{
void
AppendMultiDevPass
()
{
ir
::
Pass
*
multi_devices_pass
=
nullptr
;
ir
::
Pass
*
multi_devices_pass
=
nullptr
;
if
(
strategy_
.
async_mode_
)
{
if
(
strategy_
.
async_mode_
)
{
VLOG
(
1
)
<<
"Add async_multi_devices_pass"
;
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
}
else
if
(
strategy_
.
is_distribution_
)
{
}
else
if
(
strategy_
.
is_distribution_
)
{
VLOG
(
1
)
<<
"Add dist_multi_devices_pass, multi device parameter server mode"
;
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
}
else
{
}
else
{
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
switch
(
strategy_
.
reduce_
)
{
VLOG
(
1
)
<<
"Add all_reduce_mode_multi_devices_pass"
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
multi_devices_pass
=
multi_devices_pass
=
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
AppendPass
(
"all_reduce_mode_multi_devices_pass"
).
get
();
}
else
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
break
;
VLOG
(
1
)
<<
"Add reduce_mode_multi_devices_pass"
;
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
multi_devices_pass
=
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
multi_devices_pass
=
}
else
{
AppendPass
(
"reduce_mode_multi_devices_pass"
).
get
();
PADDLE_THROW
(
"Unknown reduce strategy."
);
break
;
default:
PADDLE_THROW
(
"Unknown reduce strategy."
);
}
}
}
}
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy_
);
&
strategy_
);
}
}
void
AppendPrintGraphPass
(
const
std
::
string
&
pass_name
,
const
std
::
string
&
debug_file_suffix
)
{
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
pass_name
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
debug_file_suffix
);
viz_pass
->
Set
<
std
::
string
>
(
ir
::
kGraphvizPath
,
new
std
::
string
(
graph_path
));
}
}
void
AppendPassWithCheck
(
bool
append_pass
,
const
std
::
string
&
pass_name
)
{
if
(
append_pass
)
{
AppendPass
(
pass_name
);
}
}
void
AppendPassToSetMkldnnAttr
(
const
std
::
string
&
pass_name
)
{
#ifdef PADDLE_WITH_MKLDNN
if
(
FLAGS_use_mkldnn
)
{
AppendPass
(
pass_name
);
}
else
if
(
!
strategy_
.
mkldnn_enabled_op_types_
.
empty
())
{
LOG
(
WARNING
)
<<
"mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false."
;
}
#else
PADDLE_ENFORCE
(
!
FLAGS_use_mkldnn
,
"Please compile with MKLDNN first to use MKLDNN"
);
#endif
}
private:
private:
BuildStrategy
strategy_
;
BuildStrategy
strategy_
;
};
};
...
@@ -307,26 +290,20 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -307,26 +290,20 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
#endif
#endif
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
||
}
else
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
pass
->
Type
()
==
"fuse_adam_op_pass"
||
pass
->
Type
()
==
"fuse_sgd_op_pass"
||
pass
->
Type
()
==
"fuse_momentum_op_pass"
||
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
kLocalScopes
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
&
local_scopes
);
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
new
bool
(
use_hierarchical_allreduce_
));
#endif
#endif
}
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"coalesce_grad_tensor_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
4140fe11
...
@@ -88,7 +88,7 @@ struct BuildStrategy {
...
@@ -88,7 +88,7 @@ struct BuildStrategy {
bool
fuse_elewise_add_act_ops_
{
false
};
bool
fuse_elewise_add_act_ops_
{
false
};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
// should not be sparse types
bool
fuse_all_optimizer_ops_
{
fals
e
};
bool
fuse_all_optimizer_ops_
{
tru
e
};
bool
fuse_all_reduce_ops_
{
false
};
bool
fuse_all_reduce_ops_
{
false
};
// fuse_relu_depthwise_conv can fuse the `relu ->
// fuse_relu_depthwise_conv can fuse the `relu ->
// depthwise_conv`
// depthwise_conv`
...
...
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
浏览文件 @
4140fe11
...
@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
...
@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
coalesce_grad_tensor_pass
,
REGISTER_PASS
(
coalesce_grad_tensor_pass
,
paddle
::
framework
::
ir
::
CoalesceGradTensorPass
)
paddle
::
framework
::
ir
::
CoalesceGradTensorPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
浏览文件 @
4140fe11
...
@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
...
@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
ir
::
FuseAdamOpPass
)
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
ir
::
FuseAdamOpPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
浏览文件 @
4140fe11
...
@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
...
@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
ir
::
FuseMomentumOpPass
)
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
ir
::
FuseMomentumOpPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
浏览文件 @
4140fe11
...
@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
...
@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
ir
::
FuseSgdOpPass
)
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
ir
::
FuseSgdOpPass
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/
multi_devices_graph_pass/multi_devices_graph_print_pass
.h
→
paddle/fluid/framework/ir/
graph_printer
.h
浏览文件 @
4140fe11
...
@@ -26,7 +26,7 @@ namespace paddle {
...
@@ -26,7 +26,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
constexpr
char
kGraphvizPath
[]
=
"
debug_graph
viz_path"
;
constexpr
char
kGraphvizPath
[]
=
"
graph_
viz_path"
;
class
SSAGraphPrinter
{
class
SSAGraphPrinter
{
public:
public:
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
4140fe11
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
...
@@ -25,8 +26,6 @@ namespace framework {
...
@@ -25,8 +26,6 @@ namespace framework {
namespace
ir
{
namespace
ir
{
using
inference
::
analysis
::
Dot
;
using
inference
::
analysis
::
Dot
;
namespace
{
namespace
{
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
string
FormatName
(
const
Node
*
node
)
{
std
::
string
FormatName
(
const
Node
*
node
)
{
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
||
if
(
!
node
->
IsOp
()
||
!
node
->
Op
()
||
!
node
->
Op
()
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
()))
{
!
node
->
Op
()
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpNamescopeAttrName
()))
{
...
@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
...
@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
}
// namespace
}
// namespace
void
GraphVizPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
GraphVizPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphV
izPath
);
const
std
::
string
&
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphv
izPath
);
VLOG
(
3
)
<<
"draw IR graph viz to "
<<
graph_viz_path
;
VLOG
(
3
)
<<
"draw IR graph viz to "
<<
graph_viz_path
;
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
PADDLE_ENFORCE
(
fout
->
good
());
...
@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
...
@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraph
V
izPath
);
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraph
v
izPath
);
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc
浏览文件 @
4140fe11
...
@@ -12,12 +12,12 @@
...
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
...
@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
std
::
unique_ptr
<
std
::
ostream
>
fout
(
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
PADDLE_ENFORCE
(
fout
->
good
());
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
if
(
Has
(
"graph_printer"
))
{
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
}
else
{
GraphvizSSAGraphPrinter
printer
;
printer
.
Print
(
*
graph
,
*
fout
);
}
}
}
};
};
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
4140fe11
...
@@ -24,6 +24,7 @@ namespace framework {
...
@@ -24,6 +24,7 @@ namespace framework {
namespace
ir
{
namespace
ir
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
CheckPrevPass
();
PADDLE_ENFORCE
(
graph
,
"graph passed to Pass::Apply() cannot be empty."
);
PADDLE_ENFORCE
(
graph
,
"graph passed to Pass::Apply() cannot be empty."
);
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
...
@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
...
@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE
(
VarDescIsConsistency
(
*
graph
),
PADDLE_ENFORCE
(
VarDescIsConsistency
(
*
graph
),
"The VarDescs of persistable variable are not consistency."
);
"The VarDescs of persistable variable are not consistency."
);
applied_
=
true
;
applied_
=
true
;
if
(
!
graph
->
Has
(
kPassRecorder
))
{
graph
->
Set
<
PassRecorder
>
(
kPassRecorder
,
new
PassRecorder
);
}
graph
->
Get
<
PassRecorder
>
(
kPassRecorder
).
insert
(
Type
());
return
graph
;
return
graph
;
}
}
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
4140fe11
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
...
@@ -31,6 +32,9 @@ namespace ir {
...
@@ -31,6 +32,9 @@ namespace ir {
template
<
typename
PassType
>
template
<
typename
PassType
>
struct
PassRegistrar
;
struct
PassRegistrar
;
typedef
std
::
unordered_set
<
std
::
string
>
PassRecorder
;
constexpr
char
kPassRecorder
[]
=
"pass_recorder"
;
class
Pass
{
class
Pass
{
public:
public:
Pass
()
=
default
;
Pass
()
=
default
;
...
@@ -104,6 +108,10 @@ class Pass {
...
@@ -104,6 +108,10 @@ class Pass {
LOG
(
FATAL
)
<<
"Calling virtual Pass not implemented."
;
LOG
(
FATAL
)
<<
"Calling virtual Pass not implemented."
;
}
}
// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
virtual
void
CheckPrevPass
()
const
{}
private:
private:
template
<
typename
PassType
>
template
<
typename
PassType
>
friend
struct
PassRegistrar
;
friend
struct
PassRegistrar
;
...
...
paddle/fluid/framework/ir/pass_builder.cc
浏览文件 @
4140fe11
...
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
...
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include <memory>
#include <utility>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
std
::
shared_ptr
<
Pass
>
PassBuilder
::
AppendPass
(
const
std
::
string
&
pass_type
)
{
std
::
shared_ptr
<
Pass
>
PassBuilder
::
AppendPass
(
const
std
::
string
&
pass_type
)
{
VLOG
(
3
)
<<
"Append "
<<
pass_type
;
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
pass_type
);
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
pass_type
);
passes_
.
emplace_back
(
pass
.
release
());
passes_
.
emplace_back
(
pass
.
release
());
return
passes_
.
back
();
return
passes_
.
back
();
...
...
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
浏览文件 @
4140fe11
...
@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
...
@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
VLOG
(
3
)
<<
"Use synchronous batch norm"
;
VLOG
(
3
)
<<
"Use synchronous batch norm"
;
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
IsOp
())
{
if
(
n
->
IsOp
()
&&
n
->
Op
()
)
{
auto
*
op
=
n
->
Op
();
auto
*
op
=
n
->
Op
();
if
(
op
->
Type
()
==
"batch_norm"
)
{
if
(
op
->
Type
()
==
"batch_norm"
)
{
op
->
SetType
(
"sync_batch_norm"
);
op
->
SetType
(
"sync_batch_norm"
);
...
...
python/paddle/fluid/tests/unittests/test_buffer_shared_memory_reuse_pass.py
浏览文件 @
4140fe11
...
@@ -32,6 +32,7 @@ feed_dict = {
...
@@ -32,6 +32,7 @@ feed_dict = {
class
InplaceTestBase
(
unittest
.
TestCase
):
class
InplaceTestBase
(
unittest
.
TestCase
):
def
initParameter
(
self
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
False
def
setUp
(
self
):
def
setUp
(
self
):
self
.
initParameter
()
self
.
initParameter
()
...
@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
self
.
device_count
=
fluid
.
core
.
get_cuda_device_count
()
self
.
device_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
else
:
self
.
device_count
=
4
self
.
device_count
=
4
assert
batch_size
%
self
.
device_count
==
0
assert
batch_size
%
self
.
device_count
==
0
def
build_program_and_scope
(
self
):
def
build_program_and_scope
(
self
):
...
@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
fuse_all_optimizer_ops
=
self
.
fuse_all_optimizer_ops
compiled_prog
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
compiled_prog
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
build_strategy
=
build_strategy
,
...
@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
memory_optimize
=
memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
fuse_all_optimizer_ops
=
self
.
fuse_all_optimizer_ops
compiled_program
=
fluid
.
CompiledProgram
(
compiled_program
=
fluid
.
CompiledProgram
(
prog
).
with_data_parallel
(
prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
loss_name
=
loss
.
name
,
...
@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
...
@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
class
CPUInplaceTest
(
InplaceTestBase
):
class
CPUInplaceTest
(
InplaceTestBase
):
def
initParameter
(
self
):
def
initParameter
(
self
):
self
.
use_cuda
=
False
self
.
use_cuda
=
False
self
.
fuse_all_optimizer_ops
=
False
class
CUDAInplaceTestWithFuseOptimizationOps
(
InplaceTestBase
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
True
class
CPUInplaceTestWithFuseOptimizationOps
(
InplaceTestBase
):
def
initParameter
(
self
):
self
.
use_cuda
=
True
self
.
fuse_all_optimizer_ops
=
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录