Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
546d4da8
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
546d4da8
编写于
5月 18, 2020
作者:
H
huzhiqiang
提交者:
GitHub
5月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Framework][ModelType] Add Shape&Precision information into optimized model (#3643)
上级
a24d4dd1
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
152 addition
and
39 deletion
+152
-39
lite/core/program.cc
lite/core/program.cc
+35
-28
lite/model_parser/compatible_pb.cc
lite/model_parser/compatible_pb.cc
+23
-8
lite/model_parser/compatible_pb_test.cc
lite/model_parser/compatible_pb_test.cc
+4
-0
lite/model_parser/cpp/var_desc.h
lite/model_parser/cpp/var_desc.h
+6
-0
lite/model_parser/desc_apis.h
lite/model_parser/desc_apis.h
+4
-0
lite/model_parser/naive_buffer/var_desc.cc
lite/model_parser/naive_buffer/var_desc.cc
+51
-0
lite/model_parser/naive_buffer/var_desc.h
lite/model_parser/naive_buffer/var_desc.h
+7
-0
lite/model_parser/pb/var_desc.cc
lite/model_parser/pb/var_desc.cc
+21
-2
lite/model_parser/pb/var_desc.h
lite/model_parser/pb/var_desc.h
+1
-1
未找到文件。
lite/core/program.cc
浏览文件 @
546d4da8
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "lite/core/program.h"
#include "lite/core/program.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_map>
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
...
@@ -85,48 +86,54 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
...
@@ -85,48 +86,54 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
auto
*
scope
=
op
->
scope
();
auto
*
scope
=
op
->
scope
();
auto
in_names
=
op
->
op_info
()
->
input_names
();
auto
in_names
=
op
->
op_info
()
->
input_names
();
auto
out_names
=
op
->
op_info
()
->
output_names
();
auto
out_names
=
op
->
op_info
()
->
output_names
();
for
(
auto
&
in_name
:
in_names
)
{
auto
it
=
origin_var_maps
.
find
(
in_name
);
std
::
vector
<
std
::
string
>
var_names
;
var_names
.
insert
(
var_names
.
end
(),
in_names
.
begin
(),
in_names
.
end
());
var_names
.
insert
(
var_names
.
end
(),
out_names
.
begin
(),
out_names
.
end
());
std
::
sort
(
var_names
.
begin
(),
var_names
.
end
());
var_names
.
erase
(
std
::
unique
(
var_names
.
begin
(),
var_names
.
end
()),
var_names
.
end
());
for
(
auto
&
var_name
:
var_names
)
{
auto
it
=
origin_var_maps
.
find
(
var_name
);
if
(
it
!=
origin_var_maps
.
end
())
{
if
(
it
!=
origin_var_maps
.
end
())
{
auto
*
v
=
main_block
.
AddVar
<
cpp
::
VarDesc
>
();
auto
*
v
=
main_block
.
AddVar
<
cpp
::
VarDesc
>
();
v
->
SetName
((
it
->
second
).
Name
());
v
->
SetName
((
it
->
second
).
Name
());
v
->
SetType
((
it
->
second
).
GetType
());
v
->
SetType
((
it
->
second
).
GetType
());
v
->
SetPersistable
((
it
->
second
).
Persistable
());
v
->
SetPersistable
((
it
->
second
).
Persistable
());
if
((
it
->
second
).
Name
()
!=
"feed"
&&
(
it
->
second
).
Name
()
!=
"fetch"
)
{
v
->
SetShape
((
it
->
second
).
GetShape
());
v
->
SetDataType
((
it
->
second
).
GetDataType
());
}
}
else
{
}
else
{
// New created vars must be LOD_TENSOR
// New created vars must be LOD_TENSOR
auto
*
v
=
main_block
.
AddVar
<
cpp
::
VarDesc
>
();
auto
*
v
=
main_block
.
AddVar
<
cpp
::
VarDesc
>
();
v
->
SetName
(
in
_name
);
v
->
SetName
(
var
_name
);
v
->
SetType
(
cpp
::
VarDesc
::
Type
::
LOD_TENSOR
);
v
->
SetType
(
cpp
::
VarDesc
::
Type
::
LOD_TENSOR
);
std
::
string
in_arg_name
;
std
::
string
in_arg_name
;
op
->
op_info
()
->
GetInputArgname
(
in
_name
,
&
in_arg_name
);
op
->
op_info
()
->
GetInputArgname
(
var
_name
,
&
in_arg_name
);
auto
type
=
kernel
->
GetInputDeclType
(
in_arg_name
);
auto
type
=
kernel
->
GetInputDeclType
(
in_arg_name
);
if
(
type
->
IsTensor
())
{
if
(
type
->
IsTensor
())
{
auto
tensor
=
scope
->
FindVar
(
in
_name
)
->
GetMutable
<
Tensor
>
();
auto
tensor
=
scope
->
FindVar
(
var
_name
)
->
GetMutable
<
Tensor
>
();
v
->
SetPersistable
(
tensor
->
persistable
());
v
->
SetPersistable
(
tensor
->
persistable
());
}
else
{
if
((
it
->
second
).
Name
()
!=
"feed"
&&
(
it
->
second
).
Name
()
!=
"fetch"
)
{
CHECK
(
false
)
<<
"unsupported var type"
;
v
->
SetShape
(
tensor
->
dims
().
data
());
}
switch
(
tensor
->
precision
())
{
}
#define SET_DATATYPE(precision__, data_type) \
}
case PrecisionType::precision__: \
v->SetDataType(data_type); \
break
for
(
auto
&
out_name
:
out_names
)
{
SET_DATATYPE
(
kFloat
,
VarDescAPI
::
VarDataType
::
FP32
);
auto
it
=
origin_var_maps
.
find
(
out_name
);
SET_DATATYPE
(
kInt8
,
VarDescAPI
::
VarDataType
::
INT8
);
if
(
it
!=
origin_var_maps
.
end
())
{
SET_DATATYPE
(
kInt16
,
VarDescAPI
::
VarDataType
::
INT16
);
auto
*
v
=
main_block
.
AddVar
<
cpp
::
VarDesc
>
();
SET_DATATYPE
(
kInt32
,
VarDescAPI
::
VarDataType
::
INT32
);
v
->
SetName
((
it
->
second
).
Name
());
SET_DATATYPE
(
kInt64
,
VarDescAPI
::
VarDataType
::
INT64
);
v
->
SetType
((
it
->
second
).
GetType
());
#undef SET_DATATYPE
v
->
SetPersistable
((
it
->
second
).
Persistable
());
default:
}
else
{
LOG
(
FATAL
)
<<
"unknown precision type"
;
// New created vars must be LOD_TENSOR
}
auto
*
v
=
main_block
.
AddVar
<
cpp
::
VarDesc
>
();
}
v
->
SetName
(
out_name
);
v
->
SetType
(
cpp
::
VarDesc
::
Type
::
LOD_TENSOR
);
std
::
string
out_arg_name
;
op
->
op_info
()
->
GetOutputArgname
(
out_name
,
&
out_arg_name
);
auto
type
=
kernel
->
GetOutputDeclType
(
out_arg_name
);
if
(
type
->
IsTensor
())
{
auto
tensor
=
scope
->
FindVar
(
out_name
)
->
GetMutable
<
Tensor
>
();
v
->
SetPersistable
(
tensor
->
persistable
());
}
else
{
}
else
{
CHECK
(
false
)
<<
"unsupported var type"
;
CHECK
(
false
)
<<
"unsupported var type"
;
}
}
...
...
lite/model_parser/compatible_pb.cc
浏览文件 @
546d4da8
...
@@ -30,13 +30,17 @@ namespace paddle {
...
@@ -30,13 +30,17 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
/// For VarDesc transfrom
/// For VarDesc transfrom
#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \
#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \
template <> \
template <> \
void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \
void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \
T *any_desc) { \
T *any_desc) { \
any_desc->SetName(cpp_desc.Name()); \
any_desc->SetName(cpp_desc.Name()); \
any_desc->SetType(cpp_desc.GetType()); \
any_desc->SetType(cpp_desc.GetType()); \
any_desc->SetPersistable(cpp_desc.Persistable()); \
any_desc->SetPersistable(cpp_desc.Persistable()); \
if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { \
any_desc->SetShape(cpp_desc.GetShape()); \
any_desc->SetDataType(cpp_desc.GetDataType()); \
} \
}
}
#ifndef LITE_ON_TINY_PUBLISH
#ifndef LITE_ON_TINY_PUBLISH
...
@@ -46,7 +50,10 @@ void TransformVarDescAnyToCpp<pb::VarDesc>(const pb::VarDesc &any_desc,
...
@@ -46,7 +50,10 @@ void TransformVarDescAnyToCpp<pb::VarDesc>(const pb::VarDesc &any_desc,
cpp_desc
->
SetName
(
any_desc
.
Name
());
cpp_desc
->
SetName
(
any_desc
.
Name
());
cpp_desc
->
SetType
(
any_desc
.
GetType
());
cpp_desc
->
SetType
(
any_desc
.
GetType
());
cpp_desc
->
SetPersistable
(
any_desc
.
Persistable
());
cpp_desc
->
SetPersistable
(
any_desc
.
Persistable
());
cpp_desc
->
SetDataType
(
any_desc
.
GetDataType
());
if
(
any_desc
.
Name
()
!=
"feed"
&&
any_desc
.
Name
()
!=
"fetch"
)
{
cpp_desc
->
SetDataType
(
any_desc
.
GetDataType
());
cpp_desc
->
SetShape
(
any_desc
.
GetShape
());
}
}
}
#endif
#endif
...
@@ -56,6 +63,14 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
...
@@ -56,6 +63,14 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
cpp_desc
->
SetName
(
any_desc
.
Name
());
cpp_desc
->
SetName
(
any_desc
.
Name
());
cpp_desc
->
SetType
(
any_desc
.
GetType
());
cpp_desc
->
SetType
(
any_desc
.
GetType
());
cpp_desc
->
SetPersistable
(
any_desc
.
Persistable
());
cpp_desc
->
SetPersistable
(
any_desc
.
Persistable
());
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
/* if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}*/
}
}
/// For OpDesc transform
/// For OpDesc transform
...
...
lite/model_parser/compatible_pb_test.cc
浏览文件 @
546d4da8
...
@@ -36,6 +36,8 @@ void SetVarDesc(VarDescType* desc) {
...
@@ -36,6 +36,8 @@ void SetVarDesc(VarDescType* desc) {
desc
->
SetName
(
"X"
);
desc
->
SetName
(
"X"
);
desc
->
SetPersistable
(
true
);
desc
->
SetPersistable
(
true
);
desc
->
SetType
(
VarDescAPI
::
Type
::
LOD_TENSOR
);
desc
->
SetType
(
VarDescAPI
::
Type
::
LOD_TENSOR
);
desc
->
SetShape
({
1
,
3
,
224
,
224
});
desc
->
SetDataType
(
VarDescAPI
::
VarDataType
::
FP32
);
}
}
template
<
typename
VarDescType
>
template
<
typename
VarDescType
>
...
@@ -43,6 +45,8 @@ void SetVarDesc1(VarDescType* desc) {
...
@@ -43,6 +45,8 @@ void SetVarDesc1(VarDescType* desc) {
desc
->
SetName
(
"Y"
);
desc
->
SetName
(
"Y"
);
desc
->
SetPersistable
(
false
);
desc
->
SetPersistable
(
false
);
desc
->
SetType
(
VarDescAPI
::
Type
::
SELECTED_ROWS
);
desc
->
SetType
(
VarDescAPI
::
Type
::
SELECTED_ROWS
);
desc
->
SetShape
({
1
,
3
,
224
,
224
});
desc
->
SetDataType
(
VarDescAPI
::
VarDataType
::
FP32
);
}
}
template
<
typename
VarDescType
>
template
<
typename
VarDescType
>
...
...
lite/model_parser/cpp/var_desc.h
浏览文件 @
546d4da8
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <string>
#include <string>
#include <vector>
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/desc_apis.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -46,11 +47,16 @@ class VarDesc : public VarDescAPI {
...
@@ -46,11 +47,16 @@ class VarDesc : public VarDescAPI {
void
SetDataType
(
Type
data_type
)
{
data_type_
=
data_type
;
}
void
SetDataType
(
Type
data_type
)
{
data_type_
=
data_type
;
}
void
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
shape_
=
dims
;
}
std
::
vector
<
int64_t
>
GetShape
()
const
{
return
shape_
;
}
private:
private:
std
::
string
name_
;
std
::
string
name_
;
Type
type_
;
Type
type_
;
Type
data_type_
;
Type
data_type_
;
bool
persistable_
;
bool
persistable_
;
std
::
vector
<
int64_t
>
shape_
;
};
};
}
// namespace cpp
}
// namespace cpp
...
...
lite/model_parser/desc_apis.h
浏览文件 @
546d4da8
...
@@ -76,6 +76,10 @@ class VarDescAPI {
...
@@ -76,6 +76,10 @@ class VarDescAPI {
virtual
bool
Persistable
()
const
=
0
;
virtual
bool
Persistable
()
const
=
0
;
// Set var to be persistable or not
// Set var to be persistable or not
virtual
void
SetPersistable
(
bool
persistable
)
=
0
;
virtual
void
SetPersistable
(
bool
persistable
)
=
0
;
// Get var's shape
virtual
std
::
vector
<
int64_t
>
GetShape
()
const
=
0
;
// Set var's shape
virtual
void
SetShape
(
const
std
::
vector
<
int64_t
>&
dims
)
=
0
;
};
};
/*
/*
...
...
lite/model_parser/naive_buffer/var_desc.cc
浏览文件 @
546d4da8
...
@@ -131,6 +131,57 @@ proto::VarType* VarDesc::GetMutableVarType() {
...
@@ -131,6 +131,57 @@ proto::VarType* VarDesc::GetMutableVarType() {
return
builder
;
return
builder
;
}
}
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
void
VarDesc
::
SetDataType
(
VarDescAPI
::
VarDataType
data_type
)
{
/* using data_type_builder_t = EnumBuilder<proto::VarDataType>;
auto data_type_builder =
desc_->GetMutableField<proto::TensorDesc>("tensor_desc")
->GetMutableField<data_type_builder_t>("data_type");
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case VarDescAPI::VarDataType::type__: \
data_type_builder->set(proto::VarDataType::type__); \
break
switch (data_type) {
// Only support primary data type now.
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var data type";
}
#undef SET_DATA_TYPE_CASE_ITEM
*/
}
// Get var's shape
std
::
vector
<
int64_t
>
VarDesc
::
GetShape
()
const
{
using
data_type_builder_t
=
ListBuilder
<
Int64Builder
>
;
auto
out_builder
=
desc_
->
GetField
<
proto
::
TensorDesc
>
(
"tensor_desc"
)
.
GetField
<
data_type_builder_t
>
(
"dims"
);
return
RepeatedToVector
<
int64_t
,
Int64Builder
>
(
out_builder
);
}
// Set var's shape
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
void
VarDesc
::
SetShape
(
const
std
::
vector
<
int64_t
>&
dims
)
{
/* using out_builder_type = ListBuilder<Int64Builder>;
auto out_builder = desc_->GetMutableField<proto::TensorDesc>("tensor_desc")
->GetMutableField<out_builder_type>("dims");
CHECK(out_builder);
VectorToRepeated<int64_t, Int64Builder>(dims, out_builder);*/
}
}
// namespace naive_buffer
}
// namespace naive_buffer
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
lite/model_parser/naive_buffer/var_desc.h
浏览文件 @
546d4da8
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/naive_buffer/naive_buffer_wrapper_helper.h"
#include "lite/model_parser/naive_buffer/proto/framework.nb.h"
#include "lite/model_parser/naive_buffer/proto/framework.nb.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -51,8 +52,14 @@ class VarDesc : public VarDescAPI {
...
@@ -51,8 +52,14 @@ class VarDesc : public VarDescAPI {
void
SetPersistable
(
bool
persistable
)
override
;
void
SetPersistable
(
bool
persistable
)
override
;
void
SetDataType
(
VarDescAPI
::
VarDataType
data_type
);
VarDescAPI
::
VarDataType
GetDataType
()
const
;
VarDescAPI
::
VarDataType
GetDataType
()
const
;
// Get var's shape
std
::
vector
<
int64_t
>
GetShape
()
const
;
// Set var's shape
void
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
);
private:
private:
const
proto
::
VarType
&
GetVarType
()
const
;
const
proto
::
VarType
&
GetVarType
()
const
;
proto
::
VarType
*
GetMutableVarType
();
proto
::
VarType
*
GetMutableVarType
();
...
...
lite/model_parser/pb/var_desc.cc
浏览文件 @
546d4da8
...
@@ -130,8 +130,27 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
...
@@ -130,8 +130,27 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return
res
;
return
res
;
}
}
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
void
VarDesc
::
SetDataType
(
VarDescAPI
::
VarDataType
data_type
)
{
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case VarDescAPI::Type::type__: \
mutable_tensor_desc()->set_data_type(framework::proto::VarType::type__); \
break;
switch
(
data_type
)
{
SET_DATA_TYPE_CASE_ITEM
(
BOOL
);
SET_DATA_TYPE_CASE_ITEM
(
SIZE_T
);
SET_DATA_TYPE_CASE_ITEM
(
UINT8
);
SET_DATA_TYPE_CASE_ITEM
(
INT8
);
SET_DATA_TYPE_CASE_ITEM
(
INT16
);
SET_DATA_TYPE_CASE_ITEM
(
INT32
);
SET_DATA_TYPE_CASE_ITEM
(
INT64
);
SET_DATA_TYPE_CASE_ITEM
(
FP16
);
SET_DATA_TYPE_CASE_ITEM
(
FP32
);
SET_DATA_TYPE_CASE_ITEM
(
FP64
);
default:
LOG
(
FATAL
)
<<
"Unknown var type: "
<<
static_cast
<
int
>
(
data_type
);
}
#undef SET_DATA_TYPE_CASE_ITEM
}
}
void
VarDesc
::
SetDataTypes
(
void
VarDesc
::
SetDataTypes
(
...
...
lite/model_parser/pb/var_desc.h
浏览文件 @
546d4da8
...
@@ -84,7 +84,7 @@ class VarDesc : public VarDescAPI {
...
@@ -84,7 +84,7 @@ class VarDesc : public VarDescAPI {
std
::
vector
<
std
::
vector
<
int64_t
>>
GetShapes
()
const
;
std
::
vector
<
std
::
vector
<
int64_t
>>
GetShapes
()
const
;
void
SetDataType
(
framework
::
proto
::
VarType
::
Type
data_type
);
void
SetDataType
(
VarDescAPI
::
VarData
Type
data_type
);
void
SetDataTypes
(
void
SetDataTypes
(
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
multiple_data_type
);
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
multiple_data_type
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录