Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
d51324bf
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d51324bf
编写于
8月 04, 2020
作者:
石
石晓伟
提交者:
GitHub
8月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
new class: ParamDesc, test=develop (#4009)
* add ParamDesc, test=develop * serialize tensor funcs, test=develop
上级
473db814
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
648 addition
and
45 deletion
+648
-45
cmake/external/flatbuffers.cmake
cmake/external/flatbuffers.cmake
+4
-1
lite/api/CMakeLists.txt
lite/api/CMakeLists.txt
+2
-2
lite/api/android/jni/native/CMakeLists.txt
lite/api/android/jni/native/CMakeLists.txt
+2
-2
lite/model_parser/base/apis.h
lite/model_parser/base/apis.h
+1
-0
lite/model_parser/base/param_desc.h
lite/model_parser/base/param_desc.h
+88
-0
lite/model_parser/base/traits.h
lite/model_parser/base/traits.h
+73
-0
lite/model_parser/base/var_desc.h
lite/model_parser/base/var_desc.h
+1
-31
lite/model_parser/flatbuffers/CMakeLists.txt
lite/model_parser/flatbuffers/CMakeLists.txt
+5
-4
lite/model_parser/flatbuffers/io.cc
lite/model_parser/flatbuffers/io.cc
+39
-0
lite/model_parser/flatbuffers/io.h
lite/model_parser/flatbuffers/io.h
+15
-0
lite/model_parser/flatbuffers/op_desc.h
lite/model_parser/flatbuffers/op_desc.h
+3
-2
lite/model_parser/flatbuffers/param.fbs
lite/model_parser/flatbuffers/param.fbs
+37
-0
lite/model_parser/flatbuffers/param_desc.cc
lite/model_parser/flatbuffers/param_desc.cc
+15
-0
lite/model_parser/flatbuffers/param_desc.h
lite/model_parser/flatbuffers/param_desc.h
+216
-0
lite/model_parser/flatbuffers/traits.h
lite/model_parser/flatbuffers/traits.h
+144
-0
lite/model_parser/flatbuffers/var_desc.h
lite/model_parser/flatbuffers/var_desc.h
+3
-3
未找到文件。
cmake/external/flatbuffers.cmake
浏览文件 @
d51324bf
...
...
@@ -100,7 +100,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
${
OPT
}
-o
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SRC_FBS_DIR
}
"
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SRC_FBS
}
"
DEPENDS flatbuffers
DEPENDS flatbuffers
${
SRC_FBS
}
COMMENT
"Run generation: '
${
GEN_HEADER
}
'"
)
register_generated_output
(
${
GEN_HEADER
}
)
add_custom_target
(
${
TARGET
}
ALL DEPENDS
${
GEN_HEADER
}
)
...
...
@@ -108,7 +108,10 @@ endfunction()
set
(
FRAMEWORK_FBS_DIR
"lite/model_parser/flatbuffers"
)
set
(
FRAMEWORK_SCHEMA_PATH
"
${
FRAMEWORK_FBS_DIR
}
/framework.fbs"
)
set
(
PARAM_SCHEMA_PATH
"
${
FRAMEWORK_FBS_DIR
}
/param.fbs"
)
compile_flatbuffers_schema_to_cpp_opt
(
framework_fbs_header
${
FRAMEWORK_SCHEMA_PATH
}
"--no-includes;--gen-compare;--force-empty"
)
compile_flatbuffers_schema_to_cpp_opt
(
param_fbs_header
${
PARAM_SCHEMA_PATH
}
"--no-includes;--gen-compare;--force-empty"
)
include_directories
(
${
FLATBUFFERS_INCLUDE_DIR
}
)
include_directories
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SRC_FBS_DIR
}
)
add_custom_target
(
fbs_headers ALL DEPENDS framework_fbs_header param_fbs_header
)
lite/api/CMakeLists.txt
浏览文件 @
d51324bf
...
...
@@ -16,7 +16,7 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR LITE_WITH
lite_cc_library
(
paddle_full_api_shared SHARED SRCS paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc
DEPS paddle_api paddle_api_light paddle_api_full
)
target_sources
(
paddle_full_api_shared PUBLIC
${
__lite_cc_files
}
)
add_dependencies
(
paddle_full_api_shared op_list_h kernel_list_h framework_proto op_registry f
ramework_fbs_header
)
add_dependencies
(
paddle_full_api_shared op_list_h kernel_list_h framework_proto op_registry f
bs_headers
)
target_link_libraries
(
paddle_full_api_shared framework_proto op_registry
)
if
(
LITE_WITH_X86
)
add_dependencies
(
paddle_full_api_shared xxhash
)
...
...
@@ -72,7 +72,7 @@ else()
set
(
TARGET_COMIPILE_FLAGS
"
${
TARGET_COMIPILE_FLAGS
}
-flto"
)
endif
()
set_target_properties
(
paddle_light_api_shared PROPERTIES COMPILE_FLAGS
"
${
TARGET_COMIPILE_FLAGS
}
"
)
add_dependencies
(
paddle_light_api_shared op_list_h kernel_list_h f
ramework_fbs_header
)
add_dependencies
(
paddle_light_api_shared op_list_h kernel_list_h f
bs_headers
)
if
(
LITE_WITH_NPU
)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries
(
paddle_light_api_shared
${
npu_builder_libs
}
${
npu_runtime_libs
}
)
...
...
lite/api/android/jni/native/CMakeLists.txt
浏览文件 @
d51324bf
...
...
@@ -17,7 +17,7 @@ if (NOT LITE_ON_TINY_PUBLISH)
# Unlike static library, module library has to link target to be able to work
# as a single .so lib.
target_link_libraries
(
paddle_lite_jni
${
lib_DEPS
}
${
arm_kernels
}
${
npu_kernels
}
)
add_dependencies
(
paddle_lite_jni f
ramework_fbs_header
)
add_dependencies
(
paddle_lite_jni f
bs_headers
)
if
(
LITE_WITH_NPU
)
# Strips the symbols of our protobuf functions to fix the conflicts during
# loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so)
...
...
@@ -32,7 +32,7 @@ else()
endif
()
set_target_properties
(
paddle_lite_jni PROPERTIES COMPILE_FLAGS
${
TARGET_COMIPILE_FLAGS
}
)
target_sources
(
paddle_lite_jni PUBLIC
${
__lite_cc_files
}
paddle_lite_jni.cc tensor_jni.cc
)
add_dependencies
(
paddle_lite_jni op_list_h kernel_list_h f
ramework_fbs_header
)
add_dependencies
(
paddle_lite_jni op_list_h kernel_list_h f
bs_headers
)
if
(
LITE_WITH_NPU
)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries
(
paddle_lite_jni
${
npu_builder_libs
}
${
npu_runtime_libs
}
)
...
...
lite/model_parser/base/apis.h
浏览文件 @
d51324bf
...
...
@@ -16,6 +16,7 @@
#include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/base/op_desc.h"
#include "lite/model_parser/base/param_desc.h"
#include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/base/proto_desc.h"
#include "lite/model_parser/base/traits.h"
...
...
lite/model_parser/base/param_desc.h
0 → 100644
浏览文件 @
d51324bf
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/model_parser/base/traits.h"
#include "lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
class
ParamDescReadAPI
{
public:
virtual
std
::
string
Name
()
const
=
0
;
virtual
std
::
vector
<
int64_t
>
Dim
()
const
=
0
;
virtual
VarDataType
GetDataType
()
const
=
0
;
virtual
const
void
*
GetData
()
const
=
0
;
virtual
size_t
byte_size
()
const
=
0
;
virtual
~
ParamDescReadAPI
()
=
default
;
};
class
ParamDescWriteAPI
{
public:
virtual
void
SetName
(
const
std
::
string
&
name
)
{
NotImplemented
();
}
virtual
void
SetDim
(
const
std
::
vector
<
int64_t
>
&
dim
)
{
NotImplemented
();
}
virtual
void
SetDataType
(
VarDataType
data_type
)
{
NotImplemented
();
}
virtual
void
SetData
(
const
void
*
data
,
size_t
byte_size
)
{
NotImplemented
();
}
virtual
~
ParamDescWriteAPI
()
=
default
;
private:
void
NotImplemented
()
const
{
LOG
(
FATAL
)
<<
"ParamDescWriteAPI is not available in model read-only mode."
;
}
};
class
CombinedParamsDescReadAPI
{
public:
virtual
const
ParamDescReadAPI
*
GetParamDesc
(
size_t
idx
)
const
=
0
;
virtual
size_t
GetParamsSize
()
const
=
0
;
virtual
~
CombinedParamsDescReadAPI
()
=
default
;
};
class
CombinedParamsDescWriteAPI
{
public:
virtual
ParamDescWriteAPI
*
AddParamDesc
()
{
NotImplemented
();
return
nullptr
;
}
virtual
~
CombinedParamsDescWriteAPI
()
=
default
;
private:
void
NotImplemented
()
const
{
LOG
(
FATAL
)
<<
"CombinedParamsDescWriteAPI is not available in model "
"read-only mode."
;
}
};
// The reading and writing of the model are one-time and separate.
// This interface is a combination of reading and writing interfaces,
// which is used to support legacy interfaces.
class
ParamDescAPI
:
public
ParamDescReadAPI
,
public
ParamDescWriteAPI
{
public:
virtual
~
ParamDescAPI
()
=
default
;
};
class
CombinedParamsDescAPI
:
public
CombinedParamsDescReadAPI
,
public
CombinedParamsDescWriteAPI
{
public:
virtual
~
CombinedParamsDescAPI
()
=
default
;
};
}
// namespace lite
}
// namespace paddle
lite/model_parser/base/traits.h
浏览文件 @
d51324bf
...
...
@@ -16,6 +16,8 @@
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -37,6 +39,77 @@ enum class OpAttrType {
UNK
,
};
enum
class
VarDataType
{
// Pod Types
BOOL
=
0
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
FP64
,
// Tensor<size_t> is used in C++.
SIZE_T
,
UINT8
,
INT8
,
// Other types that may need additional descriptions
LOD_TENSOR
,
SELECTED_ROWS
,
FEED_MINIBATCH
,
FETCH_LIST
,
STEP_SCOPES
,
LOD_RANK_TABLE
,
LOD_TENSOR_ARRAY
,
PLACE_LIST
,
READER
,
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW
,
TUPLE
};
inline
VarDataType
ConvertPrecisionType
(
lite_api
::
PrecisionType
type
)
{
#define CASE(ptype, vtype) \
case lite_api::PrecisionType::k##ptype: \
return lite::VarDataType::vtype; \
break
switch
(
type
)
{
CASE
(
Float
,
FP32
);
CASE
(
Int8
,
INT8
);
CASE
(
Int32
,
INT32
);
CASE
(
FP16
,
FP16
);
CASE
(
Bool
,
BOOL
);
CASE
(
Int64
,
INT64
);
CASE
(
Int16
,
INT16
);
default:
LOG
(
FATAL
)
<<
"Illegal flatbuffer VarType."
;
return
lite
::
VarDataType
();
}
#undef CASE
}
inline
lite_api
::
PrecisionType
ConvertPrecisionType
(
VarDataType
type
)
{
#define CASE(ptype, vtype) \
case lite::VarDataType::vtype: \
return lite_api::PrecisionType::k##ptype; \
break
switch
(
type
)
{
CASE
(
Float
,
FP32
);
CASE
(
Int8
,
INT8
);
CASE
(
Int32
,
INT32
);
CASE
(
FP16
,
FP16
);
CASE
(
Bool
,
BOOL
);
CASE
(
Int64
,
INT64
);
CASE
(
Int16
,
INT16
);
default:
LOG
(
FATAL
)
<<
"Illegal flatbuffer VarType."
;
return
lite_api
::
PrecisionType
();
}
#undef CASE
}
struct
Standard
{};
struct
Flatbuffers
{};
...
...
lite/model_parser/base/var_desc.h
浏览文件 @
d51324bf
...
...
@@ -16,42 +16,12 @@
#include <string>
#include <vector>
#include "lite/model_parser/base/traits.h"
#include "lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
enum
class
VarDataType
{
// Pod Types
BOOL
=
0
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
FP64
,
// Tensor<size_t> is used in C++.
SIZE_T
,
UINT8
,
INT8
,
// Other types that may need additional descriptions
LOD_TENSOR
,
SELECTED_ROWS
,
FEED_MINIBATCH
,
FETCH_LIST
,
STEP_SCOPES
,
LOD_RANK_TABLE
,
LOD_TENSOR_ARRAY
,
PLACE_LIST
,
READER
,
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW
,
TUPLE
};
class
VarDescReadAPI
{
public:
virtual
std
::
string
Name
()
const
=
0
;
...
...
lite/model_parser/flatbuffers/CMakeLists.txt
浏览文件 @
d51324bf
...
...
@@ -5,9 +5,10 @@ function(lite_fbs_library TARGET)
add_dependencies
(
${
TARGET
}
${
args_FBS_DEPS
}
)
endfunction
()
lite_fbs_library
(
fbs_op_desc SRCS op_desc.cc FBS_DEPS f
ramework_fbs_header
)
lite_fbs_library
(
fbs_var_desc SRCS var_desc.cc FBS_DEPS f
ramework_fbs_header
)
lite_fbs_library
(
fbs_block_desc SRCS block_desc.cc FBS_DEPS f
ramework_fbs_header
)
lite_fbs_library
(
fbs_op_desc SRCS op_desc.cc FBS_DEPS f
bs_headers
)
lite_fbs_library
(
fbs_var_desc SRCS var_desc.cc FBS_DEPS f
bs_headers
)
lite_fbs_library
(
fbs_block_desc SRCS block_desc.cc FBS_DEPS f
bs_headers
)
lite_cc_library
(
fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc
)
lite_cc_library
(
fbs_io SRCS io.cc DEPS fbs_program_desc
)
lite_fbs_library
(
fbs_param_desc SRCS param_desc.cc FBS_DEPS fbs_headers
)
lite_cc_library
(
fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc
)
lite_cc_test
(
test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc
)
lite/model_parser/flatbuffers/io.cc
浏览文件 @
d51324bf
...
...
@@ -13,9 +13,11 @@
// limitations under the License.
#include "lite/model_parser/flatbuffers/io.h"
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include "lite/model_parser/flatbuffers/traits.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -33,6 +35,43 @@ void LoadModel(const std::string& path, ProgramDesc* prog) {
prog
->
Init
(
std
::
move
(
buf
));
}
void
SetParamWithTensor
(
const
std
::
string
&
name
,
const
lite
::
Tensor
&
tensor
,
ParamDescWriteAPI
*
prog
)
{
CHECK
(
prog
);
prog
->
SetName
(
name
);
prog
->
SetDim
(
tensor
.
dims
().
Vectorize
());
prog
->
SetDataType
(
lite
::
ConvertPrecisionType
(
tensor
.
precision
()));
prog
->
SetData
(
tensor
.
raw_data
(),
tensor
.
memory_size
());
}
void
SetTensorWithParam
(
lite
::
Tensor
*
tensor
,
const
ParamDescReadAPI
&
param
)
{
tensor
->
Resize
(
param
.
Dim
());
tensor
->
set_precision
(
lite
::
ConvertPrecisionType
(
param
.
GetDataType
()));
std
::
memcpy
(
tensor
->
mutable_data
(
param
.
byte_size
()),
param
.
GetData
(),
param
.
byte_size
());
}
void
SetCombinedParamsWithScope
(
const
lite
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
params_name
,
CombinedParamsDescWriteAPI
*
params
)
{
for
(
const
auto
&
name
:
params_name
)
{
auto
*
param
=
params
->
AddParamDesc
();
auto
&
tensor
=
scope
.
FindVar
(
name
)
->
Get
<
lite
::
Tensor
>
();
SetParamWithTensor
(
name
,
tensor
,
param
);
}
}
void
SetScopeWithCombinedParams
(
lite
::
Scope
*
scope
,
const
CombinedParamsDescReadAPI
&
params
)
{
CHECK
(
scope
);
for
(
size_t
i
=
0
;
i
<
params
.
GetParamsSize
();
++
i
)
{
const
auto
&
param
=
*
params
.
GetParamDesc
(
i
);
auto
*
tensor
=
scope
->
Var
(
param
.
Name
())
->
GetMutable
<
lite
::
Tensor
>
();
SetTensorWithParam
(
tensor
,
param
);
}
}
}
// namespace fbs
}
// namespace lite
}
// namespace paddle
lite/model_parser/flatbuffers/io.h
浏览文件 @
d51324bf
...
...
@@ -15,6 +15,10 @@
#pragma once
#include <string>
#include <vector>
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/model_parser/flatbuffers/param_desc.h"
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace
paddle
{
...
...
@@ -23,6 +27,17 @@ namespace fbs {
void
LoadModel
(
const
std
::
string
&
path
,
ProgramDesc
*
prog
);
void
SetParamWithTensor
(
const
std
::
string
&
name
,
const
lite
::
Tensor
&
tensor
,
ParamDescWriteAPI
*
prog
);
void
SetTensorWithParam
(
const
lite
::
Tensor
&
tensor
,
ParamDescReadAPI
*
prog
);
void
SetCombinedParamsWithScope
(
const
lite
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
params_name
,
CombinedParamsDescWriteAPI
*
params
);
void
SetScopeWithCombinedParams
(
lite
::
Scope
*
scope
,
const
CombinedParamsDescReadAPI
&
params
);
}
// namespace fbs
}
// namespace lite
}
// namespace paddle
lite/model_parser/flatbuffers/op_desc.h
浏览文件 @
d51324bf
...
...
@@ -21,6 +21,7 @@
#include "lite/model_parser/base/op_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/traits.h"
#include "lite/model_parser/flatbuffers/vector_view.h"
#include "lite/utils/all.h"
...
...
@@ -96,13 +97,13 @@ class OpDesc : public OpDescAPI {
OpDescAPI
::
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
override
{
const
auto
&
attr
=
desc_
->
attrs
()
->
LookupByKey
(
name
.
c_str
());
CHECK
(
attr
)
<<
"Can not find attr: "
<<
name
;
return
static_cast
<
OpDescAPI
::
AttrType
>
(
attr
->
type
());
return
ConvertAttrType
(
attr
->
type
());
}
OpDescAPI
::
AttrType
GetAttrType
(
size_t
idx
)
const
{
const
auto
&
attr
=
desc_
->
attrs
()
->
Get
(
idx
);
CHECK
(
attr
);
return
static_cast
<
OpDescAPI
::
AttrType
>
(
attr
->
type
());
return
ConvertAttrType
(
attr
->
type
());
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
override
{
...
...
lite/model_parser/flatbuffers/param.fbs
0 → 100644
浏览文件 @
d51324bf
include "framework.fbs";
namespace paddle.lite.fbs.proto;
table CombinedParamsDesc {
params:[paddle.lite.fbs.proto.ParamDesc];
}
namespace paddle.lite.fbs.proto.ParamDesc_;
table LoDTensorDesc {
lod_level:int;
lod:[long];
dim:[long];
data_type:paddle.lite.fbs.proto.VarType_.Type;
data:[byte];
}
table VersionDesc {
version:int;
model_version:int;
}
union VariableDesc {
LoDTensorDesc
}
namespace paddle.lite.fbs.proto;
table ParamDesc {
version:paddle.lite.fbs.proto.ParamDesc_.VersionDesc;
name:string (required, key);
variable:paddle.lite.fbs.proto.ParamDesc_.VariableDesc;
}
root_type paddle.lite.fbs.proto.ParamDesc;
root_type paddle.lite.fbs.proto.CombinedParamsDesc;
lite/model_parser/flatbuffers/param_desc.cc
0 → 100644
浏览文件 @
d51324bf
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/param_desc.h"
lite/model_parser/flatbuffers/param_desc.h
0 → 100644
浏览文件 @
d51324bf
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/base/param_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/param_generated.h"
#include "lite/model_parser/flatbuffers/traits.h"
namespace
paddle
{
namespace
lite
{
namespace
fbs
{
class
ParamDescView
:
public
ParamDescReadAPI
{
public:
explicit
ParamDescView
(
proto
::
ParamDesc
const
*
desc
)
:
desc_
(
desc
)
{
CHECK
(
desc_
);
CHECK
(
desc_
->
variable_type
()
==
proto
::
ParamDesc_
::
VariableDesc_LoDTensorDesc
);
tensor_desc_
=
desc_
->
variable_as
<
proto
::
ParamDesc_
::
LoDTensorDesc
>
();
}
std
::
string
Name
()
const
override
{
return
desc_
->
name
()
->
c_str
();
}
std
::
vector
<
int64_t
>
Dim
()
const
override
{
const
auto
&
dims
=
tensor_desc_
->
dim
();
std
::
vector
<
int64_t
>
dims_vec
;
dims_vec
.
reserve
(
dims
->
size
());
for
(
const
auto
&
dim
:
*
dims
)
{
dims_vec
.
push_back
(
dim
);
}
return
dims_vec
;
}
VarDataType
GetDataType
()
const
override
{
return
ConvertVarType
(
tensor_desc_
->
data_type
());
}
const
void
*
GetData
()
const
override
{
return
tensor_desc_
->
data
()
->
Data
();
}
size_t
byte_size
()
const
override
{
return
tensor_desc_
->
data
()
->
size
();
}
ParamDescView
()
=
delete
;
private:
proto
::
ParamDesc
const
*
desc_
;
proto
::
ParamDesc_
::
LoDTensorDesc
const
*
tensor_desc_
;
};
class
CombinedParamsDescView
:
public
CombinedParamsDescReadAPI
{
public:
CombinedParamsDescView
()
=
default
;
explicit
CombinedParamsDescView
(
const
std
::
vector
<
char
>&
buf
)
{
Init
(
buf
);
}
explicit
CombinedParamsDescView
(
std
::
vector
<
char
>&&
buf
)
{
Init
(
std
::
forward
<
std
::
vector
<
char
>>
(
buf
));
}
void
Init
(
const
std
::
vector
<
char
>&
buf
)
{
CHECK
(
buf
.
data
());
buf_
=
buf
;
InitParams
();
}
void
Init
(
std
::
vector
<
char
>&&
buf
)
{
CHECK
(
buf
.
data
());
buf_
=
std
::
move
(
buf
);
InitParams
();
}
void
InitParams
()
{
desc_
=
proto
::
GetCombinedParamsDesc
(
buf_
.
data
());
params_
.
reserve
(
GetParamsSize
());
for
(
size_t
idx
=
0
;
idx
<
GetParamsSize
();
++
idx
)
{
params_
.
push_back
(
ParamDescView
(
desc_
->
params
()
->
Get
(
idx
)));
}
}
const
ParamDescReadAPI
*
GetParamDesc
(
size_t
idx
)
const
override
{
CHECK
(
idx
<
GetParamsSize
());
return
&
params_
[
idx
];
}
size_t
GetParamsSize
()
const
override
{
return
params_
.
size
();
}
private:
std
::
vector
<
ParamDescView
>
params_
;
std
::
vector
<
char
>
buf_
;
proto
::
CombinedParamsDesc
const
*
desc_
;
};
class
ParamDesc
:
public
ParamDescAPI
{
public:
ParamDesc
()
:
owned_
(
true
),
desc_
(
new
proto
::
ParamDescT
())
{
desc_
->
variable
.
Set
(
proto
::
ParamDesc_
::
LoDTensorDescT
());
lod_tensor_
=
desc_
->
variable
.
AsLoDTensorDesc
();
CHECK
(
lod_tensor_
);
}
explicit
ParamDesc
(
proto
::
ParamDescT
*
desc
)
:
desc_
(
desc
)
{
lod_tensor_
=
desc_
->
variable
.
AsLoDTensorDesc
();
CHECK
(
lod_tensor_
);
}
std
::
string
Name
()
const
override
{
return
desc_
->
name
;
}
void
SetName
(
const
std
::
string
&
name
)
override
{
desc_
->
name
=
name
;
}
std
::
vector
<
int64_t
>
Dim
()
const
override
{
return
lod_tensor_
->
dim
;
}
void
SetDim
(
const
std
::
vector
<
int64_t
>&
dim
)
override
{
lod_tensor_
->
dim
=
dim
;
}
VarDataType
GetDataType
()
const
override
{
return
ConvertVarType
(
lod_tensor_
->
data_type
);
}
void
SetDataType
(
VarDataType
data_type
)
override
{
lod_tensor_
->
data_type
=
ConvertVarType
(
data_type
);
}
const
void
*
GetData
()
const
override
{
return
lod_tensor_
->
data
.
data
();
}
size_t
byte_size
()
const
override
{
return
lod_tensor_
->
data
.
size
();
}
void
SetData
(
const
void
*
data
,
size_t
byte_size
)
{
lod_tensor_
->
data
.
resize
(
byte_size
);
std
::
memcpy
(
lod_tensor_
->
data
.
data
(),
data
,
byte_size
);
}
const
proto
::
ParamDescT
*
raw_desc
()
const
{
return
desc_
;
}
~
ParamDesc
()
{
if
(
owned_
)
{
delete
desc_
;
}
}
private:
bool
owned_
{
false
};
proto
::
ParamDescT
*
desc_
{
nullptr
};
proto
::
ParamDesc_
::
LoDTensorDescT
*
lod_tensor_
{
nullptr
};
};
class
CombinedParamsDesc
:
public
CombinedParamsDescAPI
{
public:
CombinedParamsDesc
()
=
default
;
explicit
CombinedParamsDesc
(
const
std
::
vector
<
char
>&
buf
)
{
const
auto
*
raw_buf
=
proto
::
GetCombinedParamsDesc
(
buf
.
data
());
raw_buf
->
UnPackTo
(
&
desc_
);
SyncParams
();
}
const
ParamDescReadAPI
*
GetParamDesc
(
size_t
idx
)
const
override
{
return
&
params_
[
idx
];
}
size_t
GetParamsSize
()
const
override
{
return
desc_
.
params
.
size
();
}
ParamDescWriteAPI
*
AddParamDesc
()
override
{
desc_
.
params
.
push_back
(
std
::
unique_ptr
<
proto
::
ParamDescT
>
());
SyncParams
();
return
&
params_
[
params_
.
size
()
-
1
];
}
const
void
*
data
()
{
SyncBuffer
();
return
buf_
.
data
();
}
size_t
buf_size
()
{
SyncBuffer
();
return
buf_
.
size
();
}
private:
void
SyncParams
()
{
params_
.
resize
(
GetParamsSize
());
for
(
size_t
i
=
0
;
i
<
GetParamsSize
();
++
i
)
{
if
(
params_
[
i
].
raw_desc
()
!=
desc_
.
params
[
i
].
get
())
{
params_
[
i
]
=
ParamDesc
(
desc_
.
params
[
i
].
get
());
}
}
}
void
SyncBuffer
()
{
fbb_
.
Reset
();
flatbuffers
::
Offset
<
proto
::
CombinedParamsDesc
>
desc
=
proto
::
CombinedParamsDesc
::
Pack
(
fbb_
,
&
desc_
);
fbb_
.
Finish
(
desc
);
buf_
=
fbb_
.
Release
();
}
flatbuffers
::
DetachedBuffer
buf_
;
flatbuffers
::
FlatBufferBuilder
fbb_
;
proto
::
CombinedParamsDescT
desc_
;
std
::
vector
<
ParamDesc
>
params_
;
};
}
// namespace fbs
}
// namespace lite
}
// namespace paddle
lite/model_parser/flatbuffers/traits.h
0 → 100644
浏览文件 @
d51324bf
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/model_parser/base/traits.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
namespace
paddle
{
namespace
lite
{
namespace
fbs
{
inline
lite
::
VarDataType
ConvertVarType
(
proto
::
VarType_
::
Type
type
)
{
#define CASE(type) \
case proto::VarType_::Type_##type: \
return lite::VarDataType::type; \
break
switch
(
type
)
{
CASE
(
BOOL
);
CASE
(
INT16
);
CASE
(
INT32
);
CASE
(
INT64
);
CASE
(
FP16
);
CASE
(
FP32
);
CASE
(
FP64
);
CASE
(
LOD_TENSOR
);
CASE
(
SELECTED_ROWS
);
CASE
(
FEED_MINIBATCH
);
CASE
(
FETCH_LIST
);
CASE
(
STEP_SCOPES
);
CASE
(
LOD_RANK_TABLE
);
CASE
(
LOD_TENSOR_ARRAY
);
CASE
(
PLACE_LIST
);
CASE
(
READER
);
CASE
(
RAW
);
CASE
(
TUPLE
);
CASE
(
SIZE_T
);
CASE
(
UINT8
);
CASE
(
INT8
);
default:
LOG
(
FATAL
)
<<
"Illegal flatbuffer VarType."
;
return
lite
::
VarDataType
();
}
#undef CASE
}
inline
proto
::
VarType_
::
Type
ConvertVarType
(
lite
::
VarDataType
type
)
{
#define CASE(type) \
case lite::VarDataType::type: \
return proto::VarType_::Type_##type; \
break
switch
(
type
)
{
CASE
(
BOOL
);
CASE
(
INT16
);
CASE
(
INT32
);
CASE
(
INT64
);
CASE
(
FP16
);
CASE
(
FP32
);
CASE
(
FP64
);
CASE
(
LOD_TENSOR
);
CASE
(
SELECTED_ROWS
);
CASE
(
FEED_MINIBATCH
);
CASE
(
FETCH_LIST
);
CASE
(
STEP_SCOPES
);
CASE
(
LOD_RANK_TABLE
);
CASE
(
LOD_TENSOR_ARRAY
);
CASE
(
PLACE_LIST
);
CASE
(
READER
);
CASE
(
RAW
);
CASE
(
TUPLE
);
CASE
(
SIZE_T
);
CASE
(
UINT8
);
CASE
(
INT8
);
default:
LOG
(
FATAL
)
<<
"Illegal flatbuffer VarType."
;
return
proto
::
VarType_
::
Type
();
}
#undef CASE
}
inline
lite
::
OpAttrType
ConvertAttrType
(
proto
::
AttrType
type
)
{
#define CASE(type) \
case proto::AttrType_##type: \
return lite::OpAttrType::type; \
break
switch
(
type
)
{
CASE
(
INT
);
CASE
(
FLOAT
);
CASE
(
STRING
);
CASE
(
INTS
);
CASE
(
FLOATS
);
CASE
(
STRINGS
);
CASE
(
BOOLEAN
);
CASE
(
BOOLEANS
);
CASE
(
BLOCK
);
CASE
(
LONG
);
CASE
(
BLOCKS
);
CASE
(
LONGS
);
default:
LOG
(
FATAL
)
<<
"Illegal flatbuffer AttrType."
;
return
lite
::
OpAttrType
();
}
#undef CASE
}
inline
proto
::
AttrType
ConvertAttrType
(
lite
::
OpAttrType
type
)
{
#define CASE(type) \
case lite::OpAttrType::type: \
return proto::AttrType_##type; \
break
switch
(
type
)
{
CASE
(
INT
);
CASE
(
FLOAT
);
CASE
(
STRING
);
CASE
(
INTS
);
CASE
(
FLOATS
);
CASE
(
STRINGS
);
CASE
(
BOOLEAN
);
CASE
(
BOOLEANS
);
CASE
(
BLOCK
);
CASE
(
LONG
);
CASE
(
BLOCKS
);
CASE
(
LONGS
);
default:
LOG
(
FATAL
)
<<
"Illegal flatbuffer AttrType."
;
return
proto
::
AttrType
();
}
#undef CASE
}
}
// namespace fbs
}
// namespace lite
}
// namespace paddle
lite/model_parser/flatbuffers/var_desc.h
浏览文件 @
d51324bf
...
...
@@ -19,6 +19,7 @@
#include <vector>
#include "lite/model_parser/base/var_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/traits.h"
#include "lite/utils/all.h"
namespace
paddle
{
...
...
@@ -32,7 +33,7 @@ class VarDesc : public VarDescAPI {
std
::
string
Name
()
const
override
{
return
desc_
->
name
()
->
str
();
}
VarDescAPI
::
Type
GetType
()
const
override
{
return
static_cast
<
VarDescAPI
::
Type
>
(
desc_
->
type
()
->
type
());
return
ConvertVarType
(
desc_
->
type
()
->
type
());
}
bool
Persistable
()
const
override
{
return
desc_
->
persistable
();
}
...
...
@@ -50,8 +51,7 @@ class VarDesc : public VarDescAPI {
VarDescAPI
::
Type
GetDataType
()
const
{
CHECK
(
GetType
()
==
VarDescAPI
::
Type
::
LOD_TENSOR
);
return
static_cast
<
VarDescAPI
::
Type
>
(
desc_
->
type
()
->
lod_tensor
()
->
tensor
()
->
data_type
());
return
ConvertVarType
(
desc_
->
type
()
->
lod_tensor
()
->
tensor
()
->
data_type
());
}
private:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录