Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8f3b2523
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8f3b2523
编写于
1月 21, 2019
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
squash commits. test=develop
上级
266e625d
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
1228 addition
and
167 deletion
+1228
-167
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-0
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+5
-4
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+15
-5
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+2
-0
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+375
-0
paddle/fluid/framework/details/inplace_op_pass.h
paddle/fluid/framework/details/inplace_op_pass.h
+74
-0
paddle/fluid/framework/details/memory_early_delete_pass.cc
paddle/fluid/framework/details/memory_early_delete_pass.cc
+1
-1
paddle/fluid/framework/details/memory_optimize_helper.cc
paddle/fluid/framework/details/memory_optimize_helper.cc
+43
-9
paddle/fluid/framework/details/memory_optimize_helper.h
paddle/fluid/framework/details/memory_optimize_helper.h
+44
-2
paddle/fluid/framework/details/memory_optimize_helper_test.cc
...le/fluid/framework/details/memory_optimize_helper_test.cc
+3
-3
paddle/fluid/framework/details/memory_optimize_pass.cc
paddle/fluid/framework/details/memory_optimize_pass.cc
+57
-111
paddle/fluid/framework/details/memory_optimize_pass.h
paddle/fluid/framework/details/memory_optimize_pass.h
+3
-9
paddle/fluid/framework/details/memory_optimize_pass_test.cc
paddle/fluid/framework/details/memory_optimize_pass_test.cc
+1
-1
paddle/fluid/framework/details/op_registry.h
paddle/fluid/framework/details/op_registry.h
+18
-3
paddle/fluid/framework/inplace_op_inference.h
paddle/fluid/framework/inplace_op_inference.h
+135
-0
paddle/fluid/framework/inplace_op_inference_test.cc
paddle/fluid/framework/inplace_op_inference_test.cc
+287
-0
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+1
-0
paddle/fluid/framework/op_info.h
paddle/fluid/framework/op_info.h
+1
-0
paddle/fluid/framework/type_defs.h
paddle/fluid/framework/type_defs.h
+3
-0
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+8
-6
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+37
-2
paddle/fluid/operators/elementwise/elementwise_add_op.cc
paddle/fluid/operators/elementwise/elementwise_add_op.cc
+1
-0
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+16
-1
paddle/fluid/operators/flatten_op.cc
paddle/fluid/operators/flatten_op.cc
+36
-4
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+36
-4
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+2
-1
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+15
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-1
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+2
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
8f3b2523
...
...
@@ -200,6 +200,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc
)
cc_test
(
inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info
)
cc_library
(
selected_rows SRCS selected_rows.cc DEPS tensor
)
cc_test
(
selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
8f3b2523
...
...
@@ -50,7 +50,8 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass
)
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc memory_optimize_helper.cc DEPS graph graph_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass DEPS memory_optimize_pass op_info
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
...
...
@@ -65,12 +66,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass
inplace_op_pass
)
if
(
WITH_GPU
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
endif
()
cc_test
(
memory_
reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types
.cc DEPS framework_proto graph
)
cc_test
(
analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types
.cc DEPS framework_proto graph graph_helper op_registry pass
)
cc_test
(
memory_
optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper
.cc DEPS framework_proto graph
)
cc_test
(
memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper
.cc DEPS framework_proto graph graph_helper op_registry pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
8f3b2523
...
...
@@ -17,7 +17,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
...
...
@@ -42,6 +42,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
if
(
strategy_
.
enable_inplace_
)
{
AppendPass
(
"inplace_pass"
);
}
if
(
strategy_
.
enable_sequential_execution_
)
{
AppendPass
(
"sequential_execution_pass"
);
}
...
...
@@ -87,7 +90,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if
(
strategy
.
memory_optimize_
)
{
auto
analysis_var_pass
=
AppendPass
(
"analysis_var
_pass"
);
auto
memory_optimize_pass
=
AppendPass
(
"memory_optimize
_pass"
);
}
AppendMultiDevPass
(
strategy
);
...
...
@@ -185,8 +188,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
}
else
if
(
pass
->
Type
()
==
"analysis_var_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"memory_optimize_pass"
)
{
const
std
::
vector
<
OpDesc
*>
*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
...
...
@@ -213,6 +215,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
}
else
if
(
pass
->
Type
()
==
"inplace_pass"
)
{
if
(
graph
->
Has
(
kAllOpDescs
))
{
graph
->
Erase
(
kAllOpDescs
);
}
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
}
else
if
(
pass
->
Type
()
==
"fuse_relu_depthwise_conv_pass"
)
{
if
(
!
use_cuda
)
{
LOG
(
WARNING
)
<<
"fuse_relu_depthwise_conv_pass is only supported on "
...
...
@@ -238,8 +247,9 @@ USE_PASS(allreduce_mode_multi_devices_pass);
USE_PASS
(
dist_multi_devices_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_print_pass
);
USE_PASS
(
analysis_var
_pass
);
USE_PASS
(
memory_optimize
_pass
);
USE_PASS
(
sequential_execution_pass
);
USE_PASS
(
all_reduce_deps_pass
);
USE_PASS
(
modify_op_lock_and_record_event_pass
);
USE_PASS
(
inplace_pass
);
USE_PASS
(
lock_free_optimize_pass
);
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
8f3b2523
...
...
@@ -80,6 +80,8 @@ struct BuildStrategy {
bool
memory_early_delete_
{
false
};
bool
enable_inplace_
{
false
};
bool
enable_sequential_execution_
{
false
};
bool
fuse_broadcast_op_
{
false
};
...
...
paddle/fluid/framework/details/inplace_op_pass.cc
0 → 100644
浏览文件 @
8f3b2523
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include <algorithm>
#include <deque>
#include <iterator>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
// By our design, one operator only can read its input(const Variable),
// write its output(non-const Variable). If one operator is inplaced, means
// user have chance to write the space before reading happens.
// Especially when some optimize code writing style is applied.
//
//
// /* wrong case in operator */
// /*In this case, a larger allocation is allocated, input content is lost*/
// const Tensor* in = ctx.Input<Tensor>("In")
// Tensor* out = ctx.Output<Tensor>("Out");
// auto* out_ptr = out->mutable_data<T>(ctx.GetPlace());
// out_ptr[0] = 0; // input contect is overwrited.
// For backward compacity. if enable_inplace_whitelist is turn on.
// only the ops in whitelist will be use inplace strategy.
// if not, all the op will be inplaced if it registered with InplaceClass
DEFINE_bool
(
enable_inplace_whitelist
,
true
,
"If this option turns on, only these op in whitelist can be inplaced."
"If it turns off, all of the running op can be candidate of inplaced op."
"Such as scale, elementwise_add"
"By default, it's turned on"
);
// clang-format off
const
std
::
string
kInplacedOpWhiteList
[]
=
{
// NOLINT
"sigmoid"
,
"exp"
,
"relu"
,
"tanh"
,
"sqrt"
,
"ceil"
,
"floor"
,
"reciprocal"
,
"relu6"
,
"soft_relu"
,
"hard_sigmoid"
,
"batch_norm"
,
"batch_norm_grad"
,
"sum"
,
"sum_grad"
,
"scale"
,
"reshape"
,
"elementwise_add"
,
"elementwise_add_grad"
,
};
// clang-format on
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
inline
ir
::
Node
*
GetNextInplacedOpOutput
(
ir
::
Node
*
var
)
{
// if next op is inplaced, then return the output var
// otherwise return nullptr
PADDLE_ENFORCE
(
var
&&
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
ir
::
Node
*
inplaced_var
=
nullptr
;
// only has one output op can be inplaced
if
(
var
->
outputs
.
size
()
==
1
&&
var
->
outputs
[
0
]
->
IsOp
())
{
auto
*
op
=
var
->
outputs
[
0
];
for
(
auto
*
out_var
:
op
->
outputs
)
{
if
(
!
out_var
->
IsVar
()
||
out_var
->
IsCtrlVar
()
||
out_var
->
Var
()
==
nullptr
)
continue
;
if
(
out_var
->
Name
()
==
var
->
Name
())
{
inplaced_var
=
out_var
;
break
;
}
}
}
return
inplaced_var
;
}
static
inline
ir
::
Node
*
GetPrevInplacedOpInput
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
var
&&
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
ir
::
Node
*
inplaced_var
=
nullptr
;
if
(
var
->
inputs
.
size
()
==
1
&&
var
->
inputs
[
0
]
->
IsOp
())
{
auto
*
op
=
var
->
inputs
[
0
];
for
(
auto
*
in_var
:
op
->
inputs
)
{
if
(
!
in_var
->
IsVar
()
||
in_var
->
IsCtrlVar
()
||
in_var
->
Var
()
==
nullptr
)
continue
;
if
(
in_var
->
Name
()
==
var
->
Name
())
{
inplaced_var
=
in_var
;
break
;
}
}
}
return
inplaced_var
;
}
template
<
typename
Container
>
static
inline
bool
ConnectByCtrlVar
(
const
Container
&
group1
,
const
Container
&
group2
)
{
bool
connected
=
false
;
std
::
unordered_set
<
ir
::
Node
*>
outputs
;
for
(
auto
*
op
:
group1
)
{
for
(
auto
*
var
:
op
->
outputs
)
{
if
(
var
->
IsCtrlVar
())
outputs
.
emplace
(
var
);
}
}
for
(
auto
*
op
:
group2
)
{
for
(
auto
*
var
:
op
->
inputs
)
{
if
(
outputs
.
count
(
var
))
connected
=
true
;
}
}
return
connected
;
}
InplacePass
::
InplacePass
()
:
Pass
()
{
if
(
FLAGS_enable_inplace_whitelist
)
{
for
(
auto
&
s
:
kInplacedOpWhiteList
)
{
whitelist_
.
emplace
(
s
);
}
}
}
void
InplacePass
::
InitSSAGraphNodes
()
const
{
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ir
::
Node
*>>
all_vars
;
for
(
auto
*
op
:
view_
.
AllOps
())
{
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
continue
;
if
(
all_vars
[
node
->
Name
()].
count
(
node
)
==
0
)
{
all_vars
[
node
->
Name
()].
emplace
(
node
);
var_nodes_
[
node
->
Name
()].
emplace_back
(
node
);
}
}
for
(
auto
*
node
:
op
->
outputs
)
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
continue
;
if
(
all_vars
[
node
->
Name
()].
count
(
node
)
==
0
)
{
all_vars
[
node
->
Name
()].
emplace
(
node
);
var_nodes_
[
node
->
Name
()].
emplace_back
(
node
);
}
}
}
}
std
::
unique_ptr
<
ir
::
Graph
>
InplacePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
var_nodes_
.
clear
();
view_
.
Build
(
graph
.
get
());
InitSSAGraphNodes
();
for
(
auto
*
op
:
view_
.
AllOps
())
{
if
(
FLAGS_enable_inplace_whitelist
&&
!
whitelist_
.
count
(
op
->
Name
()))
continue
;
TryInplaceOpInputOutput
(
op
,
graph
.
get
());
}
graph
->
ResolveHazard
(
var_nodes_
);
return
graph
;
}
void
InplacePass
::
InplaceModifyDesc
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
)
const
{
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
PADDLE_ENFORCE
(
op
->
IsOp
()
&&
op
->
Op
());
auto
*
op_desc
=
op
->
Op
();
op_desc
->
RenameInput
(
var
,
cache_var
);
op_desc
->
RenameOutput
(
var
,
cache_var
);
if
(
op_desc
->
Block
()
->
HasVar
(
var
))
op_desc
->
Block
()
->
RemoveVar
(
var
);
op_desc
->
Flush
();
}
}
void
InplacePass
::
InplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
var_desc
->
SetName
(
cache_var
);
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
// redirect the input to the latest version of cache_var
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
var_nodes_
[
cache_var
].
back
();
// swap node to cache_node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for
(
auto
*
node
:
op
->
outputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
// swap node to cache node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
cache_node
->
inputs
.
emplace_back
(
op
);
std
::
replace
(
op
->
outputs
.
begin
(),
op
->
outputs
.
end
(),
node
,
cache_node
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
}
}
}
// release node of unused var in graph
for
(
auto
*
node
:
var_nodes_
[
var
])
{
graph
->
RemoveNode
(
node
);
}
var_nodes_
.
at
(
var
).
clear
();
}
void
InplacePass
::
TryInplaceOpInputOutput
(
ir
::
Node
*
op
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
op
->
Op
()
!=
nullptr
&&
op
->
Op
()
->
Block
()
!=
nullptr
,
"op_desc is nullptr"
);
// 3 pre-requirments need to meet if the op want to inplaced.
// 1. infer_inplace_ is registered.
auto
*
op_desc
=
op
->
Op
();
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op_desc
->
Type
()).
infer_inplace_
;
if
(
!
static_cast
<
bool
>
(
infer_inplace
))
return
;
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
infer_inplace
),
"%s's infer_inplace has not been registered"
,
op_desc
->
Type
());
auto
*
block
=
op_desc
->
Block
();
auto
in_to_outs
=
infer_inplace
(
*
op_desc
,
block
);
auto
&
all_ops
=
view_
.
AllOps
();
auto
cursor
=
std
::
find
(
all_ops
.
begin
(),
all_ops
.
end
(),
op
);
size_t
idx
=
std
::
distance
(
all_ops
.
begin
(),
cursor
);
for
(
auto
&
pair
:
in_to_outs
)
{
auto
&
in_var_name
=
pair
.
first
;
auto
&
out_var_name
=
pair
.
second
;
auto
*
in_node
=
view_
.
GetNodeByName
(
in_var_name
,
op
->
inputs
);
auto
*
out_node
=
view_
.
GetNodeByName
(
out_var_name
,
op
->
outputs
);
// 2. there is no external pending op on the input node
if
(
view_
.
PendingOpsOnVar
(
in_node
).
size
()
>
1
)
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s input has external dependency, can not inplaced, %s => %s "
"skiped"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
continue
;
}
// 3. if output reuse input inplaced, the dependency group is not changed.
// For detail, check
// the function description in "OutConnectInputByCtrlVar"
if
(
view_
.
OutConnectInputByCtrlVar
(
in_node
,
out_node
))
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s input output connect by ctrl var, cannot inplaced, %s => %s "
"skiped"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
continue
;
}
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s inplaced"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
InplaceModifyDesc
(
out_var_name
,
in_var_name
,
idx
);
InplaceModifyVar
(
out_var_name
,
in_var_name
,
idx
,
graph
);
}
}
ir
::
Node
*
GraphView
::
GetNodeByName
(
const
std
::
string
&
name
,
const
std
::
vector
<
ir
::
Node
*>&
nodes
)
const
{
// nodes should be op->inputs/outputs
// node in same node do have different name.
std
::
unordered_set
<
std
::
string
>
nodes_in_op
;
bool
has_dup_node
=
std
::
all_of
(
nodes
.
begin
(),
nodes
.
end
(),
[
&
nodes_in_op
](
ir
::
Node
*
node
)
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
()
||
node
->
Var
()
==
nullptr
)
{
if
(
nodes_in_op
.
count
(
node
->
Name
()))
return
true
;
nodes_in_op
.
emplace
(
node
->
Name
());
}
return
false
;
});
PADDLE_ENFORCE
(
has_dup_node
==
false
,
"nodes has same name!"
);
ir
::
Node
*
node
=
nullptr
;
for
(
auto
*
it
:
nodes
)
{
if
(
!
it
->
IsVar
()
||
it
->
IsCtrlVar
()
||
it
->
Var
()
==
nullptr
)
continue
;
if
(
it
->
Name
()
==
name
)
{
node
=
it
;
break
;
}
}
PADDLE_ENFORCE
(
node
!=
nullptr
,
string
::
Sprintf
(
"Not found var %s in nodes!"
,
name
));
return
node
;
}
std
::
vector
<
ir
::
Node
*>
GraphView
::
PendingOpsOnVar
(
ir
::
Node
*
node
)
{
return
node
->
outputs
;
}
void
GraphView
::
Build
(
ir
::
Graph
*
g
)
{
ops_
=
SortOpLikeDescOrder
(
*
g
);
}
const
std
::
vector
<
ir
::
Node
*>
GraphView
::
AllOps
()
{
return
ops_
;
}
bool
GraphView
::
OutConnectInputByCtrlVar
(
ir
::
Node
*
in_var
,
ir
::
Node
*
out_var
)
{
// assume v_a0, v_a1 is variable. v_a0 -> v_a0 means already inplaced.
// v_a1 -> v_a1 means already inplaced.
// Currently we make decision to check if the v_a0 -> v_a1 can be inplace.
//
// v_a0
// +
// |
// v
// v_a0
// +
// |
// v
// v_a1
// +
// |
// v
// v_a1
// start from the first inplaced input v_a0(on the top one).
// Do a DFSSearch, get all its paths. If there is one path connect
// the in_var and out_var which contains control dep var.
// Means there a control path. out_var can not be inplaced use in_var.
std
::
unordered_set
<
ir
::
Node
*>
out_var_set
,
in_var_set
;
ir
::
Node
*
out
=
out_var
;
// get the ops with same output name
while
(
out
!=
nullptr
)
{
out_var_set
.
emplace
(
out
);
out
=
GetNextInplacedOpOutput
(
out
);
}
// get ops with same input name
ir
::
Node
*
in
=
in_var
;
while
(
in
!=
nullptr
)
{
in_var_set
.
emplace
(
in
);
in
=
GetPrevInplacedOpInput
(
in
);
}
// find if there is path with control dep var connect the in_var_set and
// out_var_set
return
ConnectByCtrlVar
(
in_var_set
,
out_var_set
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
inplace_pass
,
paddle
::
framework
::
details
::
InplacePass
);
paddle/fluid/framework/details/inplace_op_pass.h
0 → 100644
浏览文件 @
8f3b2523
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
GraphView
{
public:
GraphView
()
=
default
;
void
Build
(
ir
::
Graph
*
g
);
const
std
::
vector
<
ir
::
Node
*>
AllOps
();
ir
::
Node
*
GetNodeByName
(
const
std
::
string
&
name
,
const
std
::
vector
<
ir
::
Node
*>&
nodes
)
const
;
std
::
vector
<
ir
::
Node
*>
PendingOpsOnVar
(
ir
::
Node
*
var
);
bool
OutConnectInputByCtrlVar
(
ir
::
Node
*
in_var
,
ir
::
Node
*
out_var
);
private:
std
::
vector
<
ir
::
Node
*>
ops_
;
};
class
InplacePass
:
public
ir
::
Pass
{
public:
InplacePass
();
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
void
InitSSAGraphNodes
()
const
;
private:
void
InplaceModifyVar
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
)
const
;
void
TryInplaceOpInputOutput
(
ir
::
Node
*
op
,
ir
::
Graph
*
graph
)
const
;
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
mutable
std
::
unordered_set
<
std
::
string
>
whitelist_
;
mutable
GraphView
view_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/memory_early_delete_pass.cc
浏览文件 @
8f3b2523
...
...
@@ -16,7 +16,7 @@
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
...
paddle/fluid/framework/details/memory_
reuse_types
.cc
→
paddle/fluid/framework/details/memory_
optimize_helper
.cc
浏览文件 @
8f3b2523
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include <iostream>
#include <sstream>
#include <string>
...
...
@@ -83,7 +83,7 @@ struct NodeComparator {
}
};
void
OrderedNode
PairPool
::
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
)
{
void
OrderedNode
List
::
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
)
{
PADDLE_ENFORCE
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
mark_table_
.
count
(
var
->
Name
())
!=
0
)
{
...
...
@@ -119,11 +119,11 @@ void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) {
mark_table_
[
var
->
Name
()]
=
it
;
}
int
OrderedNode
PairPool
::
GetIndex
(
ir
::
Node
*
var
)
{
int
OrderedNode
List
::
GetIndex
(
ir
::
Node
*
var
)
{
return
std
::
distance
(
nodes_
.
begin
(),
mark_table_
[
var
->
Name
()]);
}
ir
::
Node
*
OrderedNode
PairPool
::
NodeMatch
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
OrderedNode
List
::
NodeMatch
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
found_node
=
nullptr
;
NodeComparator
compare_node
;
...
...
@@ -136,13 +136,15 @@ ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const {
return
found_node
;
}
void
OrderedNodePairPool
::
Erase
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
->
Name
()));
nodes_
.
erase
(
mark_table_
[
var
->
Name
()]);
mark_table_
.
erase
(
var
->
Name
());
void
OrderedNodeList
::
Erase
(
ir
::
Node
*
var
)
{
Erase
(
var
->
Name
());
}
void
OrderedNodeList
::
Erase
(
const
std
::
string
&
var
)
{
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
));
nodes_
.
erase
(
mark_table_
[
var
]);
mark_table_
.
erase
(
var
);
}
std
::
string
OrderedNode
PairPool
::
ToString
()
const
{
std
::
string
OrderedNode
List
::
ToString
()
const
{
std
::
stringstream
ss
;
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
ss
<<
DebugString
(
it
->
first
)
<<
" "
;
...
...
@@ -150,6 +152,38 @@ std::string OrderedNodePairPool::ToString() const {
return
ss
.
str
();
}
bool
NodeCanReused
(
ir
::
Node
*
node
)
{
if
(
node
==
nullptr
||
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
auto
*
desc
=
node
->
Var
();
auto
type
=
desc
->
GetType
();
if
(
desc
->
Persistable
()
||
type
!=
proto
::
VarType
::
LOD_TENSOR
||
desc
->
GetShape
().
empty
())
{
return
false
;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std
::
string
name
=
node
->
Name
();
if
(
!
name
.
empty
()
&&
name
[
0
]
==
'@'
&&
name
[
name
.
size
()
-
1
]
==
'@'
)
return
false
;
for
(
auto
*
op
:
node
->
inputs
)
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
// op output force generated in cpu, can not be reused.
return
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
}
}
return
true
;
}
bool
OpHasSubBlock
(
OpDesc
*
desc
)
{
const
AttributeMap
&
attrs
=
desc
->
GetAttrMap
();
for
(
auto
&
attr
:
attrs
)
{
if
(
attr
.
second
.
type
()
==
typeid
(
BlockDesc
*
)
||
// NOLINT
attr
.
second
.
type
()
==
typeid
(
std
::
vector
<
BlockDesc
*>
))
// NOLINT
return
true
;
}
return
false
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/memory_
reuse_types
.h
→
paddle/fluid/framework/details/memory_
optimize_helper
.h
浏览文件 @
8f3b2523
...
...
@@ -43,7 +43,7 @@ using GraphNodePool = std::vector<
// For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete
class
OrderedNode
PairPool
{
class
OrderedNode
List
{
public:
using
NodePair
=
std
::
pair
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
;
using
Iter
=
typename
std
::
list
<
NodePair
>::
iterator
;
...
...
@@ -53,8 +53,12 @@ class OrderedNodePairPool {
void
Erase
(
ir
::
Node
*
var
);
void
Erase
(
const
std
::
string
&
var
);
bool
Has
(
ir
::
Node
*
var
)
{
return
mark_table_
.
count
(
var
->
Name
());
}
bool
Has
(
const
std
::
string
&
var
)
{
return
mark_table_
.
count
(
var
);
}
ir
::
Node
*
NodeMatch
(
ir
::
Node
*
var
)
const
;
// map store non-const iterator, can not promise const
int
GetIndex
(
ir
::
Node
*
var
);
...
...
@@ -67,6 +71,11 @@ class OrderedNodePairPool {
ConstIter
end
()
const
{
return
nodes_
.
end
();
}
size_t
size
()
const
{
return
nodes_
.
size
();
}
void
Clear
()
{
mark_table_
.
clear
();
nodes_
.
clear
();
}
private:
// for searching.
std
::
unordered_map
<
std
::
string
,
Iter
>
mark_table_
;
...
...
@@ -74,14 +83,47 @@ class OrderedNodePairPool {
std
::
list
<
NodePair
>
nodes_
;
};
// valid a tensor can be reuse or not
bool
NodeCanReused
(
ir
::
Node
*
node
);
// check op has subblock or not
bool
OpHasSubBlock
(
OpDesc
*
desc
);
// node memory size in bytes
size_t
NodeSizeInBytes
(
ir
::
Node
*
n
);
std
::
string
DebugString
(
ir
::
Node
*
var
);
// std::string DebugString(VarDesc* var);
VarDesc
*
FindVarDescInBlock
(
ir
::
Node
*
n
);
template
<
typename
Container
,
typename
Callback
>
class
FilterVariableImpl
{
public:
void
operator
()(
const
Container
&
nodes
,
Callback
callback
)
{
for
(
auto
*
node
:
nodes
)
{
callback
(
node
);
}
}
};
// filter var node for op->inputs/outputs
template
<
typename
Callback
>
class
FilterVariableImpl
<
std
::
vector
<
ir
::
Node
*>
,
Callback
>
{
public:
void
operator
()(
const
std
::
vector
<
ir
::
Node
*>&
nodes
,
Callback
callback
)
{
for
(
auto
*
var
:
nodes
)
{
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
callback
(
var
);
}
}
}
};
template
<
typename
Container
,
typename
Callback
>
void
FilterVariables
(
const
Container
&
nodes
,
Callback
callback
)
{
FilterVariableImpl
<
Container
,
Callback
>
()(
nodes
,
callback
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/memory_
reuse_types
_test.cc
→
paddle/fluid/framework/details/memory_
optimize_helper
_test.cc
浏览文件 @
8f3b2523
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include <algorithm>
#include <iostream>
#include <memory>
...
...
@@ -27,8 +27,8 @@ namespace paddle {
namespace
framework
{
namespace
details
{
TEST
(
OrderedNode
PairPool
,
Normal
)
{
OrderedNode
PairPool
pool
;
TEST
(
OrderedNode
List
,
Normal
)
{
OrderedNode
List
pool
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
// clang-format off
...
...
paddle/fluid/framework/details/
analysis_var
_pass.cc
→
paddle/fluid/framework/details/
memory_optimize
_pass.cc
浏览文件 @
8f3b2523
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/
analysis_var
_pass.h"
#include "paddle/fluid/framework/details/
memory_optimize
_pass.h"
#include <algorithm>
#include <atomic>
#include <deque>
...
...
@@ -48,35 +48,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
op1
->
Outputs
()
==
op2
->
Outputs
();
}
template
<
typename
Container
,
typename
Callback
>
class
FilterVariableImpl
{
public:
void
operator
()(
const
Container
&
nodes
,
Callback
callback
)
{
for
(
auto
*
node
:
nodes
)
{
callback
(
node
);
}
}
};
// filter var node for op->inputs/outputs
template
<
typename
Callback
>
class
FilterVariableImpl
<
std
::
vector
<
ir
::
Node
*>
,
Callback
>
{
public:
void
operator
()(
const
std
::
vector
<
ir
::
Node
*>&
nodes
,
Callback
callback
)
{
for
(
auto
*
var
:
nodes
)
{
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
callback
(
var
);
}
}
}
};
template
<
typename
Container
,
typename
Callback
>
void
FilterVariables
(
const
Container
&
nodes
,
Callback
callback
)
{
FilterVariableImpl
<
Container
,
Callback
>
()(
nodes
,
callback
);
}
std
::
unique_ptr
<
ir
::
Graph
>
AnalysisVarPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
MemoryOptimizePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
nodes
=
graph
->
Nodes
();
auto
subblock_vars
=
GetSubBlockVars
(
nodes
);
...
...
@@ -103,48 +75,53 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
}
for
(
auto
&
var
:
op
->
outputs
)
{
if
(
NodeCanReused
(
var
)
&&
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
)
{
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
<<
op
->
Name
();
VLOG
(
3
)
<<
pool_
.
ToString
();
VLOG
(
3
)
<<
"matched in pool : "
<<
((
cache
==
nullptr
)
?
"False"
:
"True"
);
}
if
(
cache
!=
nullptr
)
{
if
(
var
->
Name
()
==
cache
->
Name
())
{
VLOG
(
3
)
<<
"The same cache variable is cascade reused."
<<
var
->
Name
()
<<
" is re-filled to the pool after"
<<
"the reused op is finished. Current op can not "
<<
"replace it again. Skip this candidate."
;
continue
;
}
if
(
!
NodeCanReused
(
var
)
||
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
||
skip_set_
.
count
(
var
->
Name
()))
continue
;
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
<<
op
->
Name
();
VLOG
(
3
)
<<
pool_
.
ToString
();
VLOG
(
3
)
<<
"matched in pool : "
<<
((
cache
==
nullptr
)
?
"False"
:
"True"
);
}
int
node_idx_in_pool
=
pool_
.
GetIndex
(
cache
);
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
std
::
to_string
(
reuse_id
++
),
DebugString
(
var
),
DebugString
(
cache
),
node_idx_in_pool
,
static_cast
<
int
>
(
pool_
.
size
()));
// update CFG Graph on the fly.
// reused var maybe re-fill into the pool
cfg_
->
RenameVarInCFGGraph
(
var
->
Name
(),
cache
->
Name
(),
idx
);
// NOTE(dzhwinter): we need to both update the ProgramDesc
// and IR Graph. because op_desc/var_desc is used in CreateOp,
// CreateVar when running happens. But IR Graph
// define the dependence relationship between nodes.
RenameVarInGraphDesc
(
var
->
Name
(),
cache
->
Name
(),
idx
);
RenameVarInGraphNode
(
var
->
Name
(),
cache
->
Name
(),
idx
,
graph
.
get
());
pool_
.
Erase
(
cache
);
if
(
cache
==
nullptr
)
continue
;
if
(
var
->
Name
()
==
cache
->
Name
())
{
VLOG
(
3
)
<<
"The same cache variable is cascade reused."
<<
var
->
Name
()
<<
" is re-filled to the pool after"
<<
"the reused op is finished. Current op can not "
<<
"replace it again. Skip this candidate."
;
continue
;
int
node_idx_in_pool
=
pool_
.
GetIndex
(
cache
);
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s, cache idx %d, pool size %d"
,
std
::
to_string
(
reuse_id
++
),
DebugString
(
var
),
DebugString
(
cache
),
node_idx_in_pool
,
static_cast
<
int
>
(
pool_
.
size
()));
// update CFG Graph on the fly.
// reused var maybe re-fill into the pool
cfg_
->
RenameVarInCFGGraph
(
var
->
Name
(),
cache
->
Name
(),
idx
);
// NOTE(dzhwinter): we need to both update the ProgramDesc
// and IR Graph. because op_desc/var_desc is used in CreateOp,
// CreateVar when running happens. But IR Graph
// define the dependence relationship between nodes.
RenameVarInGraphDesc
(
var
->
Name
(),
cache
->
Name
(),
idx
);
RenameVarInGraphNode
(
var
->
Name
(),
cache
->
Name
(),
idx
,
graph
.
get
());
pool_
.
Erase
(
cache
);
}
// fill the pool
std
::
unordered_set
<
std
::
string
>
unlived_vars
;
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
if
(
cfg_
->
LiveOut
(
op
).
count
(
var
)
==
0
)
{
unlived_vars
.
emplace
(
var
);
}
}
}
// fill the pool
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
if
(
cfg_
->
LiveOut
(
op
).
count
(
var
)
==
0
)
{
for
(
auto
var
:
unlived_vars
)
{
ir
::
Node
*
var_node
=
cfg_
->
GetNodeFromVarName
(
var
,
op
);
if
(
var_node
==
nullptr
)
continue
;
if
(
NodeCanReused
(
var_node
)
&&
!
pool_
.
Has
(
var_node
))
{
pool_
.
Insert
(
var_node
,
op
);
}
...
...
@@ -177,7 +154,7 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
return
graph
;
}
void
AnalysisVar
Pass
::
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
{
void
MemoryOptimize
Pass
::
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
{
// conditional block, while op and their grad op
auto
*
sub_block_desc
=
AttrReader
(
op_desc
->
GetAttrMap
()).
Get
<
BlockDesc
*>
(
"sub_block"
);
...
...
@@ -247,7 +224,7 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const {
}
}
std
::
unordered_set
<
std
::
string
>
AnalysisVar
Pass
::
GetSubBlockVars
(
std
::
unordered_set
<
std
::
string
>
MemoryOptimize
Pass
::
GetSubBlockVars
(
const
std
::
unordered_set
<
ir
::
Node
*>&
nodes
)
const
{
std
::
unordered_set
<
std
::
string
>
vars
;
for
(
auto
&
op
:
nodes
)
{
...
...
@@ -263,9 +240,9 @@ std::unordered_set<std::string> AnalysisVarPass::GetSubBlockVars(
return
vars
;
}
void
AnalysisVar
Pass
::
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
size_t
idx
)
const
{
void
MemoryOptimize
Pass
::
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
size_t
idx
)
const
{
for
(
size_t
i
=
idx
;
i
<
cfg_
->
Ops
().
size
();
++
i
)
{
auto
*
op
=
cfg_
->
Ops
()[
i
];
PADDLE_ENFORCE
(
op
->
IsOp
()
&&
op
->
Op
());
...
...
@@ -277,7 +254,7 @@ void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var,
}
}
void
AnalysisVar
Pass
::
InitSSAGraphNodes
()
const
{
void
MemoryOptimize
Pass
::
InitSSAGraphNodes
()
const
{
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ir
::
Node
*>>
all_vars
;
if
(
var_nodes_
.
empty
())
{
for
(
auto
*
op
:
cfg_
->
Ops
())
{
...
...
@@ -297,9 +274,10 @@ void AnalysisVarPass::InitSSAGraphNodes() const {
}
}
void
AnalysisVarPass
::
RenameVarInGraphNode
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
size_t
idx
,
ir
::
Graph
*
graph
)
const
{
void
MemoryOptimizePass
::
RenameVarInGraphNode
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
size_t
idx
,
ir
::
Graph
*
graph
)
const
{
// if replace happens, we need to create a newer version cache_var
// but use the same dims/data_type with var.
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
...
...
@@ -358,39 +336,6 @@ void AnalysisVarPass::RenameVarInGraphNode(const std::string& var,
var_nodes_
.
at
(
var
).
clear
();
}
bool
AnalysisVarPass
::
NodeCanReused
(
ir
::
Node
*
node
)
const
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
auto
*
desc
=
node
->
Var
();
auto
type
=
desc
->
GetType
();
if
(
desc
->
Persistable
()
||
type
!=
proto
::
VarType
::
LOD_TENSOR
||
desc
->
GetShape
().
empty
())
{
return
false
;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std
::
string
name
=
node
->
Name
();
if
(
!
name
.
empty
()
&&
name
[
0
]
==
'@'
&&
name
[
name
.
size
()
-
1
]
==
'@'
)
return
false
;
if
(
skip_set_
.
count
(
name
))
return
false
;
for
(
auto
*
op
:
node
->
inputs
)
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
// op output force generated in cpu, can not be reused.
return
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
}
}
return
true
;
}
bool
AnalysisVarPass
::
OpHasSubBlock
(
OpDesc
*
desc
)
const
{
const
AttributeMap
&
attrs
=
desc
->
GetAttrMap
();
for
(
auto
&
attr
:
attrs
)
{
if
(
attr
.
second
.
type
()
==
typeid
(
BlockDesc
*
)
||
// NOLINT
attr
.
second
.
type
()
==
typeid
(
std
::
vector
<
BlockDesc
*>
))
// NOLINT
return
true
;
}
return
false
;
}
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
PADDLE_ENFORCE
(
graph
.
Has
(
kAllOpDescs
),
"Graph has no attribute of kAllOpDescs."
);
...
...
@@ -651,6 +596,7 @@ ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name,
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
analysis_var_pass
,
paddle
::
framework
::
details
::
AnalysisVarPass
)
REGISTER_PASS
(
memory_optimize_pass
,
paddle
::
framework
::
details
::
MemoryOptimizePass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphNodePool
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
paddle/fluid/framework/details/
analysis_var
_pass.h
→
paddle/fluid/framework/details/
memory_optimize
_pass.h
浏览文件 @
8f3b2523
...
...
@@ -25,7 +25,7 @@
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
...
...
@@ -35,12 +35,10 @@ namespace details {
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
// sort op in bfs order
std
::
vector
<
ir
::
Node
*>
BFSSortGraphOps
(
const
ir
::
Graph
&
graph
);
class
ControlFlowGraph
;
class
AnalysisVar
Pass
:
public
ir
::
Pass
{
class
MemoryOptimize
Pass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
...
...
@@ -57,17 +55,13 @@ class AnalysisVarPass : public ir::Pass {
ir
::
Graph
*
graph
)
const
;
void
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
;
// valid a tensor can be reuse or not
bool
NodeCanReused
(
ir
::
Node
*
node
)
const
;
// scan subblock and collect the output/input variables.
std
::
unordered_set
<
std
::
string
>
GetSubBlockVars
(
const
std
::
unordered_set
<
ir
::
Node
*>&
)
const
;
// check op has subblock or not
bool
OpHasSubBlock
(
OpDesc
*
desc
)
const
;
private:
// Reuse Node Pool, Owned.
mutable
OrderedNode
PairPool
pool_
;
mutable
OrderedNode
List
pool_
;
// controlflow Graph
mutable
std
::
unique_ptr
<
ControlFlowGraph
>
cfg_
;
// skip set
...
...
paddle/fluid/framework/details/
analysis_var
_pass_test.cc
→
paddle/fluid/framework/details/
memory_optimize
_pass_test.cc
浏览文件 @
8f3b2523
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/
analysis_var
_pass.h"
#include "paddle/fluid/framework/details/
memory_optimize
_pass.h"
#include <algorithm>
#include <iostream>
#include <iterator>
...
...
paddle/fluid/framework/details/op_registry.h
浏览文件 @
8f3b2523
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <tuple>
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
...
...
@@ -32,7 +33,8 @@ enum OpInfoFillType {
kOpProtoAndCheckerMaker
=
1
,
kGradOpDescMaker
=
2
,
kVarTypeInference
=
3
,
kShapeInference
=
4
kShapeInference
=
4
,
kInplaceOpInference
=
5
};
template
<
typename
T
>
...
...
@@ -48,8 +50,11 @@ struct OpInfoFillTypeID {
?
kVarTypeInference
:
(
std
::
is_base_of
<
InferShapeBase
,
T
>::
value
?
kShapeInference
:
static_cast
<
OpInfoFillType
>
(
-
1
)))));
:
(
std
::
is_base_of
<
InplaceOpInference
,
T
>::
value
?
kInplaceOpInference
:
static_cast
<
OpInfoFillType
>
(
-
1
))))));
}
};
...
...
@@ -139,6 +144,16 @@ struct OpInfoFiller<T, kShapeInference> {
}
};
template
<
typename
T
>
struct
OpInfoFiller
<
T
,
kInplaceOpInference
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
infer_inplace_
=
[](
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
{
T
infer
;
return
infer
(
op_desc
,
block
);
};
}
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/inplace_op_inference.h
0 → 100644
浏览文件 @
8f3b2523
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <numeric>
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
namespace
paddle
{
namespace
framework
{
/*
Inplace Inference for create In->Out pairs for inplaced operator.
If we specify a pair of corresponding names. For example, X->Out.
then Out will inplaced use X's memory. The base class will do
legality validation for both variables.
*/
class
InplaceOpInference
{
public:
virtual
~
InplaceOpInference
()
{}
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
=
0
;
};
class
InplaceInToOut
:
public
InplaceOpInference
{
public:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ret
;
auto
in_out_var_names_pair
=
this
->
Apply
(
op_desc
,
block
);
for
(
auto
&
pair
:
in_out_var_names_pair
)
{
PADDLE_ENFORCE
(
!
op_desc
.
Input
(
pair
.
first
).
empty
(),
string
::
Sprintf
(
"op %s do not have input of %s!"
,
op_desc
.
Type
(),
pair
.
first
));
PADDLE_ENFORCE
(
!
op_desc
.
Output
(
pair
.
second
).
empty
(),
string
::
Sprintf
(
"op %s do not have output of %s!"
,
op_desc
.
Type
(),
pair
.
second
));
auto
&
in_name
=
op_desc
.
Input
(
pair
.
first
).
at
(
0
);
auto
&
out_name
=
op_desc
.
Output
(
pair
.
second
).
at
(
0
);
auto
in
=
block
->
FindRecursiveOrCreateVar
(
in_name
);
auto
out
=
block
->
FindRecursiveOrCreateVar
(
out_name
);
if
(
TryInplaceInputOutput
(
in
,
out
))
ret
.
insert
({
in_name
,
out_name
});
}
return
ret
;
}
protected:
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
=
0
;
bool
TryInplaceInputOutput
(
const
VarDesc
&
in
,
const
VarDesc
&
out
)
const
{
auto
var_can_reused
=
[
&
](
const
VarDesc
&
node
)
->
bool
{
auto
type
=
node
.
GetType
();
if
(
node
.
Persistable
()
||
type
!=
proto
::
VarType
::
LOD_TENSOR
||
node
.
GetShape
().
empty
())
{
return
false
;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std
::
string
name
=
node
.
Name
();
if
(
!
name
.
empty
()
&&
name
[
0
]
==
'@'
&&
name
[
name
.
size
()
-
1
]
==
'@'
)
return
false
;
return
true
;
};
auto
var_size_in_bytes
=
[
&
](
const
VarDesc
&
node
)
->
size_t
{
auto
shape
=
node
.
GetShape
();
int
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
size_t
type_size
=
SizeOfType
(
node
.
GetDataType
());
return
type_size
*
std
::
abs
(
size
);
};
return
in
.
Name
()
!=
out
.
Name
()
&&
var_can_reused
(
in
)
&&
var_can_reused
(
out
)
&&
var_size_in_bytes
(
out
)
<=
var_size_in_bytes
(
in
);
}
};
/*
Inplace In and Out for operator only have an Input and an Output.
For example, activation op.
*/
class
SingleOpInplaceInToOut
:
public
InplaceInToOut
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
PADDLE_ENFORCE
(
!
op_desc
.
InputNames
().
empty
(),
"Op inputs must not be empty"
);
PADDLE_ENFORCE
(
!
op_desc
.
OutputNames
().
empty
(),
"Op outputs must not be empty"
);
auto
x_name
=
op_desc
.
InputNames
().
at
(
0
);
auto
out_name
=
op_desc
.
OutputNames
().
at
(
0
);
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
x_name
,
out_name
}};
}
};
/*
Gradient op. Inplace output use it's Input.
For example, Input@Grad->Input reuse strategy.
*/
class
GradOpInplaceInToOut
:
public
InplaceInToOut
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ret
;
std
::
unordered_set
<
std
::
string
>
output_names
(
op_desc
.
OutputNames
().
begin
(),
op_desc
.
OutputNames
().
end
());
for
(
auto
&
input_name
:
op_desc
.
InputNames
())
{
if
(
output_names
.
count
(
GradVarName
(
input_name
)))
{
ret
.
insert
({
input_name
,
GradVarName
(
input_name
)});
}
}
return
ret
;
}
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/inplace_op_inference_test.cc
0 → 100644
浏览文件 @
8f3b2523
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iterator>
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace
paddle
{
namespace
framework
{
class
NOP
:
public
OperatorBase
{
public:
NOP
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
SingleOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
SingleGradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"single_op_grad"
);
op
->
SetInput
(
"Out"
,
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
return
std
::
unique_ptr
<
OpDesc
>
(
op
);
}
};
class
SingleOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
HasInput
(
"X"
);
ctx
->
HasOutput
(
"Out"
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
}
};
class
SingleGradOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
));
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
}
};
class
MultiOutOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddInput
(
"Y"
,
""
).
AsDuplicable
();
AddInput
(
"Z"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddOutput
(
"YOut"
,
""
);
AddOutput
(
"ZOut"
,
""
);
AddOutput
(
"NotReuseOut"
,
""
);
AddComment
(
""
);
}
};
class
MultiOutShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
ShareDim
(
"X"
,
"Out"
);
ctx
->
ShareDim
(
"Y"
,
"YOut"
);
ctx
->
ShareDim
(
"Z"
,
"ZOut"
);
}
};
class
MultiGradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"multi_out_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
OutputGrad
(
"YOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Z"
),
OutputGrad
(
"ZOut"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
class
MultiOutGradShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Y"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"YOut"
)));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
)));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Z"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"ZOut"
)));
}
};
class
MultiOutInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
"X"
,
"Out"
},
{
"Y"
,
"YOut"
},
{
"Z"
,
"ZOut"
},
};
}
};
class
MultiOutGradInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
framework
::
GradVarName
(
"YOut"
),
framework
::
GradVarName
(
"Y"
)},
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)},
{
framework
::
GradVarName
(
"ZOut"
),
framework
::
GradVarName
(
"Z"
)},
};
}
};
}
// namespace framework
}
// namespace paddle
namespace
f
=
paddle
::
framework
;
REGISTER_OPERATOR
(
single_op
,
f
::
NOP
,
f
::
SingleOpMaker
,
f
::
SingleGradOpMaker
,
f
::
SingleOpInplaceInToOut
,
f
::
SingleOpShapeInference
);
REGISTER_OPERATOR
(
single_op_grad
,
f
::
NOP
,
f
::
SingleOpInplaceInToOut
,
f
::
SingleGradOpShapeInference
);
REGISTER_OPERATOR
(
multi_out_op
,
f
::
NOP
,
f
::
MultiOutOpMaker
,
f
::
MultiGradOpMaker
,
f
::
MultiOutInplaceInToOut
,
f
::
MultiOutShapeInference
);
REGISTER_OPERATOR
(
multi_out_grad
,
f
::
NOP
,
f
::
MultiOutGradInplaceInToOut
,
f
::
MultiOutGradShapeInference
);
namespace
paddle
{
namespace
framework
{
TEST
(
InferInplace
,
SingleOpInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"single_op"
);
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
op
->
SetOutput
(
"Out"
,
{
"test2_out"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetShape
({
32
,
64
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
1ul
);
auto
it
=
in_to_outs
.
begin
();
EXPECT_EQ
(
it
->
first
,
"test2_a"
);
EXPECT_EQ
(
it
->
second
,
"test2_out"
);
}
TEST
(
InferInplace
,
SingleGradOpInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"single_op_grad"
);
op
->
SetInput
(
GradVarName
(
"Out"
),
{
"test2_out"
});
op
->
SetOutput
(
GradVarName
(
"X"
),
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
1ul
);
auto
it
=
in_to_outs
.
begin
();
EXPECT_EQ
(
it
->
first
,
"test2_out"
);
EXPECT_EQ
(
it
->
second
,
"test2_a"
);
}
TEST
(
InferInplace
,
MultiOutInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"multi_out_op"
);
op
->
SetInput
(
"X"
,
{
"a0"
,
"a1"
});
op
->
SetInput
(
"Y"
,
{
"b0"
});
op
->
SetInput
(
"Z"
,
{
"c0"
,
"c1"
});
op
->
SetOutput
(
"Out"
,
{
"o0"
});
op
->
SetOutput
(
"YOut"
,
{
"y0"
});
op
->
SetOutput
(
"ZOut"
,
{
"z0"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
3ul
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
expects
=
{
{
"a0"
,
"o0"
},
{
"b0"
,
"y0"
},
{
"c0"
,
"z0"
},
};
EXPECT_TRUE
(
expects
==
in_to_outs
);
}
TEST
(
InferInplace
,
MultiGradInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"multi_out_grad"
);
op
->
SetInput
(
GradVarName
(
"Out"
),
{
"o0"
});
op
->
SetInput
(
GradVarName
(
"YOut"
),
{
"y0"
});
op
->
SetInput
(
GradVarName
(
"ZOut"
),
{
"z0"
});
op
->
SetOutput
(
GradVarName
(
"X"
),
{
"a0"
,
"a1"
});
op
->
SetOutput
(
GradVarName
(
"Y"
),
{
"b0"
});
op
->
SetOutput
(
GradVarName
(
"Z"
),
{
"c0"
,
"c1"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
3ul
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
expects
=
{
{
"o0"
,
"a0"
},
{
"y0"
,
"b0"
},
{
"z0"
,
"c0"
},
};
EXPECT_TRUE
(
expects
==
in_to_outs
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/node.h
浏览文件 @
8f3b2523
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <functional>
#include <string>
#include <typeindex>
#include <typeinfo>
...
...
paddle/fluid/framework/op_info.h
浏览文件 @
8f3b2523
...
...
@@ -38,6 +38,7 @@ struct OpInfo {
OpAttrChecker
*
checker_
{
nullptr
};
InferVarTypeFN
infer_var_type_
;
InferShapeFN
infer_shape_
;
InferInplaceOpFN
infer_inplace_
;
bool
HasOpProtoAndChecker
()
const
{
return
proto_
!=
nullptr
&&
checker_
!=
nullptr
;
...
...
paddle/fluid/framework/type_defs.h
浏览文件 @
8f3b2523
...
...
@@ -57,5 +57,8 @@ using InferVarTypeFN =
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
using
InplacePair
=
std
::
unordered_map
<
std
::
string
,
std
::
string
>
;
using
InferInplaceOpFN
=
std
::
function
<
InplacePair
(
const
OpDesc
&
,
BlockDesc
*
)
>
;
}
// namespace framework
}
// namespace paddle
paddle/fluid/operators/activation_op.cc
浏览文件 @
8f3b2523
...
...
@@ -547,12 +547,14 @@ namespace ops = paddle::operators;
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker, \
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
8f3b2523
...
...
@@ -602,13 +602,48 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
}
};
class
BatchNormInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
"Mean"
,
"MeanOut"
},
{
"Variance"
,
"VarianceOut"
},
{
"X"
,
"Y"
},
};
return
inplace_in_to_out
;
}
};
class
BatchNormGradInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
{
framework
::
GradVarName
(
"Y"
),
framework
::
GradVarName
(
"X"
)},
{
"SavedMean"
,
framework
::
GradVarName
(
"Scale"
)},
{
"SavedVariance"
,
framework
::
GradVarName
(
"Bias"
)},
};
return
inplace_in_to_out
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
batch_norm
,
ops
::
BatchNormOp
,
ops
::
BatchNormOpMaker
,
ops
::
BatchNormOpInferVarType
,
ops
::
BatchNormGradMaker
);
REGISTER_OPERATOR
(
batch_norm_grad
,
ops
::
BatchNormGradOp
);
ops
::
BatchNormOpInferVarType
,
ops
::
BatchNormGradMaker
,
ops
::
BatchNormInplaceInToOut
);
REGISTER_OPERATOR
(
batch_norm_grad
,
ops
::
BatchNormGradOp
,
ops
::
BatchNormGradInplaceInToOut
);
REGISTER_OP_CPU_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.cc
浏览文件 @
8f3b2523
...
...
@@ -18,6 +18,7 @@ namespace ops = paddle::operators;
REGISTER_ELEMWISE_GRAD_MAKER
(
elementwise_add
,
Add
);
REGISTER_ELEMWISE_EXPLICIT_OP
(
elementwise_add
,
"Add"
,
"Out = X + Y"
,
"Out"
,
"X"
);
REGISTER_OP_CPU_KERNEL
(
elementwise_add
,
ops
::
ElementwiseAddKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
8f3b2523
...
...
@@ -250,6 +250,20 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
}
};
class
ElementwiseOpInplace
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
"X"
,
"Out"
},
};
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -299,6 +313,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker); \
op_type##GradMaker, \
::paddle::operators::ElementwiseOpInplace); \
REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad)
paddle/fluid/operators/flatten_op.cc
浏览文件 @
8f3b2523
...
...
@@ -267,6 +267,35 @@ class Flatten2GradOp : public framework::OperatorBase {
}
};
class
FlattenOpInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
"X"
,
"Out"
},
};
return
inplace_in_to_out
;
}
};
class
FlattenGradInplaceinToOut
:
public
framework
::
InplaceInToOut
{
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)},
};
return
inplace_in_to_out
;
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -275,10 +304,13 @@ USE_OP(reshape);
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
flatten
,
ops
::
FlattenOp
,
ops
::
FlattenOpMaker
,
ops
::
FlattenOpInferShape
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
flatten_grad
,
ops
::
FlattenGradOp
,
ops
::
FlattenGradInferShape
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
,
ops
::
FlattenOpInplaceInToOut
);
REGISTER_OPERATOR
(
flatten_grad
,
ops
::
FlattenGradOp
,
ops
::
FlattenGradInferShape
,
ops
::
FlattenGradInplaceinToOut
);
REGISTER_OPERATOR
(
flatten2
,
ops
::
Flatten2Op
,
ops
::
Flatten2OpMaker
,
ops
::
Flatten2OpInferShape
,
ops
::
Flatten2GradOpMaker
);
ops
::
Flatten2OpInferShape
,
ops
::
Flatten2GradOpMaker
,
ops
::
FlattenOpInplaceInToOut
);
REGISTER_OPERATOR
(
flatten2_grad
,
ops
::
Flatten2GradOp
,
ops
::
Flatten2GradInferShape
);
ops
::
Flatten2GradInferShape
,
ops
::
FlattenGradInplaceinToOut
);
paddle/fluid/operators/reshape_op.cc
浏览文件 @
8f3b2523
...
...
@@ -327,13 +327,44 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
}
};
class
ReshapeOpInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
"X"
,
"Out"
},
};
return
inplace_in_to_out
;
}
};
class
ReshapeGradInplaceInToOut
:
public
framework
::
InplaceInToOut
{
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)},
};
return
inplace_in_to_out
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
reshape
,
ops
::
ReshapeOp
,
ops
::
ReshapeOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
reshape_grad
,
ops
::
ReshapeGradOp
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
,
ops
::
ReshapeOpInplaceInToOut
);
REGISTER_OPERATOR
(
reshape_grad
,
ops
::
ReshapeGradOp
,
ops
::
ReshapeGradInplaceInToOut
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
...
...
@@ -343,8 +374,9 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops
::
ReshapeGradKernel
);
REGISTER_OPERATOR
(
reshape2
,
ops
::
Reshape2Op
,
ops
::
Reshape2OpMaker
,
ops
::
Reshape2GradMaker
);
REGISTER_OPERATOR
(
reshape2_grad
,
ops
::
Reshape2GradOp
);
ops
::
Reshape2GradMaker
,
ops
::
ReshapeOpInplaceInToOut
);
REGISTER_OPERATOR
(
reshape2_grad
,
ops
::
Reshape2GradOp
,
ops
::
ReshapeGradInplaceInToOut
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape2
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
...
...
paddle/fluid/operators/scale_op.cc
浏览文件 @
8f3b2523
...
...
@@ -100,13 +100,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
}
};
using
ScaleOpInplace
=
framework
::
SingleOpInplaceInToOut
;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
,
ops
::
ScaleGradMaker
,
ops
::
ScaleOpVarTypeInference
);
ops
::
ScaleOpVarTypeInference
,
ops
::
ScaleOpInplace
);
REGISTER_OP_CPU_KERNEL
(
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
8f3b2523
...
...
@@ -198,6 +198,21 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
class
SoftmaxInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
"X"
,
"Out"
},
};
}
};
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
8f3b2523
...
...
@@ -1049,6 +1049,10 @@ All parameter, weight, gradient are variables in Paddle.
"memory_early_delete"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
memory_early_delete_
=
b
;
})
.
def_property
(
"enable_inplace"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_inplace_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_inplace_
=
b
;
})
.
def
(
"_finalize_strategy_and_create_passes"
,
[](
BuildStrategy
&
self
)
->
std
::
shared_ptr
<
ir
::
PassBuilder
>
{
return
self
.
CreatePassesFromStrategy
(
true
);
...
...
python/paddle/fluid/__init__.py
浏览文件 @
8f3b2523
...
...
@@ -158,7 +158,8 @@ def __bootstrap__():
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
,
'sync_nccl_allreduce'
,
'limit_of_tmp_allocation'
,
'times_excess_than_required_tmp_allocation'
'times_excess_than_required_tmp_allocation'
,
'enable_inplace_whitelist'
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
8f3b2523
...
...
@@ -41,6 +41,7 @@ class TestParallelExecutorBase(unittest.TestCase):
use_parallel_executor
=
True
,
use_reduce
=
False
,
use_ir_memory_optimize
=
False
,
enable_inplace
=
True
,
fuse_elewise_add_act_ops
=
False
,
fuse_relu_depthwise_conv
=
False
,
optimizer
=
fluid
.
optimizer
.
Adam
,
...
...
@@ -80,6 +81,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
fuse_relu_depthwise_conv
=
fuse_relu_depthwise_conv
build_strategy
.
memory_optimize
=
use_ir_memory_optimize
build_strategy
.
enable_inplace
=
enable_inplace
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
if
use_cuda
and
core
.
is_compiled_with_cuda
():
build_strategy
.
remove_unnecessary_lock
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录