Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
74bc55c2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
74bc55c2
编写于
2月 01, 2019
作者:
X
Xin Pan
提交者:
GitHub
2月 01, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14975 from dzhwinter/ir_inplace_pass
Ir inplace pass
上级
546eefae
9f001c65
变更
42
显示空白变更内容
内联
并排
Showing
42 changed file
with
1647 addition
and
304 deletion
+1647
-304
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+6
-4
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+31
-9
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+5
-0
paddle/fluid/framework/details/graph_test_base.h
paddle/fluid/framework/details/graph_test_base.h
+80
-0
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+431
-0
paddle/fluid/framework/details/inplace_op_pass.h
paddle/fluid/framework/details/inplace_op_pass.h
+93
-0
paddle/fluid/framework/details/memory_early_delete_pass.cc
paddle/fluid/framework/details/memory_early_delete_pass.cc
+1
-1
paddle/fluid/framework/details/memory_optimize_helper.cc
paddle/fluid/framework/details/memory_optimize_helper.cc
+59
-16
paddle/fluid/framework/details/memory_optimize_helper.h
paddle/fluid/framework/details/memory_optimize_helper.h
+50
-2
paddle/fluid/framework/details/memory_optimize_helper_test.cc
...le/fluid/framework/details/memory_optimize_helper_test.cc
+3
-3
paddle/fluid/framework/details/memory_optimize_pass.cc
paddle/fluid/framework/details/memory_optimize_pass.cc
+57
-111
paddle/fluid/framework/details/memory_optimize_pass.h
paddle/fluid/framework/details/memory_optimize_pass.h
+3
-9
paddle/fluid/framework/details/memory_optimize_pass_test.cc
paddle/fluid/framework/details/memory_optimize_pass_test.cc
+2
-55
paddle/fluid/framework/details/op_registry.h
paddle/fluid/framework/details/op_registry.h
+18
-3
paddle/fluid/framework/inplace_op_inference.h
paddle/fluid/framework/inplace_op_inference.h
+115
-0
paddle/fluid/framework/inplace_op_inference_test.cc
paddle/fluid/framework/inplace_op_inference_test.cc
+287
-0
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+25
-6
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+5
-0
paddle/fluid/framework/ir/graph_helper_test.cc
paddle/fluid/framework/ir/graph_helper_test.cc
+11
-0
paddle/fluid/framework/op_info.h
paddle/fluid/framework/op_info.h
+1
-0
paddle/fluid/framework/type_defs.h
paddle/fluid/framework/type_defs.h
+3
-0
paddle/fluid/inference/utils/benchmark_tester.cc
paddle/fluid/inference/utils/benchmark_tester.cc
+2
-2
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+8
-6
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+37
-2
paddle/fluid/operators/elementwise/elementwise_add_op.cc
paddle/fluid/operators/elementwise/elementwise_add_op.cc
+1
-0
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+16
-1
paddle/fluid/operators/flatten_op.cc
paddle/fluid/operators/flatten_op.cc
+36
-4
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+36
-4
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+2
-1
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+15
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-1
python/paddle/fluid/compiler.py
python/paddle/fluid/compiler.py
+5
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+14
-1
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+8
-0
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+3
-0
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+58
-55
python/paddle/fluid/tests/unittests/test_inference_model_io.py
...n/paddle/fluid/tests/unittests/test_inference_model_io.py
+27
-0
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
+76
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
...fluid/tests/unittests/test_parallel_executor_seresnext.py
+7
-7
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
...paddle/fluid/transpiler/memory_optimization_transpiler.py
+2
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
74bc55c2
...
@@ -128,7 +128,7 @@ cc_test(version_test SRCS version_test.cc DEPS version)
...
@@ -128,7 +128,7 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
memory_optimize_helper
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
py_proto_compile
(
framework_py_proto SRCS framework.proto data_feed.proto
)
py_proto_compile
(
framework_py_proto SRCS framework.proto data_feed.proto
)
...
@@ -192,6 +192,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto)
...
@@ -192,6 +192,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_test
(
var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc
)
proto_desc
)
cc_test
(
inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper
)
cc_library
(
selected_rows SRCS selected_rows.cc DEPS tensor
)
cc_library
(
selected_rows SRCS selected_rows.cc DEPS tensor
)
cc_test
(
selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows
)
cc_test
(
selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
74bc55c2
...
@@ -50,7 +50,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
...
@@ -50,7 +50,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass
)
cc_library
(
memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper
)
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
cc_library
(
memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
...
@@ -65,12 +67,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
...
@@ -65,12 +67,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass
inplace_op_pass
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
endif
()
endif
()
cc_test
(
memory_
reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types
.cc DEPS framework_proto graph
)
cc_test
(
memory_
optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper
.cc DEPS framework_proto graph
)
cc_test
(
analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types
.cc DEPS framework_proto graph graph_helper op_registry pass
)
cc_test
(
memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper
.cc DEPS framework_proto graph graph_helper op_registry pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
74bc55c2
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <memory>
#include <memory>
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
...
@@ -47,6 +47,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -47,6 +47,22 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass
(
"sequential_execution_pass"
);
AppendPass
(
"sequential_execution_pass"
);
}
}
// Add op fusion.
if
(
strategy
.
fuse_relu_depthwise_conv_
)
{
AppendPass
(
"fuse_relu_depthwise_conv_pass"
);
}
// NOTE(dzhwinter): A note for automatical inplace.
// 1. modify program desc passes should put
// before inplace pass.
// 2. manually configured inplace should put
// before inplace_pass
// Add automatically inplace.
if
(
strategy_
.
enable_inplace_
)
{
AppendPass
(
"inplace_pass"
);
}
// Add a graph viz pass to record a graph.
// Add a graph viz pass to record a graph.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
...
@@ -55,10 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -55,10 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
}
}
// Add op fusion.
if
(
strategy
.
fuse_relu_depthwise_conv_
)
{
AppendPass
(
"fuse_relu_depthwise_conv_pass"
);
}
if
(
strategy
.
fuse_elewise_add_act_ops_
)
{
if
(
strategy
.
fuse_elewise_add_act_ops_
)
{
auto
fuse_elewise_add_act_pass
=
AppendPass
(
"fuse_elewise_add_act_pass"
);
auto
fuse_elewise_add_act_pass
=
AppendPass
(
"fuse_elewise_add_act_pass"
);
// Add a graph viz pass to record a graph.
// Add a graph viz pass to record a graph.
...
@@ -88,7 +100,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -88,7 +100,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
// , so fetchlist should be set persistable before call the Run interface.
if
(
strategy
.
memory_optimize_
)
{
if
(
strategy
.
memory_optimize_
)
{
auto
analysis_var_pass
=
AppendPass
(
"analysis_var
_pass"
);
auto
memory_optimize_pass
=
AppendPass
(
"memory_optimize
_pass"
);
}
}
AppendMultiDevPass
(
strategy
);
AppendMultiDevPass
(
strategy
);
...
@@ -186,8 +198,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -186,8 +198,10 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
#endif
}
else
if
(
pass
->
Type
()
==
"memory_optimize_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"analysis_var_pass"
)
{
if
(
graph
->
Has
(
kAllOpDescs
))
{
graph
->
Erase
(
kAllOpDescs
);
}
const
std
::
vector
<
OpDesc
*>
*
all_op_descs
=
const
std
::
vector
<
OpDesc
*>
*
all_op_descs
=
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
());
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
...
@@ -214,6 +228,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -214,6 +228,13 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
}
else
if
(
pass
->
Type
()
==
"inplace_pass"
)
{
if
(
graph
->
Has
(
kAllOpDescs
))
{
graph
->
Erase
(
kAllOpDescs
);
}
graph
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
}
else
if
(
pass
->
Type
()
==
"fuse_relu_depthwise_conv_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"fuse_relu_depthwise_conv_pass"
)
{
if
(
!
use_cuda
)
{
if
(
!
use_cuda
)
{
LOG
(
WARNING
)
<<
"fuse_relu_depthwise_conv_pass is only supported on "
LOG
(
WARNING
)
<<
"fuse_relu_depthwise_conv_pass is only supported on "
...
@@ -239,9 +260,10 @@ USE_PASS(allreduce_mode_multi_devices_pass);
...
@@ -239,9 +260,10 @@ USE_PASS(allreduce_mode_multi_devices_pass);
USE_PASS
(
dist_multi_devices_pass
);
USE_PASS
(
dist_multi_devices_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_print_pass
);
USE_PASS
(
multi_devices_print_pass
);
USE_PASS
(
analysis_var
_pass
);
USE_PASS
(
memory_optimize
_pass
);
USE_PASS
(
sequential_execution_pass
);
USE_PASS
(
sequential_execution_pass
);
USE_PASS
(
all_reduce_deps_pass
);
USE_PASS
(
all_reduce_deps_pass
);
USE_PASS
(
modify_op_lock_and_record_event_pass
);
USE_PASS
(
modify_op_lock_and_record_event_pass
);
USE_PASS
(
inplace_pass
);
USE_PASS
(
lock_free_optimize_pass
);
USE_PASS
(
lock_free_optimize_pass
);
USE_PASS
(
graph_to_program_pass
);
USE_PASS
(
graph_to_program_pass
);
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
74bc55c2
...
@@ -80,6 +80,11 @@ struct BuildStrategy {
...
@@ -80,6 +80,11 @@ struct BuildStrategy {
bool
memory_early_delete_
{
false
};
bool
memory_early_delete_
{
false
};
// TODO(dzhwinter):
// make enable_inplace, memory_optimize_
// memory_early_delete_ true by default
bool
enable_inplace_
{
false
};
bool
enable_sequential_execution_
{
false
};
bool
enable_sequential_execution_
{
false
};
bool
fuse_broadcast_op_
{
false
};
bool
fuse_broadcast_op_
{
false
};
...
...
paddle/fluid/framework/details/graph_test_base.h
0 → 100644
浏览文件 @
74bc55c2
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
class
DummyOp
:
public
OperatorBase
{
public:
DummyOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
SumOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
AssignOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
SplitOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
);
AddOutput
(
"Out"
,
""
).
AsDuplicable
();
AddComment
(
""
);
}
};
class
DummyVarTypeInference
:
public
VarTypeInference
{
public:
void
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
type
=
block
->
Var
(
inputs
.
front
())
->
GetType
();
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
block
->
Var
(
out_var_name
)
->
SetType
(
type
);
}
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/inplace_op_pass.cc
0 → 100644
浏览文件 @
74bc55c2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/inplace_op_pass.h"
#include <algorithm>
#include <deque>
#include <iterator>
#include <stack>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
// By our design, one operator only can read its input(const Variable),
// write its output(non-const Variable). If one operator is inplaced, means
// user have chance to write the space before reading happens.
// Especially when some optimize code writing style is applied.
//
//
// /* wrong case in operator */
// /*In this case, a larger allocation is allocated, input content is lost*/
// const Tensor* in = ctx.Input<Tensor>("In")
// Tensor* out = ctx.Output<Tensor>("Out");
// auto* out_ptr = out->mutable_data<T>(ctx.GetPlace());
// out_ptr[0] = 0; // input contect is overwrited.
// NOTE(dzhwinter):
// Only for backward compacity and stable. if enable_inplace_whitelist is turn
// on.
// only the ops in whitelist will be use inplace strategy.
// if not, all the op will be inplaced if it registered with InplaceClass
DEFINE_bool
(
enable_inplace_whitelist
,
false
,
"If this option turns on, only these op in whitelist can be inplaced."
"If it turns off, all of the running op can be candidate of inplaced op."
"Such as scale, elementwise_add"
"By default, it's turned on"
);
DECLARE_string
(
memory_optimize_debug
);
// clang-format off
const
std
::
string
kInplacedOpWhiteList
[]
=
{
// NOLINT
"sigmoid"
,
"exp"
,
"relu"
,
"tanh"
,
"sqrt"
,
"ceil"
,
"floor"
,
"reciprocal"
,
"relu6"
,
"soft_relu"
,
"hard_sigmoid"
,
"batch_norm"
,
"batch_norm_grad"
,
"sum"
,
"sum_grad"
,
"scale"
,
"reshape"
,
"elementwise_add"
,
"elementwise_add_grad"
,
};
// clang-format on
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
inline
ir
::
Node
*
GetNextCascadeInplacedVar
(
ir
::
Node
*
var
)
{
// if next op is inplaced, then return the output var
// otherwise return nullptr
PADDLE_ENFORCE
(
var
&&
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
ir
::
Node
*
inplaced_var
=
nullptr
;
for
(
auto
*
next_op
:
var
->
outputs
)
{
for
(
auto
*
output
:
next_op
->
outputs
)
{
if
(
output
->
IsVar
()
&&
!
output
->
IsCtrlVar
()
&&
output
->
Name
()
==
var
->
Name
())
{
inplaced_var
=
output
;
}
}
}
return
inplaced_var
;
}
static
inline
ir
::
Node
*
GetPrevCascadeInplacedVar
(
ir
::
Node
*
var
)
{
PADDLE_ENFORCE
(
var
&&
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
if
(
var
->
inputs
.
empty
())
return
nullptr
;
auto
*
prev_op
=
var
->
inputs
.
at
(
0
);
auto
input_it
=
std
::
find_if
(
prev_op
->
inputs
.
begin
(),
prev_op
->
inputs
.
end
(),
[
&
](
ir
::
Node
*
node
)
{
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Name
()
==
var
->
Name
())
{
return
true
;
}
else
{
return
false
;
}
});
return
input_it
==
prev_op
->
inputs
.
end
()
?
nullptr
:
*
input_it
;
}
InplacePass
::
InplacePass
()
:
Pass
()
{
if
(
FLAGS_enable_inplace_whitelist
)
{
for
(
auto
&
s
:
kInplacedOpWhiteList
)
{
whitelist_
.
emplace
(
s
);
}
}
}
void
InplacePass
::
InitSSAGraphNodes
()
const
{
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ir
::
Node
*>>
all_vars
;
for
(
auto
*
op
:
view_
.
AllOps
())
{
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
continue
;
if
(
all_vars
[
node
->
Name
()].
count
(
node
)
==
0
)
{
all_vars
[
node
->
Name
()].
emplace
(
node
);
var_nodes_
[
node
->
Name
()].
emplace_back
(
node
);
}
}
for
(
auto
*
node
:
op
->
outputs
)
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
continue
;
if
(
all_vars
[
node
->
Name
()].
count
(
node
)
==
0
)
{
all_vars
[
node
->
Name
()].
emplace
(
node
);
var_nodes_
[
node
->
Name
()].
emplace_back
(
node
);
}
}
}
}
std
::
unique_ptr
<
ir
::
Graph
>
InplacePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
var_nodes_
.
clear
();
view_
.
Build
(
graph
.
get
());
InitSSAGraphNodes
();
for
(
auto
*
op
:
view_
.
AllOps
())
{
if
(
FLAGS_enable_inplace_whitelist
&&
!
whitelist_
.
count
(
op
->
Name
()))
continue
;
TryInplaceOpInputOutput
(
op
,
graph
.
get
());
}
graph
->
ResolveHazard
(
var_nodes_
);
return
graph
;
}
void
InplacePass
::
InplaceModifyDesc
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
)
const
{
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
ir
::
Node
*
op
=
view_
.
AllOps
()[
i
];
PADDLE_ENFORCE
(
op
->
IsOp
()
&&
op
->
Op
());
auto
*
op_desc
=
op
->
Op
();
op_desc
->
RenameInput
(
var
,
cache_var
);
op_desc
->
RenameOutput
(
var
,
cache_var
);
if
(
op_desc
->
Block
()
->
HasVar
(
var
))
op_desc
->
Block
()
->
RemoveVar
(
var
);
op_desc
->
Flush
();
}
}
const
SSANodePair
InplacePass
::
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
var_nodes_
[
var
].
at
(
0
)
->
Var
()
!=
nullptr
);
std
::
unique_ptr
<
VarDesc
>
var_desc
(
new
VarDesc
(
*
var_nodes_
[
var
].
at
(
0
)
->
Var
()));
var_desc
->
SetName
(
cache_var
);
SSANodePair
swap_nodes
;
for
(
size_t
i
=
idx
;
i
<
view_
.
AllOps
().
size
();
++
i
)
{
auto
*
op
=
view_
.
AllOps
()[
i
];
// redirect the input to the latest version of cache_var
for
(
auto
*
node
:
op
->
inputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
// swap node to cache_node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
PADDLE_ENFORCE
(
node
->
inputs
.
size
()
==
1
&&
node
->
inputs
[
0
]
->
IsOp
());
auto
*
prev_op
=
node
->
inputs
[
0
];
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
node
,
cache_node
);
cache_node
->
inputs
.
emplace_back
(
prev_op
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
swap_nodes
.
emplace_back
(
std
::
make_pair
(
node
,
cache_node
));
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for
(
auto
*
node
:
op
->
outputs
)
{
if
(
node
->
Name
()
==
var
)
{
ir
::
Node
*
cache_node
=
graph
->
CreateVarNode
(
var_desc
.
get
());
// swap node to cache node
cache_node
->
outputs
.
insert
(
cache_node
->
outputs
.
end
(),
node
->
outputs
.
begin
(),
node
->
outputs
.
end
());
cache_node
->
inputs
.
emplace_back
(
op
);
std
::
replace
(
op
->
outputs
.
begin
(),
op
->
outputs
.
end
(),
node
,
cache_node
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
node
,
cache_node
);
}
swap_nodes
.
emplace_back
(
std
::
make_pair
(
node
,
cache_node
));
}
}
}
return
swap_nodes
;
}
void
InplacePass
::
CommitModify
(
const
SSANodePair
&
swap_nodes
,
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
swap_nodes
)
{
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
const
std
::
string
var
=
node
->
Name
(),
cache_var
=
cache_node
->
Name
();
var_nodes_
[
cache_var
].
emplace_back
(
cache_node
);
graph
->
RemoveNode
(
node
);
auto
&
nodes
=
var_nodes_
.
at
(
var
);
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
nodes
.
erase
(
std
::
remove
(
nodes
.
begin
(),
nodes
.
end
(),
node
),
nodes
.
end
());
}
}
void
InplacePass
::
WithdrawModify
(
const
SSANodePair
&
nodes
,
ir
::
Graph
*
graph
)
const
{
for
(
auto
&
pair
:
nodes
)
{
auto
*
node
=
pair
.
first
,
*
cache_node
=
pair
.
second
;
const
std
::
string
var
=
node
->
Name
(),
cache_var
=
cache_node
->
Name
();
auto
*
prev_op
=
node
->
inputs
[
0
];
std
::
replace
(
prev_op
->
outputs
.
begin
(),
prev_op
->
outputs
.
end
(),
cache_node
,
node
);
for
(
auto
*
next_op
:
node
->
outputs
)
{
std
::
replace
(
next_op
->
inputs
.
begin
(),
next_op
->
inputs
.
end
(),
cache_node
,
node
);
}
graph
->
RemoveNode
(
cache_node
);
}
}
void
InplacePass
::
TryInplaceOpInputOutput
(
ir
::
Node
*
op
,
ir
::
Graph
*
graph
)
const
{
VLOG
(
4
)
<<
"Try to inplace op "
<<
op
->
Name
();
PADDLE_ENFORCE
(
op
->
Op
()
!=
nullptr
&&
op
->
Op
()
->
Block
()
!=
nullptr
,
"op_desc is nullptr"
);
// some pre-requirments need to meet if the op want to inplaced.
auto
*
op_desc
=
op
->
Op
();
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op_desc
->
Type
()).
infer_inplace_
;
// 1. infer_inplace_ is registered.
if
(
!
static_cast
<
bool
>
(
infer_inplace
))
return
;
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
infer_inplace
),
"%s's infer_inplace has not been registered"
,
op_desc
->
Type
());
auto
*
block
=
op_desc
->
Block
();
auto
in_to_outs
=
infer_inplace
(
*
op_desc
,
block
);
auto
&
all_ops
=
view_
.
AllOps
();
auto
cursor
=
std
::
find
(
all_ops
.
begin
(),
all_ops
.
end
(),
op
);
size_t
idx
=
std
::
distance
(
all_ops
.
begin
(),
cursor
);
for
(
auto
&
pair
:
in_to_outs
)
{
auto
&
in_var_name
=
pair
.
first
;
auto
&
out_var_name
=
pair
.
second
;
auto
*
in_node
=
view_
.
GetNodeByName
(
in_var_name
,
op
->
inputs
);
auto
*
out_node
=
view_
.
GetNodeByName
(
out_var_name
,
op
->
outputs
);
// 2. there is no external pending op on the input node
if
(
view_
.
PendingOpsOnVar
(
in_node
).
size
()
>
1
)
{
VLOG
(
4
)
<<
string
::
Sprintf
(
"Skiped pair %s => %s. %s input has external dependency."
"inplace such pair will overwrite the memory."
,
out_var_name
,
in_var_name
,
op
->
Name
());
continue
;
}
// 3. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future.
if
(
view_
.
InSkipSet
(
out_node
->
Name
()))
{
VLOG
(
4
)
<<
string
::
Sprintf
(
"Skiped %s => %s reused previous memory block in python memory "
"optmize,"
"it inplace may generate a circle"
,
out_var_name
,
in_var_name
,
op
->
Name
());
continue
;
}
// Debug Interface. Which would be skipped by the pass.
if
(
out_node
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"Skiped var by force. FLAGS_memory_optimize_debug="
<<
out_node
->
Name
();
continue
;
}
// NOTE(dzhwinter):
// two stage commit of inplaced process. if after inplace happens generate a
// circle,
// then withdraw the changes. Otherwise, safely add the node.
auto
swap_nodes
=
TryInplaceModifyVar
(
out_var_name
,
in_var_name
,
idx
,
graph
);
if
(
!
ir
::
HasCircle
(
*
graph
))
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"!!! %s, %s => %s inplaced"
,
op
->
Name
(),
out_var_name
,
in_var_name
);
InplaceModifyDesc
(
out_var_name
,
in_var_name
,
idx
);
CommitModify
(
swap_nodes
,
graph
);
}
else
{
VLOG
(
3
)
<<
string
::
Sprintf
(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s"
,
out_var_name
,
in_var_name
,
op
->
Name
());
WithdrawModify
(
swap_nodes
,
graph
);
}
}
}
ir
::
Node
*
GraphView
::
GetNodeByName
(
const
std
::
string
&
name
,
const
std
::
vector
<
ir
::
Node
*>&
nodes
)
const
{
// nodes should be op->inputs/outputs
// node in same node do have different name.
std
::
unordered_set
<
std
::
string
>
nodes_in_op
;
bool
has_dup_node
=
std
::
all_of
(
nodes
.
begin
(),
nodes
.
end
(),
[
&
nodes_in_op
](
ir
::
Node
*
node
)
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
()
||
node
->
Var
()
==
nullptr
)
{
if
(
nodes_in_op
.
count
(
node
->
Name
()))
return
true
;
nodes_in_op
.
emplace
(
node
->
Name
());
}
return
false
;
});
PADDLE_ENFORCE
(
has_dup_node
==
false
,
"nodes has same name!"
);
ir
::
Node
*
node
=
nullptr
;
for
(
auto
*
it
:
nodes
)
{
if
(
!
it
->
IsVar
()
||
it
->
IsCtrlVar
()
||
it
->
Var
()
==
nullptr
)
continue
;
if
(
it
->
Name
()
==
name
)
{
node
=
it
;
break
;
}
}
PADDLE_ENFORCE
(
node
!=
nullptr
,
string
::
Sprintf
(
"Not found var %s in nodes!"
,
name
));
return
node
;
}
std
::
vector
<
ir
::
Node
*>
GraphView
::
PendingOpsOnVar
(
ir
::
Node
*
node
)
{
// get the pending ops depends on same var node.
// because node also maybe a inplaced variable, so need to backtrack all the
// previous inplaced vars.
std
::
vector
<
ir
::
Node
*>
pending_ops
;
ir
::
Node
*
p
=
node
;
while
(
p
!=
nullptr
)
{
pending_ops
.
insert
(
pending_ops
.
end
(),
p
->
outputs
.
begin
(),
p
->
outputs
.
end
());
p
=
GetPrevCascadeInplacedVar
(
p
);
}
return
pending_ops
;
}
void
GraphView
::
Build
(
ir
::
Graph
*
g
)
{
// track the var nodes in correct order.
// Because we insert some new created node. Which may have data race between
// nodes.
// resolve data harzards depends on the var nodes in right order.
ops_
=
SortOpLikeDescOrder
(
*
g
);
// 1. track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph.
std
::
unordered_set
<
std
::
string
>
all_vars
;
for
(
auto
&
node
:
g
->
Nodes
())
{
if
(
node
->
IsVar
())
continue
;
for
(
auto
&
out
:
node
->
outputs
)
{
if
(
out
->
IsCtrlVar
()
||
out
->
Var
()
==
nullptr
)
continue
;
if
(
all_vars
.
count
(
out
->
Name
()))
{
dup_nodes_
.
emplace
(
out
->
Name
());
}
else
{
all_vars
.
emplace
(
out
->
Name
());
}
}
}
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other name.
for
(
auto
&
node
:
g
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
if
(
node
->
Name
()
==
"send"
)
{
for
(
auto
&
in
:
node
->
inputs
)
{
dup_nodes_
.
emplace
(
in
->
Name
());
}
}
if
(
node
->
Name
()
==
"recv"
)
{
for
(
auto
&
out
:
node
->
outputs
)
{
dup_nodes_
.
emplace
(
out
->
Name
());
}
}
}
}
const
std
::
vector
<
ir
::
Node
*>&
GraphView
::
AllOps
()
{
return
ops_
;
}
bool
GraphView
::
InSkipSet
(
const
std
::
string
&
var
)
const
{
return
dup_nodes_
.
count
(
var
);
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
inplace_pass
,
paddle
::
framework
::
details
::
InplacePass
);
paddle/fluid/framework/details/inplace_op_pass.h
0 → 100644
浏览文件 @
74bc55c2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may abtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
GraphView
{
public:
GraphView
()
=
default
;
void
Build
(
ir
::
Graph
*
g
);
const
std
::
vector
<
ir
::
Node
*>&
AllOps
();
ir
::
Node
*
GetNodeByName
(
const
std
::
string
&
name
,
const
std
::
vector
<
ir
::
Node
*>&
nodes
)
const
;
std
::
vector
<
ir
::
Node
*>
PendingOpsOnVar
(
ir
::
Node
*
var
);
// Will Deperated in the future.
// NOTE(dzhwinter) :
// 1. Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
// 2. DistributeTranspiler will use unique name to
// map the parameter and gradient, must be skipped.
bool
InSkipSet
(
const
std
::
string
&
var
)
const
;
private:
std
::
vector
<
ir
::
Node
*>
ops_
;
std
::
unordered_set
<
std
::
string
>
dup_nodes_
;
// mem opt affect nodes
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list_
;
};
typedef
std
::
vector
<
std
::
pair
<
ir
::
Node
*
,
ir
::
Node
*>>
SSANodePair
;
class
InplacePass
:
public
ir
::
Pass
{
public:
InplacePass
();
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
void
InitSSAGraphNodes
()
const
;
private:
const
SSANodePair
TryInplaceModifyVar
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
size_t
&
idx
,
ir
::
Graph
*
graph
)
const
;
void
CommitModify
(
const
SSANodePair
&
,
ir
::
Graph
*
graph
)
const
;
void
WithdrawModify
(
const
SSANodePair
&
nodes
,
ir
::
Graph
*
graph
)
const
;
void
InplaceModifyDesc
(
const
std
::
string
&
in_var
,
const
std
::
string
&
out_var
,
const
size_t
&
idx
)
const
;
void
TryInplaceOpInputOutput
(
ir
::
Node
*
op
,
ir
::
Graph
*
graph
)
const
;
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
mutable
std
::
unordered_set
<
std
::
string
>
whitelist_
;
mutable
GraphView
view_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/memory_early_delete_pass.cc
浏览文件 @
74bc55c2
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include <queue>
#include <queue>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
...
paddle/fluid/framework/details/memory_
reuse_types
.cc
→
paddle/fluid/framework/details/memory_
optimize_helper
.cc
浏览文件 @
74bc55c2
...
@@ -12,8 +12,10 @@
...
@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/memory_reuse_types.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <functional>
#include <iostream>
#include <iostream>
#include <numeric>
#include <sstream>
#include <sstream>
#include <string>
#include <string>
...
@@ -21,15 +23,17 @@ namespace paddle {
...
@@ -21,15 +23,17 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
size_t
NodeSizeInBytes
(
const
VarDesc
&
node
)
{
auto
shape
=
node
.
GetShape
();
int
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
size_t
type_size
=
SizeOfType
(
node
.
GetDataType
());
return
type_size
*
std
::
abs
(
size
);
}
size_t
NodeSizeInBytes
(
ir
::
Node
*
n
)
{
size_t
NodeSizeInBytes
(
ir
::
Node
*
n
)
{
auto
*
desc
=
FindVarDescInBlock
(
n
);
auto
*
desc
=
FindVarDescInBlock
(
n
);
auto
shape
=
desc
->
GetShape
();
return
NodeSizeInBytes
(
*
desc
);
size_t
type_size
=
SizeOfType
(
desc
->
GetDataType
());
int
size
=
1
;
for
(
auto
&
s
:
shape
)
{
size
*=
s
;
}
return
type_size
*
std
::
abs
(
size
);
}
}
std
::
string
DebugStringImpl
(
VarDesc
*
var
)
{
std
::
string
DebugStringImpl
(
VarDesc
*
var
)
{
...
@@ -83,7 +87,7 @@ struct NodeComparator {
...
@@ -83,7 +87,7 @@ struct NodeComparator {
}
}
};
};
void
OrderedNode
PairPool
::
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
)
{
void
OrderedNode
List
::
Insert
(
ir
::
Node
*
var
,
ir
::
Node
*
op
)
{
PADDLE_ENFORCE
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
PADDLE_ENFORCE
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
());
PADDLE_ENFORCE
(
op
->
IsOp
());
PADDLE_ENFORCE
(
op
->
IsOp
());
if
(
mark_table_
.
count
(
var
->
Name
())
!=
0
)
{
if
(
mark_table_
.
count
(
var
->
Name
())
!=
0
)
{
...
@@ -119,11 +123,11 @@ void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) {
...
@@ -119,11 +123,11 @@ void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) {
mark_table_
[
var
->
Name
()]
=
it
;
mark_table_
[
var
->
Name
()]
=
it
;
}
}
int
OrderedNode
PairPool
::
GetIndex
(
ir
::
Node
*
var
)
{
int
OrderedNode
List
::
GetIndex
(
ir
::
Node
*
var
)
{
return
std
::
distance
(
nodes_
.
begin
(),
mark_table_
[
var
->
Name
()]);
return
std
::
distance
(
nodes_
.
begin
(),
mark_table_
[
var
->
Name
()]);
}
}
ir
::
Node
*
OrderedNode
PairPool
::
NodeMatch
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
OrderedNode
List
::
NodeMatch
(
ir
::
Node
*
var
)
const
{
ir
::
Node
*
found_node
=
nullptr
;
ir
::
Node
*
found_node
=
nullptr
;
NodeComparator
compare_node
;
NodeComparator
compare_node
;
...
@@ -136,13 +140,15 @@ ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const {
...
@@ -136,13 +140,15 @@ ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const {
return
found_node
;
return
found_node
;
}
}
void
OrderedNodePairPool
::
Erase
(
ir
::
Node
*
var
)
{
void
OrderedNodeList
::
Erase
(
ir
::
Node
*
var
)
{
Erase
(
var
->
Name
());
}
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
->
Name
()));
nodes_
.
erase
(
mark_table_
[
var
->
Name
()]);
void
OrderedNodeList
::
Erase
(
const
std
::
string
&
var
)
{
mark_table_
.
erase
(
var
->
Name
());
PADDLE_ENFORCE
(
mark_table_
.
count
(
var
));
nodes_
.
erase
(
mark_table_
[
var
]);
mark_table_
.
erase
(
var
);
}
}
std
::
string
OrderedNode
PairPool
::
ToString
()
const
{
std
::
string
OrderedNode
List
::
ToString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
for
(
auto
it
=
nodes_
.
begin
();
it
!=
nodes_
.
end
();
++
it
)
{
ss
<<
DebugString
(
it
->
first
)
<<
" "
;
ss
<<
DebugString
(
it
->
first
)
<<
" "
;
...
@@ -150,6 +156,43 @@ std::string OrderedNodePairPool::ToString() const {
...
@@ -150,6 +156,43 @@ std::string OrderedNodePairPool::ToString() const {
return
ss
.
str
();
return
ss
.
str
();
}
}
bool
NodeCanReused
(
ir
::
Node
*
node
)
{
if
(
node
==
nullptr
||
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
// auto* desc = node->Var();
bool
flag
=
NodeCanReused
(
*
node
->
Var
());
for
(
auto
*
op
:
node
->
inputs
)
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
// op output force generated in cpu, can not be reused.
flag
&=
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
}
}
return
flag
;
}
bool
NodeCanReused
(
const
VarDesc
&
node
)
{
auto
type
=
node
.
GetType
();
if
(
node
.
Persistable
()
||
type
!=
proto
::
VarType
::
LOD_TENSOR
||
node
.
GetShape
().
empty
())
{
return
false
;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std
::
string
name
=
node
.
Name
();
if
(
!
name
.
empty
()
&&
name
[
0
]
==
'@'
&&
name
[
name
.
size
()
-
1
]
==
'@'
)
return
false
;
return
true
;
}
bool
OpHasSubBlock
(
OpDesc
*
desc
)
{
const
AttributeMap
&
attrs
=
desc
->
GetAttrMap
();
for
(
auto
&
attr
:
attrs
)
{
if
(
attr
.
second
.
type
()
==
typeid
(
BlockDesc
*
)
||
// NOLINT
attr
.
second
.
type
()
==
typeid
(
std
::
vector
<
BlockDesc
*>
))
// NOLINT
return
true
;
}
return
false
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/memory_
reuse_types
.h
→
paddle/fluid/framework/details/memory_
optimize_helper
.h
浏览文件 @
74bc55c2
...
@@ -43,7 +43,7 @@ using GraphNodePool = std::vector<
...
@@ -43,7 +43,7 @@ using GraphNodePool = std::vector<
// For example,
// For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete
// O(1) insert, delete
class
OrderedNode
PairPool
{
class
OrderedNode
List
{
public:
public:
using
NodePair
=
std
::
pair
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
;
using
NodePair
=
std
::
pair
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
;
using
Iter
=
typename
std
::
list
<
NodePair
>::
iterator
;
using
Iter
=
typename
std
::
list
<
NodePair
>::
iterator
;
...
@@ -53,8 +53,12 @@ class OrderedNodePairPool {
...
@@ -53,8 +53,12 @@ class OrderedNodePairPool {
void
Erase
(
ir
::
Node
*
var
);
void
Erase
(
ir
::
Node
*
var
);
void
Erase
(
const
std
::
string
&
var
);
bool
Has
(
ir
::
Node
*
var
)
{
return
mark_table_
.
count
(
var
->
Name
());
}
bool
Has
(
ir
::
Node
*
var
)
{
return
mark_table_
.
count
(
var
->
Name
());
}
bool
Has
(
const
std
::
string
&
var
)
{
return
mark_table_
.
count
(
var
);
}
ir
::
Node
*
NodeMatch
(
ir
::
Node
*
var
)
const
;
ir
::
Node
*
NodeMatch
(
ir
::
Node
*
var
)
const
;
// map store non-const iterator, can not promise const
// map store non-const iterator, can not promise const
int
GetIndex
(
ir
::
Node
*
var
);
int
GetIndex
(
ir
::
Node
*
var
);
...
@@ -67,6 +71,11 @@ class OrderedNodePairPool {
...
@@ -67,6 +71,11 @@ class OrderedNodePairPool {
ConstIter
end
()
const
{
return
nodes_
.
end
();
}
ConstIter
end
()
const
{
return
nodes_
.
end
();
}
size_t
size
()
const
{
return
nodes_
.
size
();
}
size_t
size
()
const
{
return
nodes_
.
size
();
}
void
Clear
()
{
mark_table_
.
clear
();
nodes_
.
clear
();
}
private:
private:
// for searching.
// for searching.
std
::
unordered_map
<
std
::
string
,
Iter
>
mark_table_
;
std
::
unordered_map
<
std
::
string
,
Iter
>
mark_table_
;
...
@@ -74,14 +83,53 @@ class OrderedNodePairPool {
...
@@ -74,14 +83,53 @@ class OrderedNodePairPool {
std
::
list
<
NodePair
>
nodes_
;
std
::
list
<
NodePair
>
nodes_
;
};
};
// valid a tensor can be reuse or not
bool
NodeCanReused
(
ir
::
Node
*
node
);
// valid a tensor can be reuse or not.
bool
NodeCanReused
(
const
VarDesc
&
node
);
// check op has subblock or not
bool
OpHasSubBlock
(
OpDesc
*
desc
);
// node memory size in bytes
// node memory size in bytes
size_t
NodeSizeInBytes
(
ir
::
Node
*
n
);
size_t
NodeSizeInBytes
(
ir
::
Node
*
n
);
// node memory size in bytes
size_t
NodeSizeInBytes
(
const
VarDesc
&
);
std
::
string
DebugString
(
ir
::
Node
*
var
);
std
::
string
DebugString
(
ir
::
Node
*
var
);
// std::string DebugString(VarDesc* var);
VarDesc
*
FindVarDescInBlock
(
ir
::
Node
*
n
);
VarDesc
*
FindVarDescInBlock
(
ir
::
Node
*
n
);
template
<
typename
Container
,
typename
Callback
>
class
FilterVariableImpl
{
public:
void
operator
()(
const
Container
&
nodes
,
Callback
callback
)
{
for
(
auto
*
node
:
nodes
)
{
callback
(
node
);
}
}
};
// filter var node for op->inputs/outputs
template
<
typename
Callback
>
class
FilterVariableImpl
<
std
::
vector
<
ir
::
Node
*>
,
Callback
>
{
public:
void
operator
()(
const
std
::
vector
<
ir
::
Node
*>&
nodes
,
Callback
callback
)
{
for
(
auto
*
var
:
nodes
)
{
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
callback
(
var
);
}
}
}
};
template
<
typename
Container
,
typename
Callback
>
void
FilterVariables
(
const
Container
&
nodes
,
Callback
callback
)
{
FilterVariableImpl
<
Container
,
Callback
>
()(
nodes
,
callback
);
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/memory_
reuse_types
_test.cc
→
paddle/fluid/framework/details/memory_
optimize_helper
_test.cc
浏览文件 @
74bc55c2
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
#include <memory>
#include <memory>
...
@@ -27,8 +27,8 @@ namespace paddle {
...
@@ -27,8 +27,8 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
TEST
(
OrderedNode
PairPool
,
Normal
)
{
TEST
(
OrderedNode
List
,
Normal
)
{
OrderedNode
PairPool
pool
;
OrderedNode
List
pool
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
// clang-format off
// clang-format off
...
...
paddle/fluid/framework/details/
analysis_var
_pass.cc
→
paddle/fluid/framework/details/
memory_optimize
_pass.cc
浏览文件 @
74bc55c2
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/
analysis_var
_pass.h"
#include "paddle/fluid/framework/details/
memory_optimize
_pass.h"
#include <algorithm>
#include <algorithm>
#include <atomic>
#include <atomic>
#include <deque>
#include <deque>
...
@@ -48,35 +48,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
...
@@ -48,35 +48,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
op1
->
Outputs
()
==
op2
->
Outputs
();
op1
->
Outputs
()
==
op2
->
Outputs
();
}
}
template
<
typename
Container
,
typename
Callback
>
std
::
unique_ptr
<
ir
::
Graph
>
MemoryOptimizePass
::
ApplyImpl
(
class
FilterVariableImpl
{
public:
void
operator
()(
const
Container
&
nodes
,
Callback
callback
)
{
for
(
auto
*
node
:
nodes
)
{
callback
(
node
);
}
}
};
// filter var node for op->inputs/outputs
template
<
typename
Callback
>
class
FilterVariableImpl
<
std
::
vector
<
ir
::
Node
*>
,
Callback
>
{
public:
void
operator
()(
const
std
::
vector
<
ir
::
Node
*>&
nodes
,
Callback
callback
)
{
for
(
auto
*
var
:
nodes
)
{
if
(
var
->
IsVar
()
&&
!
var
->
IsCtrlVar
())
{
callback
(
var
);
}
}
}
};
template
<
typename
Container
,
typename
Callback
>
void
FilterVariables
(
const
Container
&
nodes
,
Callback
callback
)
{
FilterVariableImpl
<
Container
,
Callback
>
()(
nodes
,
callback
);
}
std
::
unique_ptr
<
ir
::
Graph
>
AnalysisVarPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
nodes
=
graph
->
Nodes
();
auto
nodes
=
graph
->
Nodes
();
auto
subblock_vars
=
GetSubBlockVars
(
nodes
);
auto
subblock_vars
=
GetSubBlockVars
(
nodes
);
...
@@ -103,8 +75,11 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
...
@@ -103,8 +75,11 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
}
}
for
(
auto
&
var
:
op
->
outputs
)
{
for
(
auto
&
var
:
op
->
outputs
)
{
if
(
NodeCanReused
(
var
)
&&
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
)
{
if
(
!
NodeCanReused
(
var
)
||
cfg_
->
Use
(
op
).
count
(
var
->
Name
())
==
0
||
skip_set_
.
count
(
var
->
Name
()))
continue
;
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
ir
::
Node
*
cache
=
pool_
.
NodeMatch
(
var
);
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
if
(
var
->
Name
()
==
FLAGS_memory_optimize_debug
)
{
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
VLOG
(
3
)
<<
"start match var "
<<
DebugString
(
var
)
<<
" of op "
<<
op
->
Name
();
<<
op
->
Name
();
...
@@ -112,14 +87,14 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
...
@@ -112,14 +87,14 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
VLOG
(
3
)
<<
"matched in pool : "
VLOG
(
3
)
<<
"matched in pool : "
<<
((
cache
==
nullptr
)
?
"False"
:
"True"
);
<<
((
cache
==
nullptr
)
?
"False"
:
"True"
);
}
}
if
(
cache
!=
nullptr
)
{
if
(
cache
==
nullptr
)
continue
;
if
(
var
->
Name
()
==
cache
->
Name
())
{
if
(
var
->
Name
()
==
cache
->
Name
())
{
VLOG
(
3
)
<<
"The same cache variable is cascade reused."
VLOG
(
3
)
<<
"The same cache variable is cascade reused."
<<
var
->
Name
()
<<
var
->
Name
()
<<
" is re-filled to the pool after"
<<
" is re-filled to the pool after"
<<
"the reused op is finished. Current op can not "
<<
"the reused op is finished. Current op can not "
<<
"replace it again. Skip this candidate."
;
<<
"replace it again. Skip this candidate."
;
continue
;
continue
;
}
int
node_idx_in_pool
=
pool_
.
GetIndex
(
cache
);
int
node_idx_in_pool
=
pool_
.
GetIndex
(
cache
);
VLOG
(
3
)
<<
string
::
Sprintf
(
VLOG
(
3
)
<<
string
::
Sprintf
(
...
@@ -138,13 +113,15 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
...
@@ -138,13 +113,15 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
pool_
.
Erase
(
cache
);
pool_
.
Erase
(
cache
);
}
}
}
}
// fill the pool
// fill the pool
std
::
unordered_set
<
std
::
string
>
unlived_vars
;
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
for
(
auto
var
:
cfg_
->
LiveIn
(
op
))
{
if
(
cfg_
->
LiveOut
(
op
).
count
(
var
)
==
0
)
{
if
(
cfg_
->
LiveOut
(
op
).
count
(
var
)
==
0
)
{
unlived_vars
.
emplace
(
var
);
}
}
for
(
auto
var
:
unlived_vars
)
{
ir
::
Node
*
var_node
=
cfg_
->
GetNodeFromVarName
(
var
,
op
);
ir
::
Node
*
var_node
=
cfg_
->
GetNodeFromVarName
(
var
,
op
);
if
(
var_node
==
nullptr
)
continue
;
if
(
NodeCanReused
(
var_node
)
&&
!
pool_
.
Has
(
var_node
))
{
if
(
NodeCanReused
(
var_node
)
&&
!
pool_
.
Has
(
var_node
))
{
pool_
.
Insert
(
var_node
,
op
);
pool_
.
Insert
(
var_node
,
op
);
}
}
...
@@ -177,7 +154,7 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
...
@@ -177,7 +154,7 @@ std::unique_ptr<ir::Graph> AnalysisVarPass::ApplyImpl(
return
graph
;
return
graph
;
}
}
void
AnalysisVar
Pass
::
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
{
void
MemoryOptimize
Pass
::
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
{
// conditional block, while op and their grad op
// conditional block, while op and their grad op
auto
*
sub_block_desc
=
auto
*
sub_block_desc
=
AttrReader
(
op_desc
->
GetAttrMap
()).
Get
<
BlockDesc
*>
(
"sub_block"
);
AttrReader
(
op_desc
->
GetAttrMap
()).
Get
<
BlockDesc
*>
(
"sub_block"
);
...
@@ -247,7 +224,7 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const {
...
@@ -247,7 +224,7 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const {
}
}
}
}
std
::
unordered_set
<
std
::
string
>
AnalysisVar
Pass
::
GetSubBlockVars
(
std
::
unordered_set
<
std
::
string
>
MemoryOptimize
Pass
::
GetSubBlockVars
(
const
std
::
unordered_set
<
ir
::
Node
*>&
nodes
)
const
{
const
std
::
unordered_set
<
ir
::
Node
*>&
nodes
)
const
{
std
::
unordered_set
<
std
::
string
>
vars
;
std
::
unordered_set
<
std
::
string
>
vars
;
for
(
auto
&
op
:
nodes
)
{
for
(
auto
&
op
:
nodes
)
{
...
@@ -263,7 +240,7 @@ std::unordered_set<std::string> AnalysisVarPass::GetSubBlockVars(
...
@@ -263,7 +240,7 @@ std::unordered_set<std::string> AnalysisVarPass::GetSubBlockVars(
return
vars
;
return
vars
;
}
}
void
AnalysisVar
Pass
::
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
void
MemoryOptimize
Pass
::
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
std
::
string
&
cache_var
,
size_t
idx
)
const
{
size_t
idx
)
const
{
for
(
size_t
i
=
idx
;
i
<
cfg_
->
Ops
().
size
();
++
i
)
{
for
(
size_t
i
=
idx
;
i
<
cfg_
->
Ops
().
size
();
++
i
)
{
...
@@ -277,7 +254,7 @@ void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var,
...
@@ -277,7 +254,7 @@ void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var,
}
}
}
}
void
AnalysisVar
Pass
::
InitSSAGraphNodes
()
const
{
void
MemoryOptimize
Pass
::
InitSSAGraphNodes
()
const
{
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ir
::
Node
*>>
all_vars
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ir
::
Node
*>>
all_vars
;
if
(
var_nodes_
.
empty
())
{
if
(
var_nodes_
.
empty
())
{
for
(
auto
*
op
:
cfg_
->
Ops
())
{
for
(
auto
*
op
:
cfg_
->
Ops
())
{
...
@@ -297,9 +274,10 @@ void AnalysisVarPass::InitSSAGraphNodes() const {
...
@@ -297,9 +274,10 @@ void AnalysisVarPass::InitSSAGraphNodes() const {
}
}
}
}
void
AnalysisVar
Pass
::
RenameVarInGraphNode
(
const
std
::
string
&
var
,
void
MemoryOptimize
Pass
::
RenameVarInGraphNode
(
const
std
::
string
&
var
,
const
std
::
string
&
cache_var
,
const
std
::
string
&
cache_var
,
size_t
idx
,
ir
::
Graph
*
graph
)
const
{
size_t
idx
,
ir
::
Graph
*
graph
)
const
{
// if replace happens, we need to create a newer version cache_var
// if replace happens, we need to create a newer version cache_var
// but use the same dims/data_type with var.
// but use the same dims/data_type with var.
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
PADDLE_ENFORCE
(
var_nodes_
[
var
].
size
()
>=
1
&&
...
@@ -358,39 +336,6 @@ void AnalysisVarPass::RenameVarInGraphNode(const std::string& var,
...
@@ -358,39 +336,6 @@ void AnalysisVarPass::RenameVarInGraphNode(const std::string& var,
var_nodes_
.
at
(
var
).
clear
();
var_nodes_
.
at
(
var
).
clear
();
}
}
bool
AnalysisVarPass
::
NodeCanReused
(
ir
::
Node
*
node
)
const
{
if
(
!
node
->
IsVar
()
||
node
->
IsCtrlVar
())
return
false
;
auto
*
desc
=
node
->
Var
();
auto
type
=
desc
->
GetType
();
if
(
desc
->
Persistable
()
||
type
!=
proto
::
VarType
::
LOD_TENSOR
||
desc
->
GetShape
().
empty
())
{
return
false
;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std
::
string
name
=
node
->
Name
();
if
(
!
name
.
empty
()
&&
name
[
0
]
==
'@'
&&
name
[
name
.
size
()
-
1
]
==
'@'
)
return
false
;
if
(
skip_set_
.
count
(
name
))
return
false
;
for
(
auto
*
op
:
node
->
inputs
)
{
if
(
op
->
Op
()
->
HasAttr
(
"force_cpu"
))
{
// op output force generated in cpu, can not be reused.
return
framework
::
AttrReader
(
op
->
Op
()
->
GetAttrMap
())
.
Get
<
bool
>
(
"force_cpu"
)
==
0
;
}
}
return
true
;
}
bool
AnalysisVarPass
::
OpHasSubBlock
(
OpDesc
*
desc
)
const
{
const
AttributeMap
&
attrs
=
desc
->
GetAttrMap
();
for
(
auto
&
attr
:
attrs
)
{
if
(
attr
.
second
.
type
()
==
typeid
(
BlockDesc
*
)
||
// NOLINT
attr
.
second
.
type
()
==
typeid
(
std
::
vector
<
BlockDesc
*>
))
// NOLINT
return
true
;
}
return
false
;
}
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
PADDLE_ENFORCE
(
graph
.
Has
(
kAllOpDescs
),
PADDLE_ENFORCE
(
graph
.
Has
(
kAllOpDescs
),
"Graph has no attribute of kAllOpDescs."
);
"Graph has no attribute of kAllOpDescs."
);
...
@@ -651,6 +596,7 @@ ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name,
...
@@ -651,6 +596,7 @@ ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name,
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
analysis_var_pass
,
paddle
::
framework
::
details
::
AnalysisVarPass
)
REGISTER_PASS
(
memory_optimize_pass
,
paddle
::
framework
::
details
::
MemoryOptimizePass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphNodePool
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphNodePool
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
paddle/fluid/framework/details/
analysis_var
_pass.h
→
paddle/fluid/framework/details/
memory_optimize
_pass.h
浏览文件 @
74bc55c2
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_
reuse_types
.h"
#include "paddle/fluid/framework/details/memory_
optimize_helper
.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -35,12 +35,10 @@ namespace details {
...
@@ -35,12 +35,10 @@ namespace details {
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
);
// sort op in bfs order
std
::
vector
<
ir
::
Node
*>
BFSSortGraphOps
(
const
ir
::
Graph
&
graph
);
class
ControlFlowGraph
;
class
ControlFlowGraph
;
class
AnalysisVar
Pass
:
public
ir
::
Pass
{
class
MemoryOptimize
Pass
:
public
ir
::
Pass
{
protected:
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
...
@@ -57,17 +55,13 @@ class AnalysisVarPass : public ir::Pass {
...
@@ -57,17 +55,13 @@ class AnalysisVarPass : public ir::Pass {
ir
::
Graph
*
graph
)
const
;
ir
::
Graph
*
graph
)
const
;
void
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
;
void
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
;
// valid a tensor can be reuse or not
bool
NodeCanReused
(
ir
::
Node
*
node
)
const
;
// scan subblock and collect the output/input variables.
// scan subblock and collect the output/input variables.
std
::
unordered_set
<
std
::
string
>
GetSubBlockVars
(
std
::
unordered_set
<
std
::
string
>
GetSubBlockVars
(
const
std
::
unordered_set
<
ir
::
Node
*>&
)
const
;
const
std
::
unordered_set
<
ir
::
Node
*>&
)
const
;
// check op has subblock or not
bool
OpHasSubBlock
(
OpDesc
*
desc
)
const
;
private:
private:
// Reuse Node Pool, Owned.
// Reuse Node Pool, Owned.
mutable
OrderedNode
PairPool
pool_
;
mutable
OrderedNode
List
pool_
;
// controlflow Graph
// controlflow Graph
mutable
std
::
unique_ptr
<
ControlFlowGraph
>
cfg_
;
mutable
std
::
unique_ptr
<
ControlFlowGraph
>
cfg_
;
// skip set
// skip set
...
...
paddle/fluid/framework/details/
analysis_var
_pass_test.cc
→
paddle/fluid/framework/details/
memory_optimize
_pass_test.cc
浏览文件 @
74bc55c2
...
@@ -12,63 +12,19 @@
...
@@ -12,63 +12,19 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/
analysis_var
_pass.h"
#include "paddle/fluid/framework/details/
memory_optimize
_pass.h"
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
#include <iterator>
#include <iterator>
#include "glog/logging.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
class
DummyOp
:
public
OperatorBase
{
public:
DummyOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
SumOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
AssignOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
DummyVarTypeInference
:
public
VarTypeInference
{
public:
void
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
type
=
block
->
Var
(
inputs
.
front
())
->
GetType
();
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
block
->
Var
(
out_var_name
)
->
SetType
(
type
);
}
};
}
// namespace framework
}
// namespace paddle
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
REGISTER_OPERATOR
(
sum
,
paddle
::
framework
::
DummyOp
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
SumOpMaker
,
paddle
::
framework
::
DummyVarTypeInference
);
paddle
::
framework
::
DummyVarTypeInference
);
...
@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() {
...
@@ -141,15 +97,6 @@ inline static ProgramDesc FillProgramDesc() {
return
prog
;
return
prog
;
}
}
template
<
typename
Container
>
inline
static
std
::
string
DebugString
(
const
Container
&
c
)
{
std
::
stringstream
ss
;
for
(
auto
&
item
:
c
)
{
ss
<<
item
<<
" "
;
}
return
ss
.
str
();
}
TEST
(
CFGGraph
,
IRGraph
)
{
TEST
(
CFGGraph
,
IRGraph
)
{
// prepare ir graph
// prepare ir graph
auto
prog
=
FillProgramDesc
();
auto
prog
=
FillProgramDesc
();
...
...
paddle/fluid/framework/details/op_registry.h
浏览文件 @
74bc55c2
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <tuple>
#include <tuple>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
...
@@ -32,7 +33,8 @@ enum OpInfoFillType {
...
@@ -32,7 +33,8 @@ enum OpInfoFillType {
kOpProtoAndCheckerMaker
=
1
,
kOpProtoAndCheckerMaker
=
1
,
kGradOpDescMaker
=
2
,
kGradOpDescMaker
=
2
,
kVarTypeInference
=
3
,
kVarTypeInference
=
3
,
kShapeInference
=
4
kShapeInference
=
4
,
kInplaceOpInference
=
5
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -48,8 +50,11 @@ struct OpInfoFillTypeID {
...
@@ -48,8 +50,11 @@ struct OpInfoFillTypeID {
?
kVarTypeInference
?
kVarTypeInference
:
(
std
::
is_base_of
<
InferShapeBase
,
T
>::
value
:
(
std
::
is_base_of
<
InferShapeBase
,
T
>::
value
?
kShapeInference
?
kShapeInference
:
(
std
::
is_base_of
<
InplaceOpInference
,
T
>::
value
?
kInplaceOpInference
:
static_cast
<
OpInfoFillType
>
(
:
static_cast
<
OpInfoFillType
>
(
-
1
)))));
-
1
)
)))));
}
}
};
};
...
@@ -139,6 +144,16 @@ struct OpInfoFiller<T, kShapeInference> {
...
@@ -139,6 +144,16 @@ struct OpInfoFiller<T, kShapeInference> {
}
}
};
};
template
<
typename
T
>
struct
OpInfoFiller
<
T
,
kInplaceOpInference
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
infer_inplace_
=
[](
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
{
T
infer
;
return
infer
(
op_desc
,
block
);
};
}
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/inplace_op_inference.h
0 → 100644
浏览文件 @
74bc55c2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <numeric>
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
namespace
paddle
{
namespace
framework
{
/*
Inplace Inference for create In->Out pairs for inplaced operator.
If we specify a pair of corresponding names. For example, X->Out.
then Out will inplaced use X's memory. The base class will do
legality validation for both variables.
*/
class
InplaceOpInference
{
public:
virtual
~
InplaceOpInference
()
{}
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
=
0
;
};
class
InplaceInToOut
:
public
InplaceOpInference
{
public:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ret
;
auto
in_out_var_names_pair
=
this
->
Apply
(
op_desc
,
block
);
for
(
auto
&
pair
:
in_out_var_names_pair
)
{
PADDLE_ENFORCE
(
!
op_desc
.
Input
(
pair
.
first
).
empty
(),
string
::
Sprintf
(
"op %s do not have input of %s!"
,
op_desc
.
Type
(),
pair
.
first
));
PADDLE_ENFORCE
(
!
op_desc
.
Output
(
pair
.
second
).
empty
(),
string
::
Sprintf
(
"op %s do not have output of %s!"
,
op_desc
.
Type
(),
pair
.
second
));
auto
&
in_name
=
op_desc
.
Input
(
pair
.
first
).
at
(
0
);
auto
&
out_name
=
op_desc
.
Output
(
pair
.
second
).
at
(
0
);
auto
in
=
block
->
FindRecursiveOrCreateVar
(
in_name
);
auto
out
=
block
->
FindRecursiveOrCreateVar
(
out_name
);
if
(
TryInplaceInputOutput
(
in
,
out
))
ret
.
insert
({
in_name
,
out_name
});
}
return
ret
;
}
protected:
virtual
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
=
0
;
bool
TryInplaceInputOutput
(
const
VarDesc
&
in
,
const
VarDesc
&
out
)
const
{
return
in
.
Name
()
!=
out
.
Name
()
&&
details
::
NodeCanReused
(
in
)
&&
details
::
NodeCanReused
(
out
)
&&
details
::
NodeSizeInBytes
(
out
)
<=
details
::
NodeSizeInBytes
(
in
);
}
};
/*
Inplace In and Out for operator only have an Input and an Output.
For example, activation op.
*/
class
SingleOpInplaceInToOut
:
public
InplaceInToOut
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
PADDLE_ENFORCE
(
!
op_desc
.
InputNames
().
empty
(),
"Op inputs must not be empty"
);
PADDLE_ENFORCE
(
!
op_desc
.
OutputNames
().
empty
(),
"Op outputs must not be empty"
);
auto
x_name
=
op_desc
.
InputNames
().
at
(
0
);
auto
out_name
=
op_desc
.
OutputNames
().
at
(
0
);
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
x_name
,
out_name
}};
}
};
/*
Gradient op. Inplace output use it's Input.
For example, Input@Grad->Input reuse strategy.
*/
class
GradOpInplaceInToOut
:
public
InplaceInToOut
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ret
;
std
::
unordered_set
<
std
::
string
>
output_names
(
op_desc
.
OutputNames
().
begin
(),
op_desc
.
OutputNames
().
end
());
for
(
auto
&
input_name
:
op_desc
.
InputNames
())
{
if
(
output_names
.
count
(
GradVarName
(
input_name
)))
{
ret
.
insert
({
input_name
,
GradVarName
(
input_name
)});
}
}
return
ret
;
}
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/inplace_op_inference_test.cc
0 → 100644
浏览文件 @
74bc55c2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iterator>
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace
paddle
{
namespace
framework
{
class
NOP
:
public
OperatorBase
{
public:
NOP
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{}
};
class
SingleOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
SingleGradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"single_op_grad"
);
op
->
SetInput
(
"Out"
,
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
return
std
::
unique_ptr
<
OpDesc
>
(
op
);
}
};
class
SingleOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
HasInput
(
"X"
);
ctx
->
HasOutput
(
"Out"
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
}
};
class
SingleGradOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
));
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"Out"
));
}
};
class
MultiOutOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
""
).
AsDuplicable
();
AddInput
(
"Y"
,
""
).
AsDuplicable
();
AddInput
(
"Z"
,
""
).
AsDuplicable
();
AddOutput
(
"Out"
,
""
);
AddOutput
(
"YOut"
,
""
);
AddOutput
(
"ZOut"
,
""
);
AddOutput
(
"NotReuseOut"
,
""
);
AddComment
(
""
);
}
};
class
MultiOutShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
ShareDim
(
"X"
,
"Out"
);
ctx
->
ShareDim
(
"Y"
,
"YOut"
);
ctx
->
ShareDim
(
"Z"
,
"ZOut"
);
}
};
class
MultiGradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"multi_out_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
OutputGrad
(
"YOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Z"
),
OutputGrad
(
"ZOut"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
class
MultiOutGradShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Y"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"YOut"
)));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
)));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Z"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"ZOut"
)));
}
};
class
MultiOutInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
"X"
,
"Out"
},
{
"Y"
,
"YOut"
},
{
"Z"
,
"ZOut"
},
};
}
};
class
MultiOutGradInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
framework
::
GradVarName
(
"YOut"
),
framework
::
GradVarName
(
"Y"
)},
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)},
{
framework
::
GradVarName
(
"ZOut"
),
framework
::
GradVarName
(
"Z"
)},
};
}
};
}
// namespace framework
}
// namespace paddle
namespace
f
=
paddle
::
framework
;
REGISTER_OPERATOR
(
single_op
,
f
::
NOP
,
f
::
SingleOpMaker
,
f
::
SingleGradOpMaker
,
f
::
SingleOpInplaceInToOut
,
f
::
SingleOpShapeInference
);
REGISTER_OPERATOR
(
single_op_grad
,
f
::
NOP
,
f
::
SingleOpInplaceInToOut
,
f
::
SingleGradOpShapeInference
);
REGISTER_OPERATOR
(
multi_out_op
,
f
::
NOP
,
f
::
MultiOutOpMaker
,
f
::
MultiGradOpMaker
,
f
::
MultiOutInplaceInToOut
,
f
::
MultiOutShapeInference
);
REGISTER_OPERATOR
(
multi_out_grad
,
f
::
NOP
,
f
::
MultiOutGradInplaceInToOut
,
f
::
MultiOutGradShapeInference
);
namespace
paddle
{
namespace
framework
{
TEST
(
InferInplace
,
SingleOpInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"single_op"
);
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
op
->
SetOutput
(
"Out"
,
{
"test2_out"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetShape
({
32
,
64
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
1ul
);
auto
it
=
in_to_outs
.
begin
();
EXPECT_EQ
(
it
->
first
,
"test2_a"
);
EXPECT_EQ
(
it
->
second
,
"test2_out"
);
}
TEST
(
InferInplace
,
SingleGradOpInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"single_op_grad"
);
op
->
SetInput
(
GradVarName
(
"Out"
),
{
"test2_out"
});
op
->
SetOutput
(
GradVarName
(
"X"
),
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_a"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test2_out"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
1ul
);
auto
it
=
in_to_outs
.
begin
();
EXPECT_EQ
(
it
->
first
,
"test2_out"
);
EXPECT_EQ
(
it
->
second
,
"test2_a"
);
}
TEST
(
InferInplace
,
MultiOutInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"multi_out_op"
);
op
->
SetInput
(
"X"
,
{
"a0"
,
"a1"
});
op
->
SetInput
(
"Y"
,
{
"b0"
});
op
->
SetInput
(
"Z"
,
{
"c0"
,
"c1"
});
op
->
SetOutput
(
"Out"
,
{
"o0"
});
op
->
SetOutput
(
"YOut"
,
{
"y0"
});
op
->
SetOutput
(
"ZOut"
,
{
"z0"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
3ul
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
expects
=
{
{
"a0"
,
"o0"
},
{
"b0"
,
"y0"
},
{
"c0"
,
"z0"
},
};
EXPECT_TRUE
(
expects
==
in_to_outs
);
}
TEST
(
InferInplace
,
MultiGradInplaceInToOut
)
{
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"multi_out_grad"
);
op
->
SetInput
(
GradVarName
(
"Out"
),
{
"o0"
});
op
->
SetInput
(
GradVarName
(
"YOut"
),
{
"y0"
});
op
->
SetInput
(
GradVarName
(
"ZOut"
),
{
"z0"
});
op
->
SetOutput
(
GradVarName
(
"X"
),
{
"a0"
,
"a1"
});
op
->
SetOutput
(
GradVarName
(
"Y"
),
{
"b0"
});
op
->
SetOutput
(
GradVarName
(
"Z"
),
{
"c0"
,
"c1"
});
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c1"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
);
prog
.
MutableBlock
(
0
)
->
Var
(
"a0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"b0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"c0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"o0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"y0"
)
->
SetShape
({
32
,
16
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
});
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
3ul
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
expects
=
{
{
"o0"
,
"a0"
},
{
"y0"
,
"b0"
},
{
"z0"
,
"c0"
},
};
EXPECT_TRUE
(
expects
==
in_to_outs
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
74bc55c2
...
@@ -52,16 +52,29 @@ bool HasCircleHelper(
...
@@ -52,16 +52,29 @@ bool HasCircleHelper(
ir
::
Node
*
node
,
ir
::
Node
*
node
,
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
,
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
,
std
::
unordered_set
<
ir
::
Node
*>
*
visited
,
std
::
unordered_set
<
ir
::
Node
*>
*
visited
,
std
::
unordered_set
<
ir
::
Node
*>
*
in_trace
)
{
std
::
unordered_set
<
ir
::
Node
*>
*
in_trace
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
)
{
if
(
visited
->
find
(
node
)
==
visited
->
end
())
{
if
(
visited
->
find
(
node
)
==
visited
->
end
())
{
visited
->
insert
(
node
);
visited
->
insert
(
node
);
in_trace
->
insert
(
node
);
in_trace
->
insert
(
node
);
for
(
ir
::
Node
*
in
:
adj_list
.
at
(
node
))
{
for
(
ir
::
Node
*
in
:
adj_list
.
at
(
node
))
{
if
(
visited
->
find
(
in
)
==
visited
->
end
()
&&
if
(
visited
->
find
(
in
)
==
visited
->
end
()
&&
HasCircleHelper
(
in
,
adj_list
,
visited
,
in_trace
))
{
HasCircleHelper
(
in
,
adj_list
,
visited
,
in_trace
,
circles
))
{
return
true
;
return
true
;
}
else
if
(
in_trace
->
find
(
in
)
!=
in_trace
->
end
())
{
}
else
if
(
in_trace
->
find
(
in
)
!=
in_trace
->
end
())
{
if
(
circles
!=
nullptr
)
{
std
::
vector
<
ir
::
Node
*>
circle
;
circle
.
emplace_back
(
in
);
ir
::
Node
*
p
=
in
;
for
(
auto
&
adj
:
adj_list
.
at
(
p
))
{
if
(
in_trace
->
count
(
adj
))
{
circle
.
emplace_back
(
adj
);
p
=
adj
;
}
}
circles
->
emplace_back
(
circle
);
}
return
true
;
return
true
;
}
}
}
}
...
@@ -71,11 +84,12 @@ bool HasCircleHelper(
...
@@ -71,11 +84,12 @@ bool HasCircleHelper(
}
}
bool
HasCircleInternal
(
bool
HasCircleInternal
(
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
)
{
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
)
{
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
in_trace
;
std
::
unordered_set
<
ir
::
Node
*>
in_trace
;
for
(
auto
&
adj
:
adj_list
)
{
for
(
auto
&
adj
:
adj_list
)
{
if
(
HasCircleHelper
(
adj
.
first
,
adj_list
,
&
visited
,
&
in_trace
))
{
if
(
HasCircleHelper
(
adj
.
first
,
adj_list
,
&
visited
,
&
in_trace
,
circles
))
{
return
true
;
return
true
;
}
}
}
}
...
@@ -84,13 +98,18 @@ bool HasCircleInternal(
...
@@ -84,13 +98,18 @@ bool HasCircleInternal(
}
// namespace
}
// namespace
bool
HasCircle
(
const
Graph
&
graph
)
{
bool
HasCircle
(
const
Graph
&
graph
)
{
return
HasCircleInternal
(
BuildOperationAdjList
(
graph
));
return
HasCircleInternal
(
BuildOperationAdjList
(
graph
),
nullptr
);
}
bool
FindCircleSubGraph
(
const
Graph
&
graph
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
)
{
return
HasCircleInternal
(
BuildOperationAdjList
(
graph
),
circles
);
}
}
std
::
vector
<
ir
::
Node
*>
TopologySortOperations
(
const
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
TopologySortOperations
(
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
BuildOperationAdjList
(
graph
);
BuildOperationAdjList
(
graph
);
PADDLE_ENFORCE
(
!
HasCircleInternal
(
adj_list
));
PADDLE_ENFORCE
(
!
HasCircleInternal
(
adj_list
,
nullptr
));
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
vector
<
ir
::
Node
*>
ret
;
std
::
vector
<
ir
::
Node
*>
ret
;
for
(
auto
adj
:
adj_list
)
{
for
(
auto
adj
:
adj_list
)
{
...
...
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
74bc55c2
...
@@ -28,6 +28,11 @@ namespace ir {
...
@@ -28,6 +28,11 @@ namespace ir {
// Test if the graph contains circle.
// Test if the graph contains circle.
bool
HasCircle
(
const
Graph
&
graph
);
bool
HasCircle
(
const
Graph
&
graph
);
// Find All Circles for debugging,
// store all subgraph in circles.
bool
FindCircleSubGraph
(
const
Graph
&
graph
,
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
*
circles
);
size_t
GraphNum
(
const
Graph
&
graph
);
size_t
GraphNum
(
const
Graph
&
graph
);
// Topology Sort the operations in the graph from inputs to outputs.
// Topology Sort the operations in the graph from inputs to outputs.
...
...
paddle/fluid/framework/ir/graph_helper_test.cc
浏览文件 @
74bc55c2
...
@@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) {
...
@@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) {
// v4->outputs.push_back(o5);
// v4->outputs.push_back(o5);
}
}
TEST
(
GraphHelperTest
,
Circles
)
{
ProgramDesc
prog
;
Graph
g
(
prog
);
BuildCircleGraph
(
&
g
);
std
::
vector
<
std
::
vector
<
ir
::
Node
*>>
circles
;
ASSERT_TRUE
(
FindCircleSubGraph
(
g
,
&
circles
));
ASSERT_EQ
(
circles
.
size
(),
1UL
);
}
TEST
(
GraphHelperTest
,
GraphNum
)
{
TEST
(
GraphHelperTest
,
GraphNum
)
{
ProgramDesc
prog
;
ProgramDesc
prog
;
...
...
paddle/fluid/framework/op_info.h
浏览文件 @
74bc55c2
...
@@ -38,6 +38,7 @@ struct OpInfo {
...
@@ -38,6 +38,7 @@ struct OpInfo {
OpAttrChecker
*
checker_
{
nullptr
};
OpAttrChecker
*
checker_
{
nullptr
};
InferVarTypeFN
infer_var_type_
;
InferVarTypeFN
infer_var_type_
;
InferShapeFN
infer_shape_
;
InferShapeFN
infer_shape_
;
InferInplaceOpFN
infer_inplace_
;
bool
HasOpProtoAndChecker
()
const
{
bool
HasOpProtoAndChecker
()
const
{
return
proto_
!=
nullptr
&&
checker_
!=
nullptr
;
return
proto_
!=
nullptr
&&
checker_
!=
nullptr
;
...
...
paddle/fluid/framework/type_defs.h
浏览文件 @
74bc55c2
...
@@ -57,5 +57,8 @@ using InferVarTypeFN =
...
@@ -57,5 +57,8 @@ using InferVarTypeFN =
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
using
InplacePair
=
std
::
unordered_map
<
std
::
string
,
std
::
string
>
;
using
InferInplaceOpFN
=
std
::
function
<
InplacePair
(
const
OpDesc
&
,
BlockDesc
*
)
>
;
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/utils/benchmark_tester.cc
浏览文件 @
74bc55c2
...
@@ -34,6 +34,6 @@ TEST(Benchmark, PersistToFile) {
...
@@ -34,6 +34,6 @@ TEST(Benchmark, PersistToFile) {
benchmark
.
SetLatency
(
220
);
benchmark
.
SetLatency
(
220
);
benchmark
.
PersistToFile
(
"1.log"
);
benchmark
.
PersistToFile
(
"1.log"
);
benchmark
.
PersistToFile
(
"
1
.log"
);
benchmark
.
PersistToFile
(
"
2
.log"
);
benchmark
.
PersistToFile
(
"
1
.log"
);
benchmark
.
PersistToFile
(
"
3
.log"
);
}
}
paddle/fluid/operators/activation_op.cc
浏览文件 @
74bc55c2
...
@@ -551,8 +551,10 @@ namespace ops = paddle::operators;
...
@@ -551,8 +551,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker); \
::paddle::operators::OP_NAME##GradMaker, \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
74bc55c2
...
@@ -604,13 +604,48 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -604,13 +604,48 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
}
}
};
};
class
BatchNormInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
"Mean"
,
"MeanOut"
},
{
"Variance"
,
"VarianceOut"
},
{
"X"
,
"Y"
},
};
return
inplace_in_to_out
;
}
};
class
BatchNormGradInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
{
framework
::
GradVarName
(
"Y"
),
framework
::
GradVarName
(
"X"
)},
{
"SavedMean"
,
framework
::
GradVarName
(
"Scale"
)},
{
"SavedVariance"
,
framework
::
GradVarName
(
"Bias"
)},
};
return
inplace_in_to_out
;
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
batch_norm
,
ops
::
BatchNormOp
,
ops
::
BatchNormOpMaker
,
REGISTER_OPERATOR
(
batch_norm
,
ops
::
BatchNormOp
,
ops
::
BatchNormOpMaker
,
ops
::
BatchNormOpInferVarType
,
ops
::
BatchNormGradMaker
);
ops
::
BatchNormOpInferVarType
,
ops
::
BatchNormGradMaker
,
REGISTER_OPERATOR
(
batch_norm_grad
,
ops
::
BatchNormGradOp
);
ops
::
BatchNormInplaceInToOut
);
REGISTER_OPERATOR
(
batch_norm_grad
,
ops
::
BatchNormGradOp
,
ops
::
BatchNormGradInplaceInToOut
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
batch_norm
,
ops
::
BatchNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
batch_norm
,
ops
::
BatchNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.cc
浏览文件 @
74bc55c2
...
@@ -18,6 +18,7 @@ namespace ops = paddle::operators;
...
@@ -18,6 +18,7 @@ namespace ops = paddle::operators;
REGISTER_ELEMWISE_GRAD_MAKER
(
elementwise_add
,
Add
);
REGISTER_ELEMWISE_GRAD_MAKER
(
elementwise_add
,
Add
);
REGISTER_ELEMWISE_EXPLICIT_OP
(
elementwise_add
,
"Add"
,
"Out = X + Y"
,
"Out"
,
REGISTER_ELEMWISE_EXPLICIT_OP
(
elementwise_add
,
"Add"
,
"Out = X + Y"
,
"Out"
,
"X"
);
"X"
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
elementwise_add
,
elementwise_add
,
ops
::
ElementwiseAddKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseAddKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
74bc55c2
...
@@ -250,6 +250,20 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
...
@@ -250,6 +250,20 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
}
}
};
};
class
ElementwiseOpInplace
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
"X"
,
"Out"
},
};
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -299,6 +313,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
...
@@ -299,6 +313,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker); \
op_type##GradMaker, \
::paddle::operators::ElementwiseOpInplace); \
REGISTER_OPERATOR(op_type##_grad, \
REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad)
::paddle::operators::ElementwiseOpExplicitGrad)
paddle/fluid/operators/flatten_op.cc
浏览文件 @
74bc55c2
...
@@ -267,6 +267,35 @@ class Flatten2GradOp : public framework::OperatorBase {
...
@@ -267,6 +267,35 @@ class Flatten2GradOp : public framework::OperatorBase {
}
}
};
};
class
FlattenOpInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
"X"
,
"Out"
},
};
return
inplace_in_to_out
;
}
};
class
FlattenGradInplaceinToOut
:
public
framework
::
InplaceInToOut
{
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)},
};
return
inplace_in_to_out
;
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -275,10 +304,13 @@ USE_OP(reshape);
...
@@ -275,10 +304,13 @@ USE_OP(reshape);
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
flatten
,
ops
::
FlattenOp
,
ops
::
FlattenOpMaker
,
REGISTER_OPERATOR
(
flatten
,
ops
::
FlattenOp
,
ops
::
FlattenOpMaker
,
ops
::
FlattenOpInferShape
,
ops
::
FlattenOpInferShape
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
,
REGISTER_OPERATOR
(
flatten_grad
,
ops
::
FlattenGradOp
,
ops
::
FlattenGradInferShape
);
ops
::
FlattenOpInplaceInToOut
);
REGISTER_OPERATOR
(
flatten_grad
,
ops
::
FlattenGradOp
,
ops
::
FlattenGradInferShape
,
ops
::
FlattenGradInplaceinToOut
);
REGISTER_OPERATOR
(
flatten2
,
ops
::
Flatten2Op
,
ops
::
Flatten2OpMaker
,
REGISTER_OPERATOR
(
flatten2
,
ops
::
Flatten2Op
,
ops
::
Flatten2OpMaker
,
ops
::
Flatten2OpInferShape
,
ops
::
Flatten2GradOpMaker
);
ops
::
Flatten2OpInferShape
,
ops
::
Flatten2GradOpMaker
,
ops
::
FlattenOpInplaceInToOut
);
REGISTER_OPERATOR
(
flatten2_grad
,
ops
::
Flatten2GradOp
,
REGISTER_OPERATOR
(
flatten2_grad
,
ops
::
Flatten2GradOp
,
ops
::
Flatten2GradInferShape
);
ops
::
Flatten2GradInferShape
,
ops
::
FlattenGradInplaceinToOut
);
paddle/fluid/operators/reshape_op.cc
浏览文件 @
74bc55c2
...
@@ -327,14 +327,45 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
...
@@ -327,14 +327,45 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
}
}
};
};
class
ReshapeOpInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
"X"
,
"Out"
},
};
return
inplace_in_to_out
;
}
};
class
ReshapeGradInplaceInToOut
:
public
framework
::
InplaceInToOut
{
using
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
inplace_in_to_out
=
{
{
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
)},
};
return
inplace_in_to_out
;
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OPERATOR
(
reshape
,
ops
::
ReshapeOp
,
ops
::
ReshapeOpMaker
,
REGISTER_OPERATOR
(
reshape
,
ops
::
ReshapeOp
,
ops
::
ReshapeOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
,
REGISTER_OPERATOR
(
reshape_grad
,
ops
::
ReshapeGradOp
);
ops
::
ReshapeOpInplaceInToOut
);
REGISTER_OPERATOR
(
reshape_grad
,
ops
::
ReshapeGradOp
,
ops
::
ReshapeGradInplaceInToOut
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
int64_t
,
ops
::
ReshapeKernel
);
...
@@ -344,8 +375,9 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
...
@@ -344,8 +375,9 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops
::
ReshapeGradKernel
);
ops
::
ReshapeGradKernel
);
REGISTER_OPERATOR
(
reshape2
,
ops
::
Reshape2Op
,
ops
::
Reshape2OpMaker
,
REGISTER_OPERATOR
(
reshape2
,
ops
::
Reshape2Op
,
ops
::
Reshape2OpMaker
,
ops
::
Reshape2GradMaker
);
ops
::
Reshape2GradMaker
,
ops
::
ReshapeOpInplaceInToOut
);
REGISTER_OPERATOR
(
reshape2_grad
,
ops
::
Reshape2GradOp
);
REGISTER_OPERATOR
(
reshape2_grad
,
ops
::
Reshape2GradOp
,
ops
::
ReshapeGradInplaceInToOut
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape2
,
float
,
ops
::
ReshapeKernel
,
double
,
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape2
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
int64_t
,
ops
::
ReshapeKernel
);
...
...
paddle/fluid/operators/scale_op.cc
浏览文件 @
74bc55c2
...
@@ -100,13 +100,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -100,13 +100,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
}
}
};
};
using
ScaleOpInplace
=
framework
::
SingleOpInplaceInToOut
;
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
,
ops
::
ScaleGradMaker
,
REGISTER_OPERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
,
ops
::
ScaleGradMaker
,
ops
::
ScaleOpVarTypeInference
);
ops
::
ScaleOpVarTypeInference
,
ops
::
ScaleOpInplace
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
74bc55c2
...
@@ -198,6 +198,21 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -198,6 +198,21 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
}
};
};
class
SoftmaxInplaceInToOut
:
public
framework
::
InplaceInToOut
{
public:
using
framework
::
InplaceInToOut
::
InplaceInToOut
;
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Apply
(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{
{
"X"
,
"Out"
},
};
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
74bc55c2
...
@@ -1096,6 +1096,10 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1096,6 +1096,10 @@ All parameter, weight, gradient are variables in Paddle.
"memory_early_delete"
,
"memory_early_delete"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
memory_early_delete_
=
b
;
})
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
memory_early_delete_
=
b
;
})
.
def_property
(
"enable_inplace"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_inplace_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_inplace_
=
b
;
})
.
def
(
"_finalize_strategy_and_create_passes"
,
.
def
(
"_finalize_strategy_and_create_passes"
,
[](
BuildStrategy
&
self
)
->
std
::
shared_ptr
<
ir
::
PassBuilder
>
{
[](
BuildStrategy
&
self
)
->
std
::
shared_ptr
<
ir
::
PassBuilder
>
{
return
self
.
CreatePassesFromStrategy
(
true
);
return
self
.
CreatePassesFromStrategy
(
true
);
...
...
python/paddle/fluid/__init__.py
浏览文件 @
74bc55c2
...
@@ -158,7 +158,8 @@ def __bootstrap__():
...
@@ -158,7 +158,8 @@ def __bootstrap__():
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
,
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
,
'sync_nccl_allreduce'
,
'limit_of_tmp_allocation'
,
'sync_nccl_allreduce'
,
'limit_of_tmp_allocation'
,
'times_excess_than_required_tmp_allocation'
'times_excess_than_required_tmp_allocation'
,
'enable_inplace_whitelist'
]
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
...
...
python/paddle/fluid/compiler.py
浏览文件 @
74bc55c2
...
@@ -174,6 +174,11 @@ class CompiledProgram(object):
...
@@ -174,6 +174,11 @@ class CompiledProgram(object):
self
.
_exec_strategy
.
num_threads
=
cpu_num
*
2
self
.
_exec_strategy
.
num_threads
=
cpu_num
*
2
trainers_endpoints
=
self
.
_program
.
_trainers_endpoints
trainers_endpoints
=
self
.
_program
.
_trainers_endpoints
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
self
.
_build_strategy
.
enable_inplace
=
False
if
self
.
_program
.
_is_mem_optimized
else
True
if
self
.
_build_strategy
.
num_trainers
>
1
and
trainers_endpoints
:
if
self
.
_build_strategy
.
num_trainers
>
1
and
trainers_endpoints
:
assert
self
.
_build_strategy
.
num_trainers
==
len
(
assert
self
.
_build_strategy
.
num_trainers
==
len
(
trainers_endpoints
),
"num_trainers == len(end_points)"
trainers_endpoints
),
"num_trainers == len(end_points)"
...
...
python/paddle/fluid/framework.py
浏览文件 @
74bc55c2
...
@@ -1725,6 +1725,19 @@ class Program(object):
...
@@ -1725,6 +1725,19 @@ class Program(object):
self
.
_trainers_endpoints
=
[]
self
.
_trainers_endpoints
=
[]
# the distributed lookup table names
# the distributed lookup table names
self
.
_distributed_lookup_table
=
None
self
.
_distributed_lookup_table
=
None
# @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler
self
.
__is_mem_optimized
=
False
@
property
def
_is_mem_optimized
(
self
):
# if the program is optimized, operator input/outputs
# maybe same, which conflict with save_inference_model.
return
self
.
__is_mem_optimized
@
_is_mem_optimized
.
setter
def
_is_mem_optimized
(
self
,
target
):
self
.
__is_mem_optimized
=
target
@
property
@
property
def
op_role
(
self
):
def
op_role
(
self
):
...
@@ -1744,7 +1757,7 @@ class Program(object):
...
@@ -1744,7 +1757,7 @@ class Program(object):
return
self
.
_current_role
return
self
.
_current_role
@
op_role
.
setter
@
op_role
.
setter
def
set_
op_role
(
self
,
role
):
def
op_role
(
self
,
role
):
self
.
_current_role
=
role
self
.
_current_role
=
role
@
property
@
property
...
...
python/paddle/fluid/io.py
浏览文件 @
74bc55c2
...
@@ -16,6 +16,7 @@ from __future__ import print_function
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
os
import
os
import
errno
import
errno
import
warnings
import
time
import
time
import
shutil
import
shutil
import
six
import
six
...
@@ -931,6 +932,13 @@ def save_inference_model(dirname,
...
@@ -931,6 +932,13 @@ def save_inference_model(dirname,
if
main_program
is
None
:
if
main_program
is
None
:
main_program
=
default_main_program
()
main_program
=
default_main_program
()
if
main_program
.
_is_mem_optimized
:
warnings
.
warn
(
"save_inference_model must put before you call memory_optimize.
\
the memory_optimize will modify the original program,
\
is not suitable for saving inference model
\
we save the original program as inference model."
,
RuntimeWarning
)
# fix the bug that the activation op's output as target will be pruned.
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# will affect the inference performance.
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
74bc55c2
...
@@ -146,6 +146,9 @@ class ParallelExecutor(object):
...
@@ -146,6 +146,9 @@ class ParallelExecutor(object):
# step4: get main_program, scope, local_scopes
# step4: get main_program, scope, local_scopes
main
=
main_program
if
main_program
\
main
=
main_program
if
main_program
\
else
framework
.
default_main_program
()
else
framework
.
default_main_program
()
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
build_strategy
.
enable_inplace
=
False
if
main
.
_is_mem_optimized
else
True
scope
=
scope
if
scope
is
not
None
else
executor
.
global_scope
()
scope
=
scope
if
scope
is
not
None
else
executor
.
global_scope
()
if
share_vars_from
and
not
isinstance
(
share_vars_from
,
if
share_vars_from
and
not
isinstance
(
share_vars_from
,
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
74bc55c2
...
@@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase):
seed
=
None
,
seed
=
None
,
use_parallel_executor
=
True
,
use_parallel_executor
=
True
,
use_reduce
=
False
,
use_reduce
=
False
,
use_ir_memory_optimize
=
False
,
use_ir_memory_optimize
=
True
,
enable_inplace
=
True
,
fuse_elewise_add_act_ops
=
False
,
fuse_elewise_add_act_ops
=
False
,
fuse_relu_depthwise_conv
=
False
,
fuse_relu_depthwise_conv
=
False
,
optimizer
=
fluid
.
optimizer
.
Adam
,
optimizer
=
fluid
.
optimizer
.
Adam
,
...
@@ -60,7 +61,6 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -60,7 +61,6 @@ class TestParallelExecutorBase(unittest.TestCase):
main
.
random_seed
=
seed
main
.
random_seed
=
seed
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
if
optimizer
:
if
optimizer
:
optimizer
().
minimize
(
loss
)
optimizer
().
minimize
(
loss
)
...
@@ -80,7 +80,11 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -80,7 +80,11 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
fuse_relu_depthwise_conv
=
fuse_relu_depthwise_conv
build_strategy
.
fuse_relu_depthwise_conv
=
fuse_relu_depthwise_conv
build_strategy
.
memory_optimize
=
use_ir_memory_optimize
build_strategy
.
memory_optimize
=
use_ir_memory_optimize
# python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way.
build_strategy
.
enable_inplace
=
False
if
memory_opt
else
enable_inplace
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
if
use_cuda
and
core
.
is_compiled_with_cuda
():
if
use_cuda
and
core
.
is_compiled_with_cuda
():
build_strategy
.
remove_unnecessary_lock
=
True
build_strategy
.
remove_unnecessary_lock
=
True
if
use_parallel_executor
:
if
use_parallel_executor
:
...
@@ -100,8 +104,7 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -100,8 +104,7 @@ class TestParallelExecutorBase(unittest.TestCase):
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
for
i
in
range
(
iter
):
for
i
in
range
(
iter
):
run_executor
(
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[])
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[])
last_loss
,
=
run_executor
(
last_loss
,
=
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
...
...
python/paddle/fluid/tests/unittests/test_inference_model_io.py
浏览文件 @
74bc55c2
...
@@ -25,6 +25,7 @@ import paddle.fluid.layers as layers
...
@@ -25,6 +25,7 @@ import paddle.fluid.layers as layers
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.optimizer
as
optimizer
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.io
import
save_inference_model
,
load_inference_model
from
paddle.fluid.io
import
save_inference_model
,
load_inference_model
from
paddle.fluid.transpiler
import
memory_optimize
class
TestBook
(
unittest
.
TestCase
):
class
TestBook
(
unittest
.
TestCase
):
...
@@ -87,5 +88,31 @@ class TestBook(unittest.TestCase):
...
@@ -87,5 +88,31 @@ class TestBook(unittest.TestCase):
self
.
assertEqual
(
expected
,
actual
)
self
.
assertEqual
(
expected
,
actual
)
class
TestSaveInferenceModel
(
unittest
.
TestCase
):
def
test_save_inference_model
(
self
):
MODEL_DIR
=
"./tmp/inference_model2"
init_program
=
Program
()
program
=
Program
()
# fake program without feed/fetch
with
program_guard
(
program
,
init_program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
2
],
dtype
=
'float32'
)
y
=
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
layers
.
mean
(
cost
)
place
=
core
.
CPUPlace
()
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
init_program
,
feed
=
{},
fetch_list
=
[])
memory_optimize
(
program
,
print_log
=
True
)
self
.
assertEqual
(
program
.
_is_mem_optimized
,
True
)
# will print warning message
save_inference_model
(
MODEL_DIR
,
[
"x"
,
"y"
],
[
avg_cost
],
exe
,
program
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py
0 → 100644
浏览文件 @
74bc55c2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
parallel_executor_test_base
import
TestParallelExecutorBase
def
fc_with_batchnorm
(
use_feed
):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
hidden
=
img
for
_
in
range
(
3
):
hidden
=
fluid
.
layers
.
fc
(
hidden
,
size
=
200
,
act
=
'tanh'
,
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
)))
hidden
=
fluid
.
layers
.
batch_norm
(
input
=
hidden
)
prediction
=
fluid
.
layers
.
fc
(
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
class
TestIrInplace
(
TestParallelExecutorBase
):
@
classmethod
def
setUpClass
(
cls
):
os
.
environ
[
'CPU_NUM'
]
=
str
(
4
)
def
_fc_with_batchnorm
(
self
,
ir_memory_optimize
,
enable_inplace
,
memory_opt
=
False
):
if
not
core
.
is_compiled_with_cuda
():
return
np
.
random
.
seed
(
5
)
img
=
np
.
random
.
random
(
size
=
[
32
,
784
]).
astype
(
np
.
float32
)
label
=
np
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
self
.
check_network_convergence
(
fc_with_batchnorm
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
True
,
memory_opt
=
memory_opt
,
use_ir_memory_optimize
=
ir_memory_optimize
,
enable_inplace
=
enable_inplace
)
def
test_fc_with_batchnorm
(
self
,
delta
=
1e-3
):
loss00
=
self
.
_fc_with_batchnorm
(
False
,
False
)
loss10
=
self
.
_fc_with_batchnorm
(
True
,
False
)
loss01
=
self
.
_fc_with_batchnorm
(
False
,
True
)
loss11
=
self
.
_fc_with_batchnorm
(
True
,
True
)
self
.
assertAlmostEqual
(
loss00
,
loss10
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss01
,
delta
=
delta
)
self
.
assertAlmostEqual
(
loss00
,
loss11
,
delta
=
delta
)
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
浏览文件 @
74bc55c2
...
@@ -200,7 +200,7 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -200,7 +200,7 @@ class TestResnet(TestParallelExecutorBase):
model
,
model
,
use_cuda
,
use_cuda
,
iter
=
20
,
iter
=
20
,
delta2
=
1e-
6
):
delta2
=
1e-
5
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
return
...
@@ -228,7 +228,7 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -228,7 +228,7 @@ class TestResnet(TestParallelExecutorBase):
optimizer
=
optimizer
)
optimizer
=
optimizer
)
for
loss
in
zip
(
all_reduce_first_loss
,
reduce_first_loss
):
for
loss
in
zip
(
all_reduce_first_loss
,
reduce_first_loss
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
6
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
5
)
for
loss
in
zip
(
all_reduce_last_loss
,
reduce_last_loss
):
for
loss
in
zip
(
all_reduce_last_loss
,
reduce_last_loss
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
...
@@ -258,17 +258,17 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -258,17 +258,17 @@ class TestResnet(TestParallelExecutorBase):
enable_sequential_execution
=
True
)
enable_sequential_execution
=
True
)
for
loss
in
zip
(
all_reduce_first_loss
,
all_reduce_first_loss_seq
):
for
loss
in
zip
(
all_reduce_first_loss
,
all_reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
6
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
5
)
for
loss
in
zip
(
all_reduce_last_loss
,
all_reduce_last_loss_seq
):
for
loss
in
zip
(
all_reduce_last_loss
,
all_reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
for
loss
in
zip
(
reduce_first_loss
,
reduce_first_loss_seq
):
for
loss
in
zip
(
reduce_first_loss
,
reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
6
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
5
)
for
loss
in
zip
(
reduce_last_loss
,
reduce_last_loss_seq
):
for
loss
in
zip
(
reduce_last_loss
,
reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
for
loss
in
zip
(
all_reduce_first_loss_seq
,
reduce_first_loss_seq
):
for
loss
in
zip
(
all_reduce_first_loss_seq
,
reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
6
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-
5
)
for
loss
in
zip
(
all_reduce_last_loss_seq
,
reduce_last_loss_seq
):
for
loss
in
zip
(
all_reduce_last_loss_seq
,
reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
...
@@ -277,7 +277,7 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -277,7 +277,7 @@ class TestResnet(TestParallelExecutorBase):
use_cuda
=
True
,
use_cuda
=
True
,
use_reduce
=
False
,
use_reduce
=
False
,
iter
=
20
,
iter
=
20
,
delta2
=
1e-
6
):
delta2
=
1e-
5
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
return
...
@@ -308,7 +308,7 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -308,7 +308,7 @@ class TestResnet(TestParallelExecutorBase):
optimizer
=
optimizer
)
optimizer
=
optimizer
)
self
.
assertAlmostEquals
(
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_first_loss
),
single_first_loss
[
0
],
delta
=
1e-
6
)
np
.
mean
(
parallel_first_loss
),
single_first_loss
[
0
],
delta
=
1e-
5
)
self
.
assertAlmostEquals
(
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_last_loss
),
single_last_loss
[
0
],
delta
=
delta2
)
np
.
mean
(
parallel_last_loss
),
single_last_loss
[
0
],
delta
=
delta2
)
...
...
python/paddle/fluid/transpiler/memory_optimization_transpiler.py
浏览文件 @
74bc55c2
...
@@ -540,6 +540,7 @@ def memory_optimize(input_program,
...
@@ -540,6 +540,7 @@ def memory_optimize(input_program,
if
skip_opt_set
is
not
None
:
if
skip_opt_set
is
not
None
:
skip_opt_set
=
set
(
map
(
to_name_str
,
skip_opt_set
))
skip_opt_set
=
set
(
map
(
to_name_str
,
skip_opt_set
))
cfgs
=
_get_cfgs
(
input_program
)
cfgs
=
_get_cfgs
(
input_program
)
input_program
.
_is_mem_optimized
=
True
for
cfg
in
cfgs
:
for
cfg
in
cfgs
:
cfg
.
memory_optimize
(
skip_opt_set
=
skip_opt_set
,
level
=
level
)
cfg
.
memory_optimize
(
skip_opt_set
=
skip_opt_set
,
level
=
level
)
...
@@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None):
...
@@ -559,5 +560,6 @@ def release_memory(input_program, skip_opt_set=None):
None
None
"""
"""
cfgs
=
_get_cfgs
(
input_program
)
cfgs
=
_get_cfgs
(
input_program
)
input_program
.
_is_mem_optimized
=
True
for
cfg
in
cfgs
:
for
cfg
in
cfgs
:
cfg
.
release_memory
(
skip_opt_set
=
skip_opt_set
)
cfg
.
release_memory
(
skip_opt_set
=
skip_opt_set
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录