提交 b62c0d94 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #642 from xiebaiyuan/develop

增加可被x86编译运行的量化脚本,并增加量化脚本说明文档.
......@@ -9,7 +9,6 @@ option(LOG_PROFILE "log profile" ON)
option(CPU "armv7 with neon" ON)
option(MALI_GPU "mali gpu" OFF)
option(FPGA "fpga" OFF)
option(QUANTI "quantification" OFF)
file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm)
file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h)
......@@ -163,7 +162,4 @@ if(DEBUGING)
endif()
endif()
if (QUANTI)
add_subdirectory(tools/quantification)
endif ()
# Quantification 模型量化、反量化
## 背景故事
部分网络如AlexNet训练出的模型体积较大,不适宜在移动设备上使用。
## 解决模型过大办法
1. 选用适合移动端的模型结构如:mobilenet、googlenet、 yolo、squeezenet 等;
2. 使用我们提供的量化工具,可以在几乎不影响精度的情况下将float32模型减小至原模型的 1/4;
- - - - -
## 量化工具介绍
### 模型转化工具目录:
- [量化工具目录](https://github.com/PaddlePaddle/paddle-mobile/tree/develop/tools/quantification)
- [模型转化工具](https://github.com/PaddlePaddle/paddle-mobile/blob/develop/tools/quantification/convert.cpp)
#### 使用说明
- [工具使用](https://github.com/PaddlePaddle/paddle-mobile/blob/develop/tools/quantification/README.md)
## 如何读取量化后的模型
load方法中添加了 quantification 参数,默认为false。 如果需要load量化后的模型,按需传参即可。
[我是源代码](https://github.com/PaddlePaddle/paddle-mobile/blob/55302b33ea3bd68c9797d8f65e527544792b8095/src/io/paddle_mobile.h)
```c++
bool Load(const std::string &dirname, bool optimize = false,
bool quantification = false, int batch_size = 1);
```
- - - - -
set(dir ${CMAKE_CURRENT_SOURCE_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${dir}/build")
cmake_minimum_required(VERSION 3.6)
project(quali)
add_definitions(-DENABLE_EXCEPTION)
ADD_EXECUTABLE(convert convert.cpp)
target_link_libraries(convert paddle-mobile)
\ No newline at end of file
set(CMAKE_CXX_STANDARD 11)
file(GLOB_RECURSE QULIFICATON_CC src/*.cc src/*.cpp src/*.c src/*.mm)
file(GLOB_RECURSE QULIFICATON_H src/*.h)
include_directories(. src/)
#add_library(paddle-mobile SHARED ${QULIFICATON_CC} ${QULIFICATON_H} convert.cpp)
add_executable(quantify convert.cpp ${QULIFICATON_CC} ${QULIFICATON_H})
\ No newline at end of file
# 模型量化脚本
#### 量化脚本使用指南
1. 在PaddleMobile项目目录下(如 ~/PaddleProject/paddle-mobile)
2. cd到 tools/quantification/ 目录
3. cmake编译
``` sh
cmake .
make
```
4. 运行量化脚本
```sh
./quantify (0:seperated. 1:combined ) (输入路径) (输出路径)
# quantify googlenet seperated from /Users/xiebaiyuan/PaddleProject/quali/models/googlenet to ./googlenet_min
./quantify 0 /Users/xiebaiyuan/PaddleProject/quali/models/googlenet ./googlenet_min
```
*注:*
*量化工具中*
*1.seperated模型model文件默认命名为 "__model__";*
*2.combined模型的model文件默认命名为 "model",参数文件默认命名为"params";*
##### 整体如下:
以googlenet非combined为例:
```sh
cd tools/quantification/
cmake .
make
./quantify 0 /Users/xiebaiyuan/PaddleProject/quali/models/googlenet ./googlenet_min
```
#include "io/paddle_mobile.h"
#include "src/enforce.h"
#include "src/var_desc.h"
#include "src/program_desc.h"
#include <cstdlib>
using std::string;
#include <string>
#include <cmath>
#include <iostream>
#include <utility>
#include <vector>
#include "src/framework.pb-c.h"
#include "src/protobuf-c.h"
#include <fstream>
#include <iostream>
static const std::string g_googlenet_combine = "../models/googlenet_combine";
static const std::string g_googlenet = "../models/googlenet";
using paddle_mobile::Executor;
using paddle_mobile::framework::Program;
char *Get_binary_data(std::string filename) {
const size_t kSize64 = sizeof(uint64_t);
const size_t kSize32 = sizeof(uint32_t);
char *Get_binary_data(const std::string &filename) {
FILE *file = fopen(filename.c_str(), "rb");
PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ",
filename.c_str());
fseek(file, 0, SEEK_END);
int64_t size = ftell(file);
PADDLE_MOBILE_ENFORCE(size > 0, "size is too small");
rewind(file);
char *data = new char[size];
size_t bytes_read = fread(data, 1, size, file);
auto *data = new char[size];
size_t bytes_read = fread(data, 1, static_cast<size_t>(size), file);
PADDLE_MOBILE_ENFORCE(bytes_read == size,
"read binary file bytes do not match with fseek");
DLOG << "Get_binary_data end";
fclose(file);
return data;
}
static size_t ReadBuffer(const char *file_name, uint8_t **out) {
FILE *fp;
fp = fopen(file_name, "rb");
PADDLE_MOBILE_ENFORCE(fp != nullptr, " %s open failed !", file_name);
fseek(fp, 0, SEEK_END);
auto size = static_cast<size_t>(ftell(fp));
rewind(fp);
*out = reinterpret_cast<uint8_t *>(malloc(size));
size_t cur_len = 0;
size_t nread;
while ((nread = fread(*out + cur_len, 1, size - cur_len, fp)) != 0) {
cur_len += nread;
}
fclose(fp);
return cur_len;
}
void LoadWithDump(const paddle_mobile::framework::VarDesc var_desc,
paddle_mobile::framework::LoDTensor *tensor, char **data, FILE *out_file) {
std::shared_ptr<ProgramDesc> loadParams(const std::string &model_path) {
PaddleMobile__Framework__Proto__ProgramDesc *c_program;
uint8_t *buf = nullptr;
size_t read_size = ReadBuffer(model_path.c_str(), &buf);
PADDLE_MOBILE_ENFORCE(buf != nullptr, "read from __model__ is null");
c_program = paddle_mobile__framework__proto__program_desc__unpack(
nullptr, read_size, buf);
PADDLE_MOBILE_ENFORCE(c_program != nullptr, "program is null");
auto originProgramDesc = std::make_shared<ProgramDesc>(c_program);
return originProgramDesc;
}
void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char *dataP, FILE *out_file) {
// 1. version
uint32_t version = *reinterpret_cast<uint32_t *>(*data);
uint32_t version = *reinterpret_cast<uint32_t *>(dataP);
// write version
fwrite(&version, sizeof(uint32_t), 1, out_file );
(*data) += sizeof(uint32_t);
fwrite(&version, kSize32, 1, out_file);
dataP += kSize32;
// 2 Lod information
uint64_t *lod_level_ptr = new uint64_t();
memcpy(lod_level_ptr, (*data), sizeof(uint64_t));
auto *lod_level_ptr = new uint64_t();
memcpy(lod_level_ptr, dataP, kSize64);
uint64_t lod_level = 0;
// write lod Information
fwrite(&lod_level, sizeof(uint64_t), 1, out_file);
fwrite(&lod_level, kSize64, 1, out_file);
delete lod_level_ptr;
(*data) += sizeof(uint64_t);
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
dataP += kSize64;
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size = *reinterpret_cast<uint64_t *>(*data);
uint64_t size = *reinterpret_cast<uint64_t *>(dataP);
// write lod size
fwrite(&size, sizeof(uint64_t), 1, out_file);
(*data) += sizeof(uint64_t);
fwrite(&size, kSize64, 1, out_file);
(dataP) += kSize64;
std::vector<size_t> tmp(size / sizeof(size_t));
for (int k = 0; k < tmp.size(); ++k) {
tmp[k] = *reinterpret_cast<size_t *>(*data);
(*data) += sizeof(size_t);
for (unsigned long &k : tmp) {
k = *reinterpret_cast<size_t *>(dataP);
(dataP) += sizeof(size_t);
}
// write lod size vector
fwrite(&tmp, sizeof(size_t), tmp.size(), out_file );
lod[i] = tmp;
fwrite(&tmp, sizeof(size_t), tmp.size(), out_file);
}
// 3. tensor version
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(*data);
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(dataP);
// write tensor version
fwrite(&tensor_version, sizeof(uint32_t), 1, out_file);
(*data) += sizeof(uint32_t);
fwrite(&tensor_version, kSize32, 1, out_file);
(dataP) += kSize32;
// 4. tensor desc
int32_t size = *reinterpret_cast<int32_t *>(*data);
int32_t size = *reinterpret_cast<int32_t *>(dataP);
// write tensor desc
fwrite(&size, sizeof(int32_t), 1, out_file);
(*data) += sizeof(int32_t);
(dataP) += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
buf.get()[m] = (*data)[m];
buf.get()[m] = (dataP)[m];
}
fwrite(buf.get(), sizeof(char), size, out_file);
(*data) += (sizeof(char) * size);
fwrite(buf.get(), sizeof(char), static_cast<size_t>(size), out_file);
(dataP) += (sizeof(char) * size);
const paddle_mobile::framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
}
tensor->Resize(paddle_mobile::framework::make_ddim(desc.Dims()));
void *memory = tensor;
void *memory = nullptr;
int type_size = 0;
switch (desc.DataType()) {
case paddle_mobile::framework::VARTYPE_TYPE_FP16:
......@@ -93,7 +137,6 @@ using paddle_mobile::framework::Program;
break;
case paddle_mobile::framework::VARTYPE_TYPE_FP32:
type_size = 4;
memory = tensor->mutable_data<float>();
break;
case paddle_mobile::framework::VARTYPE_TYPE_FP64:
type_size = 8;
......@@ -110,91 +153,121 @@ using paddle_mobile::framework::Program;
default:
break;
}
for (int n = 0; n < memory_size * type_size; ++n) {
static_cast<char *>(memory)[n] = (*data)[n];
size_t tensorSize = sizeof(char) * memory_size * type_size;
memory = new char[tensorSize];
for (int n = 0; n < tensorSize; ++n) {
static_cast<char *>(memory)[n] = (dataP)[n];
}
(*data) += (sizeof(char) * memory_size * type_size);
dataP += tensorSize;
// for float 32
float min_value = std::numeric_limits<float>::max();
float max_value = std::numeric_limits<float>::min();
for (int k = 0; k < memory_size; ++k) {
min_value = std::min(min_value, static_cast<float *> (memory)[k]);
max_value = std::max(max_value, static_cast<float *> (memory)[k]);
}
fwrite(&min_value, sizeof(float), 1, out_file);
fwrite(&max_value, sizeof(float), 1, out_file);
for (int g = 0; g < memory_size; ++g) {
float value = static_cast<float *> (memory)[g];
uint8_t factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255);
auto factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255);
fwrite(&factor, sizeof(uint8_t), 1, out_file);
}
}
void
quantificate_combined(const std::string &model_path, const std::string &param_path, const std::string &param_min_path) {
}
void quantificate_combined(std::string model_path, std::string param_path, std::string param_min_path){
paddle_mobile::Loader<paddle_mobile::CPU,paddle_mobile::Precision::FP32 > loader;
bool optimize = true;
auto program = loader.Load(model_path, param_path, optimize);
char *origin_data = Get_binary_data(program.para_path);
auto program = loadParams(model_path);
char *origin_data = Get_binary_data(param_path);
char *data = origin_data;
FILE *out_file = fopen(param_min_path.c_str(), "wb");
for (const auto &block : program.originProgram->Blocks()) {
for (const auto &block : program->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = program.scope->Var(var_desc->Name());
if(var_desc ->Persistable()) {
auto tensor = var->template GetMutable<paddle_mobile::framework::LoDTensor>();
if (var_desc->Persistable()) {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
LoadWithDump(*var_desc, tensor, &data,out_file);
LoadWithDump(*var_desc, data, out_file);
}
}
}
fclose(out_file);
delete origin_data;
}
void quantificate_seperated(std::string model_dir, std::string param_min_path) {
paddle_mobile::Loader<paddle_mobile::CPU,paddle_mobile::Precision::FP32 > loader;
bool optimize = true;
auto program = loader.Load(model_dir, optimize);
std::string shell_command = "mkdir "+param_min_path;
}
void quantificate_seperated(const std::string model_dir, const std::string param_min_path) {
auto program = loadParams(model_dir + "/__model__");
std::string shell_command = "mkdir " + param_min_path;
system(shell_command.c_str());
for (const auto &block : program.originProgram->Blocks()) {
for (const auto &block : program->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = program.scope->Var(var_desc->Name());
if(var_desc ->Persistable()) {
auto tensor = var->template GetMutable<paddle_mobile::framework::LoDTensor>();
if (var_desc->Persistable()) {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
std::string file_name = param_min_path +"/"+ var_desc->Name();
std::string file_name = param_min_path + "/" + var_desc->Name();
FILE *out_file = fopen(file_name.c_str(), "wb");
char *origin_data =
Get_binary_data(program.model_path + "/" + var_desc->Name());
char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name());
char *data = origin_data;
LoadWithDump(*var_desc, tensor, &data,out_file);
LoadWithDump(*var_desc, data, out_file);
delete origin_data;
fclose(out_file);
}
}
}
}
int main(int argc, char **argv) {
const std::string kNoteEg = "( eg: ./quantify 1 your_combined_model_path output_path or ./quantify 0 your_seperated_model_path output_path)";
PADDLE_MOBILE_ENFORCE(argc > 1, "wee need params.%s ", kNoteEg.c_str());
std::string action_type = argv[1];
PADDLE_MOBILE_ENFORCE(argc > 1 && (action_type) == "1" || action_type == "0",
"only 1 or 2 supported, current is %s %s ",
action_type.c_str(),
kNoteEg.c_str());
PADDLE_MOBILE_ENFORCE(argc > 2, "we need your model path. %s ", kNoteEg.c_str());
std::string base_path = argv[2];
PADDLE_MOBILE_ENFORCE(argc > 3, "we need your output path. %s ", kNoteEg.c_str());
std::string output_path = argv[3];
if (action_type == "0") {
// for seperated
const std::string &seperated_min_dir = output_path;
quantificate_seperated(base_path, seperated_min_dir);
return 0;
}
int main() {
std::string filename = "params_min";
std::string model_path = g_googlenet_combine + "/model";
std::string param_path = g_googlenet_combine + "/params";
std::string dirname = "param_min_dir";
std::string model_dir = g_googlenet;
// quantificate_combined(model_path, param_path,filename);
quantificate_seperated(model_dir, dirname);
if (action_type == "1") {
// for combined
const std::string &combined_min_dir = output_path;
std::string model_path = base_path + "/model";
std::string param_path = base_path + "/params";
quantificate_combined(model_path, param_path, combined_min_dir);
return 0;
}
return -1;
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//
// Created by 谢柏渊 on 2018/7/25.
//
#include "src/block_desc_local.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "src/framework.pb-c.h"
std::vector<std::shared_ptr<paddle_mobile::framework::VarDesc>>
BlockDesc::Vars() const {
return vars_;
}
BlockDesc::BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc)
: index_(desc->idx), parent_index_(desc->idx) {
for (int i = 0; i < desc->n_vars; ++i) {
PaddleMobile__Framework__Proto__VarDesc *var_desc = desc->vars[i];
vars_.emplace_back(std::shared_ptr<paddle_mobile::framework::VarDesc>(
new paddle_mobile::framework::VarDesc(var_desc)));
}
std::sort(vars_.begin(), vars_.end(),
[](std::shared_ptr<paddle_mobile::framework::VarDesc> left,
std::shared_ptr<paddle_mobile::framework::VarDesc> right) {
return left->Name() < right->Name();
});
// for (int j = 0; j < desc->n_ops; ++j) {
// PaddleMobile__Framework__Proto__OpDesc *op_desc = desc->ops[j];
// ops_.emplace_back(new OpDesc(op_desc));
// }
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//
// Created by 谢柏渊 on 2018/7/25.
//
#ifndef TOOLS_QUANTIFICATION_SRC_BLOCK_DESC_LOCAL_H_
#define TOOLS_QUANTIFICATION_SRC_BLOCK_DESC_LOCAL_H_
#include <vector>
#include "src/var_desc.h"
class BlockDesc {
public:
friend class Node;
friend class ProgramOptimize;
BlockDesc() {}
explicit BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc);
const int &ID() const { return index_; }
const bool &MultiThread() const { return multi_thread_; }
const int &Parent() const { return parent_index_; }
bool operator==(const BlockDesc &in_block) const {
return this->ID() == in_block.ID() && this->Parent() == in_block.Parent();
}
bool operator<(const BlockDesc &in_block) const {
return this->ID() < in_block.ID() && this->Parent() < in_block.Parent();
}
std::vector<std::shared_ptr<paddle_mobile::framework::VarDesc>> Vars() const;
private:
int index_;
bool multi_thread_;
int parent_index_;
std::vector<std::shared_ptr<paddle_mobile::framework::VarDesc>> vars_;
};
#endif // TOOLS_QUANTIFICATION_SRC_BLOCK_DESC_LOCAL_H_
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef ENABLE_EXCEPTION
#include <stdio.h>
#include <exception>
#include <string>
#endif
namespace paddle_mobile {
#ifdef ENABLE_EXCEPTION
struct PaddleMobileException : public std::exception {
const std::string exception_prefix = "paddle mobile C++ Exception: \n";
std::string message;
PaddleMobileException(const char *header, const char *detail,
const char *file, const int line) {
char buffer[1500];
snprintf(buffer, sizeof(buffer),
"%s| %s \n| [in file] : %s\n| [on line] : %d\n| [detail] : %s\n",
exception_prefix.c_str(), header, file, line, detail);
message = std::string(buffer);
}
const char *what() const noexcept { return message.c_str(); }
};
#define PADDLE_MOBILE_THROW_EXCEPTION(...) \
{ \
char buffer[1000]; \
snprintf(buffer, sizeof(buffer), __VA_ARGS__); \
std::string detail(buffer); \
throw paddle_mobile::PaddleMobileException("Custom Exception", buffer, \
__FILE__, __LINE__); \
}
#define PADDLE_MOBILE_ENFORCE(stat, ...) \
{ \
if (stat) { \
} else { \
char buffer[1000]; \
snprintf(buffer, sizeof(buffer), __VA_ARGS__); \
std::string detail(buffer); \
throw paddle_mobile::PaddleMobileException("paddle-mobile enforce", \
buffer, __FILE__, __LINE__); \
} \
}
#else
#define PADDLE_MOBILE_THROW_EXCEPTION(...)
#define PADDLE_MOBILE_ENFORCE(stat, ...)
#endif
} // namespace paddle_mobile
此差异已折叠。
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//
// Created by 谢柏渊 on 2018/7/25.
//
#include "src/program_desc.h"
#include <vector>
ProgramDesc::ProgramDesc(PaddleMobile__Framework__Proto__ProgramDesc *desc) {
for (int i = 0; i < desc->n_blocks; ++i) {
blocks_.emplace_back(std::make_shared<BlockDesc>(desc->blocks[i]));
}
}
const std::vector<std::shared_ptr<BlockDesc>> ProgramDesc::Blocks() {
return blocks_;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//
// Created by 谢柏渊 on 2018/7/25.
//
#ifndef TOOLS_QUANTIFICATION_SRC_PROGRAM_DESC_H_
#define TOOLS_QUANTIFICATION_SRC_PROGRAM_DESC_H_
#include <memory>
#include <vector>
#include "src/block_desc_local.h"
#include "src/framework.pb-c.h"
class ProgramDesc {
public:
// friend class Node;
//
// friend class ProgramOptimize;
explicit ProgramDesc(PaddleMobile__Framework__Proto__ProgramDesc *desc);
const std::vector<std::shared_ptr<BlockDesc>> Blocks();
private:
std::vector<std::shared_ptr<BlockDesc>> blocks_;
};
#endif // TOOLS_QUANTIFICATION_SRC_PROGRAM_DESC_H_
此差异已折叠。
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "src/framework.pb-c.h"
namespace paddle_mobile {
namespace framework {
enum VarType_Type {
VARTYPE_TYPE_BOOL = 0,
VARTYPE_TYPE_INT16 = 1,
VARTYPE_TYPE_INT32 = 2,
VARTYPE_TYPE_INT64 = 3,
VARTYPE_TYPE_FP16 = 4,
VARTYPE_TYPE_FP32 = 5,
VARTYPE_TYPE_FP64 = 6,
VARTYPE_TYPE_LOD_TENSOR = 7,
VARTYPE_TYPE_SELECTED_ROWS = 8,
VARTYPE_TYPE_FEED_MINIBATCH = 9,
VARTYPE_TYPE_FETCH_LIST = 10,
VARTYPE_TYPE_STEP_SCOPES = 11,
VARTYPE_TYPE_STEP_LOD_RANK_TABLE = 12,
VARTYPE_TYPE_STEP_LOD_TENSOR_ARRAY = 13,
VARTYPE_TYPE_STEP_PLACE_LIST = 14,
VARTYPE_TYPE_READER = 15,
VARTYPE_TYPE_CHANNEL = 16,
VARTYPE_TYPE_RAW = 17,
VARTYPE_TYPE_TUPLE = 18
};
class TensorDesc {
public:
TensorDesc() = default;
TensorDesc(const TensorDesc &desc) {
this->dims_ = desc.dims_;
this->data_type_ = desc.data_type_;
}
explicit TensorDesc(
PaddleMobile__Framework__Proto__VarType__TensorDesc *desc) {
for (int i = 0; i < desc->n_dims; ++i) {
int64_t d = desc->dims[i];
dims_.emplace_back(d);
}
data_type_ = (VarType_Type)desc->data_type;
}
std::vector<int64_t> Dims() const { return dims_; }
VarType_Type DataType() const { return data_type_; }
private:
std::vector<int64_t> dims_;
VarType_Type data_type_;
};
} // namespace framework
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "src/framework.pb-c.h"
#include "src/tensor_desc.h"
namespace paddle_mobile {
namespace framework {
class VarDesc {
public:
VarDesc(const VarDesc &var_desc) {
this->data_type_ = var_desc.data_type_;
this->name_ = var_desc.name_;
this->persistable_ = var_desc.persistable_;
this->tensor_desc_ = var_desc.tensor_desc_;
this->type_ = var_desc.type_;
}
explicit VarDesc(PaddleMobile__Framework__Proto__VarDesc *desc) {
type_ = (VarType_Type)desc->type->type;
name_ = std::string(desc->name);
persistable_ = static_cast<bool>(desc->persistable);
switch (type_) {
case VARTYPE_TYPE_SELECTED_ROWS:
tensor_desc_ = TensorDesc(desc->type->selected_rows);
break;
case VARTYPE_TYPE_LOD_TENSOR:
tensor_desc_ = TensorDesc(desc->type->lod_tensor->tensor);
break;
case VARTYPE_TYPE_STEP_LOD_TENSOR_ARRAY:
// desc->type->tensor_array->tensor->data_type;
tensor_desc_ = TensorDesc(desc->type->tensor_array->tensor);
break;
default:
break;
}
switch (type_) {
case VARTYPE_TYPE_CHANNEL:
data_type_ = (VarType_Type)desc->type->channel->data_type;
break;
default:
data_type_ = tensor_desc_.DataType();
break;
}
}
std::string Name() const { return name_; }
VarType_Type Type() const { return type_; }
bool Persistable() const { return persistable_; }
const TensorDesc &Tensor_desc() const { return tensor_desc_; }
private:
std::string name_;
bool persistable_;
TensorDesc tensor_desc_;
VarType_Type type_;
VarType_Type data_type_;
};
} // namespace framework
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册