Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
41a8af2b
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
330
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
41a8af2b
编写于
8月 16, 2019
作者:
X
xiebaiyuan
提交者:
Jiaying Zhao
8月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove mali . && optimise memory use of super resolution (#1794)
close
#1791
上级
225ec4e0
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
343 addition
and
31 deletion
+343
-31
src/common/types.h
src/common/types.h
+0
-1
src/framework/cl/cl_image.h
src/framework/cl/cl_image.h
+49
-5
src/framework/executor.cpp
src/framework/executor.cpp
+13
-11
src/framework/executor.h
src/framework/executor.h
+1
-0
src/framework/loader.cpp
src/framework/loader.cpp
+0
-2
src/framework/operator.cpp
src/framework/operator.cpp
+0
-1
src/io/api_paddle_mobile.cc
src/io/api_paddle_mobile.cc
+0
-2
src/io/paddle_mobile.cpp
src/io/paddle_mobile.cpp
+0
-1
src/io/paddle_test_inference_api.cpp
src/io/paddle_test_inference_api.cpp
+0
-1
src/operators/op_param.cpp
src/operators/op_param.cpp
+0
-6
src/operators/slice_op.cpp
src/operators/slice_op.cpp
+1
-1
src/pass/memory_optimize_super.cpp
src/pass/memory_optimize_super.cpp
+209
-0
src/pass/memory_optimize_super.h
src/pass/memory_optimize_super.h
+70
-0
未找到文件。
src/common/types.h
浏览文件 @
41a8af2b
...
...
@@ -53,7 +53,6 @@ struct DeviceType {};
typedef
DeviceType
<
kCPU
>
CPU
;
typedef
DeviceType
<
kFPGA
>
FPGA
;
typedef
DeviceType
<
kGPU_MALI
>
GPU_MALI
;
typedef
DeviceType
<
kGPU_CL
>
GPU_CL
;
//! data type
...
...
src/framework/cl/cl_image.h
浏览文件 @
41a8af2b
...
...
@@ -145,21 +145,61 @@ class CLImage {
initialized_
=
true
;
DLOG
<<
" end init cl image"
;
}
// create fake size cl_mem for mem share
void
InitFakeSizeImage
(
cl_context
context
,
cl_command_queue
command_queue
,
const
DDim
&
need_dims
,
const
DDim
&
real_dims
)
{
PADDLE_MOBILE_ENFORCE
(
tensor_data_
==
nullptr
,
" empty image tensor data shouldn't have value"
);
void
InitEmptyWithImageDim
(
cl_context
context
,
cl_command_queue
command_queue
,
const
DDim
&
image_dims
)
{
DLOG
<<
" to get image dims "
;
image_dims_
=
image_dims
;
DLOG
<<
" end get image dims "
<<
image_dims_
;
CLImageConverterNormal
*
normal_converter
=
new
CLImageConverterNormal
();
real_image_dims
=
normal_converter
->
InitImageDimInfoWith
(
real_dims
);
real_tensor_dims
=
real_dims
;
image_dims_
=
normal_converter
->
InitImageDimInfoWith
(
need_dims
);
InitCLImage
(
context
,
image_dims_
[
0
],
image_dims_
[
1
],
nullptr
);
tensor_dims_
=
need_dims
;
command_queue_
=
command_queue
;
image_converter_
=
normal_converter
;
cl_event_
=
CLEngine
::
Instance
()
->
CreateEvent
(
context
);
initialized_
=
true
;
DLOG
<<
" end init cl image"
;
}
void
InitWithExitedMem
(
cl_context
context
,
cl_command_queue
command_queue
,
DDim
need_dims
,
CLImage
&
src
)
{
CLImageConverterNormal
*
normal_converter
=
new
CLImageConverterNormal
();
real_image_dims
=
normal_converter
->
InitImageDimInfoWith
(
src
.
dims
());
real_tensor_dims
=
src
.
dims
();
image_dims_
=
normal_converter
->
InitImageDimInfoWith
(
need_dims
);
// InitCLImage(context, image_dims_[0], image_dims_[1], nullptr);
if
(
cl_image_
!=
src
.
cl_image_
)
{
cl_image_
.
reset
(
src
.
cl_image_
.
get
());
}
tensor_dims_
=
need_dims
;
command_queue_
=
command_queue
;
image_converter_
=
normal_converter
;
cl_event_
=
CLEngine
::
Instance
()
->
CreateEvent
(
context
);
initialized_
=
true
;
DLOG
<<
" end init cl image"
;
}
/*! The internal of two tensors share the same memory block. */
inline
CLImage
&
ShareHolderWith
(
const
CLImage
&
src
)
{
PADDLE_MOBILE_ENFORCE
(
src
.
cl_image_
!=
nullptr
,
"Tensor holds no memory. Call Tensor::mutable_data first."
)
if
(
cl_image_
!=
src
.
cl_image_
)
{
cl_image_
.
reset
(
src
.
cl_image_
.
get
());
}
return
*
this
;
}
cl_mem
GetCLImage
()
const
{
return
cl_image_
.
get
();
}
const
DDim
&
ImageDims
()
const
{
return
image_dims_
;
}
...
...
@@ -238,6 +278,10 @@ class CLImage {
std
::
unique_ptr
<
_cl_event
,
CLEventDeleter
>
cl_event_
;
DDim
tensor_dims_
;
DDim
image_dims_
;
// real image dims usually it is same as image_dims
DDim
real_image_dims
;
// real tensor dims usually it is same as tensor dims
DDim
real_tensor_dims
;
float
*
tensor_data_
=
nullptr
;
cl_context
context_
;
cl_command_queue
command_queue_
;
...
...
src/framework/executor.cpp
浏览文件 @
41a8af2b
...
...
@@ -33,6 +33,7 @@ limitations under the License. */
#include "pass/model_obfuscate.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_image.h"
#include "pass/memory_optimize_super.h"
#endif
namespace
paddle_mobile
{
...
...
@@ -55,7 +56,7 @@ Executor<Device, T>::Executor(const Program<Device> &program,
use_optimize_
(
use_optimize
),
lod_mode_
(
lod_mode
),
config_
(
config
)
{
DLOG
<<
"executor in lod mode: "
<<
lod_mode
_
;
DLOG
<<
"executor in lod mode: "
<<
lod_mode
;
Variable
*
variable_ptr
=
program_
.
scope
->
Var
(
"batch_size"
);
variable_ptr
->
SetValue
<
int
>
(
batch_size
);
...
...
@@ -805,27 +806,30 @@ void Executor<GPU_CL, float>::SetInput(const Tensor &input,
index
=
feed_indices_
.
find
(
var_name
)
->
second
;
}
auto
*
feed_var
=
program_
.
scope
->
Var
(
"feed"
);
framework
::
LoDTensor
*
targe
t_tensor
=
framework
::
LoDTensor
*
inpu
t_tensor
=
&
(
feed_var
->
template
GetMutable
<
framework
::
LoDTensorArray
>()
->
at
(
index
));
DLOG
<<
"config_.load_when_predict "
<<
config_
.
load_when_predict
;
DLOG
<<
"target_tensor->IsInitialized() "
<<
targe
t_tensor
->
IsInitialized
();
DLOG
<<
"target_tensor->dims() "
<<
targe
t_tensor
->
dims
();
DLOG
<<
"target_tensor->IsInitialized() "
<<
inpu
t_tensor
->
IsInitialized
();
DLOG
<<
"target_tensor->dims() "
<<
inpu
t_tensor
->
dims
();
DLOG
<<
"input.dims() "
<<
input
.
dims
();
DLOG
<<
"input_dim_last_ "
<<
input_dim_last_
;
if
(
config_
.
load_when_predict
)
{
if
(
input_dim_last_
!=
input
.
dims
())
{
DLOG
<<
"SetInput ---- > resize1"
;
target_tensor
->
Resize
(
input
.
dims
());
target_tensor
->
mutable_data
<
float
>
();
InitNoPersistableMemory
(
*
target_tensor
);
input_tensor
->
Resize
(
input
.
dims
());
input_tensor
->
mutable_data
<
float
>
();
// InitNoPersistableMemory(*input_tensor);
pass
::
MemoryOptPassSuper
()(
program_desc_
.
get
(),
program_
.
scope
.
get
(),
config_
.
memory_optimization_level
,
input
.
dims
());
}
}
else
{
DLOG
<<
"SetInput ---- > resize2"
;
targe
t_tensor
->
Resize
(
input
.
dims
());
inpu
t_tensor
->
Resize
(
input
.
dims
());
DLOG
<<
"SetInput ---- > ShareDataWith"
;
}
targe
t_tensor
->
ShareDataWith
(
input
);
inpu
t_tensor
->
ShareDataWith
(
input
);
if
(
feed_indices_
.
size
()
==
1
)
{
input_dim_has_changed_
=
input_dim_last_
!=
input
.
dims
();
}
...
...
@@ -1063,7 +1067,5 @@ template class Executor<FPGA, float>;
template
class
Executor
<
GPU_CL
,
float
>;
template
class
Executor
<
GPU_MALI
,
float
>;
}
// namespace framework
}
// namespace paddle_mobile
src/framework/executor.h
浏览文件 @
41a8af2b
...
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "framework/program/program.h"
#include "framework/tensor.h"
#include "framework/type_trait.h"
#include "pass/memory_optimize.h"
namespace
paddle_mobile
{
namespace
framework
{
...
...
src/framework/loader.cpp
浏览文件 @
41a8af2b
...
...
@@ -284,8 +284,6 @@ template class Loader<CPU, float>;
template
class
Loader
<
FPGA
,
float
>;
template
class
Loader
<
GPU_MALI
,
float
>;
template
class
Loader
<
GPU_CL
,
float
>;
}
// namespace framework
...
...
src/framework/operator.cpp
浏览文件 @
41a8af2b
...
...
@@ -148,7 +148,6 @@ void OperatorBase<Dtype>::InsertTensors() {
template
class
OperatorBase
<
CPU
>;
template
class
OperatorBase
<
FPGA
>;
template
class
OperatorBase
<
GPU_MALI
>;
template
class
OperatorBase
<
GPU_CL
>;
}
// namespace framework
...
...
src/io/api_paddle_mobile.cc
浏览文件 @
41a8af2b
...
...
@@ -242,8 +242,6 @@ CreatePaddlePredictor<PaddleMobileConfig, PaddleEngineKind::kPaddleMobile>(
x
.
reset
(
new
PaddleMobilePredictor
<
CPU
,
float
>
(
config
));
}
else
if
(
config
.
device
==
PaddleMobileConfig
::
kFPGA
)
{
x
.
reset
(
new
PaddleMobilePredictor
<
FPGA
,
float
>
(
config
));
}
else
if
(
config
.
device
==
PaddleMobileConfig
::
kGPU_MALI
)
{
x
.
reset
(
new
PaddleMobilePredictor
<
GPU_MALI
,
float
>
(
config
));
}
else
if
(
config
.
device
==
PaddleMobileConfig
::
kGPU_CL
)
{
x
.
reset
(
new
PaddleMobilePredictor
<
GPU_CL
,
float
>
(
config
));
}
else
{
...
...
src/io/paddle_mobile.cpp
浏览文件 @
41a8af2b
...
...
@@ -525,7 +525,6 @@ int PaddleMobile<Device, T>::readText(
template
class
PaddleMobile
<
CPU
,
float
>;
template
class
PaddleMobile
<
FPGA
,
float
>;
template
class
PaddleMobile
<
GPU_MALI
,
float
>;
template
class
PaddleMobile
<
GPU_CL
,
float
>;
}
// namespace paddle_mobile
src/io/paddle_test_inference_api.cpp
浏览文件 @
41a8af2b
...
...
@@ -30,7 +30,6 @@ double PaddleTester<Device, T>::CaculatePredictTime(std::string *cl_path) {
}
template
class
PaddleTester
<
CPU
,
float
>;
template
class
PaddleTester
<
FPGA
,
float
>;
template
class
PaddleTester
<
GPU_MALI
,
float
>;
template
class
PaddleTester
<
GPU_CL
,
float
>;
...
...
src/operators/op_param.cpp
浏览文件 @
41a8af2b
...
...
@@ -41,37 +41,31 @@ Print &operator<<(Print &printer, const ConvParam<CPU> &conv_param) {
template
class
ConvParam
<
CPU
>;
template
class
ConvParam
<
FPGA
>;
template
class
ConvParam
<
GPU_MALI
>;
#endif
#ifdef ELEMENTWISEADD_OP
template
class
ElementwiseAddParam
<
CPU
>;
template
class
ElementwiseAddParam
<
FPGA
>;
template
class
ElementwiseAddParam
<
GPU_MALI
>;
#endif
#ifdef ELEMENTWISEMUL_OP
template
class
ElementwiseMulParam
<
CPU
>;
template
class
ElementwiseMulParam
<
FPGA
>;
template
class
ElementwiseMulParam
<
GPU_MALI
>;
#endif
#ifdef MUL_OP
template
class
MulParam
<
CPU
>;
template
class
MulParam
<
FPGA
>;
template
class
MulParam
<
GPU_MALI
>;
#endif
#ifdef CONCAT_OP
template
class
ConcatParam
<
CPU
>;
template
class
ConcatParam
<
FPGA
>;
template
class
ConcatParam
<
GPU_MALI
>;
#endif
#ifdef LRN_OP
template
class
LrnParam
<
CPU
>;
template
class
LrnParam
<
FPGA
>;
template
class
LrnParam
<
GPU_MALI
>;
#endif
#ifdef FUSION_CONVADD_OP
...
...
src/operators/slice_op.cpp
浏览文件 @
41a8af2b
...
...
@@ -84,7 +84,7 @@ void SliceOp<Dtype, T>::InferShape() const {
}
}
output
->
Resize
(
out_dims
);
#if
def PADDLE_MOBILE_CPU
#if
!defined(PADDLE_MOBILE_CL) && defined(PADDLE_MOBILE_CPU)
if
(
axes
[
0
]
!=
0
)
{
output
->
set_lod
(
input
->
lod
());
}
...
...
src/pass/memory_optimize_super.cpp
0 → 100644
浏览文件 @
41a8af2b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_MOBILE_CL
#include "pass/memory_optimize_super.h"
#include <algorithm>
#include "framework/cl/cl_image.h"
#include "framework/lod_tensor.h"
namespace
paddle_mobile
{
namespace
pass
{
void
MemoryOptPassSuper
::
AppendBlockVars
(
const
framework
::
BlockDesc
*
block
)
{
// block_vars_.clear();
for
(
const
auto
var
:
block
->
Vars
())
{
block_vars_
[
var
->
Name
()]
=
var
.
get
();
}
}
bool
MemoryOptPassSuper
::
IsPersistable
(
const
std
::
string
name
)
{
const
auto
it
=
block_vars_
.
find
(
name
);
if
(
it
!=
block_vars_
.
end
())
{
return
it
->
second
->
Persistable
();
}
return
false
;
}
ClVarNode
*
MemoryOptPassSuper
::
CreateNode
(
const
std
::
string
name
)
{
auto
it
=
created_nodes_
.
find
(
name
);
if
(
it
!=
created_nodes_
.
end
())
{
++
(
it
->
second
->
count
);
return
it
->
second
;
}
ClVarNode
*
var
=
new
ClVarNode
;
var
->
name
=
name
;
var
->
count
=
1
;
var
->
visited
=
false
;
created_nodes_
[
name
]
=
var
;
return
var
;
}
void
MemoryOptPassSuper
::
operator
()(
const
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
scope
,
MemoryOptimizationLevel
memory_optimization_level
,
framework
::
DDim
target_dims
)
{
const
auto
&
blocks
=
program
->
Blocks
();
for
(
const
auto
&
block
:
blocks
)
{
// access all variables in each block
AppendBlockVars
(
block
.
get
());
reused_nodes_
.
clear
();
// collect all not persistable variables, and accumulate
// it's reference count
std
::
stack
<
ClVarNode
*>
empty_var_nodes
;
analysis_nodes_
.
swap
(
empty_var_nodes
);
std
::
vector
<
std
::
string
>
exclude_var_names
;
for
(
const
auto
&
op
:
block
->
Ops
())
{
for
(
const
auto
&
inputs
:
op
->
GetInputs
())
{
for
(
const
auto
&
input
:
inputs
.
second
)
{
if
(
!
IsPersistable
(
input
))
{
if
(
memory_optimization_level
==
MemoryOptimizationWithoutFeeds
)
{
if
(
op
->
Type
()
==
"feed"
)
{
exclude_var_names
.
push_back
(
input
);
}
}
}
}
}
}
std
::
vector
<
ClVarNode
*>
fetch_var_nodes
;
for
(
const
auto
&
op
:
block
->
Ops
())
{
DLOG
<<
"op_desc->Type(): "
<<
op
->
Type
();
for
(
const
auto
&
outputs
:
op
->
GetOutputs
())
{
for
(
const
auto
&
output
:
outputs
.
second
)
{
if
(
!
IsPersistable
(
output
)
&&
std
::
find
(
exclude_var_names
.
begin
(),
exclude_var_names
.
end
(),
output
)
==
exclude_var_names
.
end
())
{
DLOG
<<
"output: "
<<
output
;
ClVarNode
*
node
=
CreateNode
(
output
);
analysis_nodes_
.
push
(
node
);
}
}
}
for
(
const
auto
&
inputs
:
op
->
GetInputs
())
{
for
(
const
auto
&
input
:
inputs
.
second
)
{
if
(
!
IsPersistable
(
input
)
&&
std
::
find
(
exclude_var_names
.
begin
(),
exclude_var_names
.
end
(),
input
)
==
exclude_var_names
.
end
())
{
DLOG
<<
"input: "
<<
input
;
ClVarNode
*
node
=
CreateNode
(
input
);
analysis_nodes_
.
push
(
node
);
if
(
op
->
Type
()
==
"fetch"
)
{
fetch_var_nodes
.
push_back
(
node
);
}
}
}
}
for
(
const
auto
&
outputs
:
op
->
GetOutputs
())
{
for
(
const
auto
&
output
:
outputs
.
second
)
{
if
(
!
IsPersistable
(
output
)
&&
std
::
find
(
exclude_var_names
.
begin
(),
exclude_var_names
.
end
(),
output
)
==
exclude_var_names
.
end
())
{
DLOG
<<
"output: "
<<
output
;
ClVarNode
*
node
=
CreateNode
(
output
);
analysis_nodes_
.
push
(
node
);
}
}
}
}
// apply optimize
while
(
!
analysis_nodes_
.
empty
())
{
auto
*
node
=
analysis_nodes_
.
top
();
analysis_nodes_
.
pop
();
// only not visited node can reuse memory between other nodes
// with 0 count which indicate they will not be used any more
if
(
!
node
->
visited
)
{
bool
reused
=
false
;
// find out a possable reuse list
for
(
auto
&
list
:
reused_nodes_
)
{
if
(
list
.
back
()
->
count
==
0
&&
std
::
find
(
fetch_var_nodes
.
begin
(),
fetch_var_nodes
.
end
(),
list
.
back
())
==
fetch_var_nodes
.
end
())
{
list
.
push_back
(
node
);
reused
=
true
;
break
;
}
}
// create new list if can't find a reused list
if
(
!
reused
)
{
std
::
vector
<
ClVarNode
*>
list
;
list
.
push_back
(
node
);
reused_nodes_
.
push_back
(
std
::
move
(
list
));
}
}
node
->
visited
=
true
;
node
->
count
-=
1
;
}
// shared data within all variables in the same reused list
ShareData
(
scope
,
memory_optimization_level
,
target_dims
);
}
}
void
MemoryOptPassSuper
::
ShareData
(
framework
::
Scope
*
scope
,
MemoryOptimizationLevel
memory_optimization_level
,
framework
::
DDim
target_dims
)
const
{
// shared data within all variables in the same reused list
for
(
const
auto
&
list
:
reused_nodes_
)
{
DLOG
<<
"
\n
"
;
DLOG
<<
"gpu . share memory within these variables"
;
// find max dims
int64_t
max_numl
=
-
1
;
framework
::
CLImage
*
reuse_tensor
=
nullptr
;
DLOG
<<
"resused nodes group ----------"
;
for
(
const
auto
&
node
:
list
)
{
auto
*
var
=
scope
->
Var
(
node
->
name
);
auto
*
tensor
=
var
->
template
GetMutable
<
framework
::
CLImage
>();
const
int64_t
numl
=
tensor
->
numel
();
if
(
max_numl
<
numl
)
{
max_numl
=
numl
;
reuse_tensor
=
tensor
;
}
DLOG
<<
node
->
name
<<
" ----dims: "
<<
tensor
->
dims
()
<<
"----numl----: "
<<
numl
;
}
if
(
reuse_tensor
==
nullptr
)
{
return
;
}
const
framework
::
DDim
&
dims
=
reuse_tensor
->
dims
();
cl_context
context
=
scope
->
GetCLScpoe
()
->
Context
();
cl_command_queue
command_queue
=
scope
->
GetCLScpoe
()
->
CommandQueue
();
framework
::
DDim
reshaped_dim
=
framework
::
make_ddim
(
{
dims
[
0
],
dims
[
1
],
target_dims
[
2
],
target_dims
[
3
]});
DLOG
<<
"target dims : "
<<
target_dims
;
DLOG
<<
"reshaped_dim : "
<<
reshaped_dim
;
reuse_tensor
->
InitFakeSizeImage
(
context
,
command_queue
,
reshaped_dim
,
reshaped_dim
);
for
(
const
auto
&
node
:
list
)
{
auto
*
var
=
scope
->
Var
(
node
->
name
);
auto
*
tensor
=
var
->
template
GetMutable
<
framework
::
CLImage
>();
const
framework
::
DDim
&
temp_dim
=
tensor
->
dims
();
framework
::
DDim
need_dims
=
framework
::
make_ddim
(
{
temp_dim
[
0
],
temp_dim
[
1
],
target_dims
[
2
],
target_dims
[
3
]});
tensor
->
InitWithExitedMem
(
context
,
command_queue
,
need_dims
,
*
reuse_tensor
);
}
}
}
}
// namespace pass
}
// namespace paddle_mobile
#endif
src/pass/memory_optimize_super.h
0 → 100644
浏览文件 @
41a8af2b
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_MOBILE_CL
#pragma once
#include <stack>
#include <string>
#include <unordered_map>
#include <vector>
#include "framework/lod_tensor.h"
#include "framework/program/program.h"
#include "pass/pass_base.h"
// use for super resulotion to be extend for all opencl
namespace
paddle_mobile
{
namespace
pass
{
typedef
struct
{
std
::
string
name
;
// variable name
int
count
;
// reference count
bool
visited
;
}
ClVarNode
;
// MemoryOptPass will analyze the program, and reuse memory between
// variables as much as possible
class
MemoryOptPassSuper
:
public
PassBase
{
public:
MemoryOptPassSuper
()
{}
virtual
~
MemoryOptPassSuper
()
{
for
(
auto
&
it
:
created_nodes_
)
{
delete
it
.
second
;
}
}
void
operator
()(
const
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
scope
,
MemoryOptimizationLevel
memory_optimization_level
,
framework
::
DDim
dims
);
void
AppendBlockVars
(
const
framework
::
BlockDesc
*
block
);
bool
IsPersistable
(
const
std
::
string
name
);
ClVarNode
*
CreateNode
(
const
std
::
string
name
);
void
ShareData
(
framework
::
Scope
*
scope
,
MemoryOptimizationLevel
memory_optimization_level
,
framework
::
DDim
dims
)
const
;
private:
std
::
stack
<
ClVarNode
*>
analysis_nodes_
;
std
::
vector
<
std
::
vector
<
ClVarNode
*>>
reused_nodes_
;
std
::
unordered_map
<
std
::
string
,
ClVarNode
*>
created_nodes_
;
std
::
unordered_map
<
std
::
string
,
framework
::
VarDesc
*>
block_vars_
;
};
}
// namespace pass
}
// namespace paddle_mobile
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录