未验证 提交 546d4da8 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework][ModelType] Add Shape&Precision information into optimized model (#3643)

上级 a24d4dd1
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/core/program.h"
#include <algorithm>
#include <unordered_map>
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
......@@ -85,48 +86,54 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
auto* scope = op->scope();
auto in_names = op->op_info()->input_names();
auto out_names = op->op_info()->output_names();
for (auto& in_name : in_names) {
auto it = origin_var_maps.find(in_name);
std::vector<std::string> var_names;
var_names.insert(var_names.end(), in_names.begin(), in_names.end());
var_names.insert(var_names.end(), out_names.begin(), out_names.end());
std::sort(var_names.begin(), var_names.end());
var_names.erase(std::unique(var_names.begin(), var_names.end()),
var_names.end());
for (auto& var_name : var_names) {
auto it = origin_var_maps.find(var_name);
if (it != origin_var_maps.end()) {
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName((it->second).Name());
v->SetType((it->second).GetType());
v->SetPersistable((it->second).Persistable());
if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") {
v->SetShape((it->second).GetShape());
v->SetDataType((it->second).GetDataType());
}
} else {
// New created vars must be LOD_TENSOR
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(in_name);
v->SetName(var_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string in_arg_name;
op->op_info()->GetInputArgname(in_name, &in_arg_name);
op->op_info()->GetInputArgname(var_name, &in_arg_name);
auto type = kernel->GetInputDeclType(in_arg_name);
if (type->IsTensor()) {
auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>();
auto tensor = scope->FindVar(var_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
} else {
CHECK(false) << "unsupported var type";
}
}
}
if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") {
v->SetShape(tensor->dims().data());
switch (tensor->precision()) {
#define SET_DATATYPE(precision__, data_type) \
case PrecisionType::precision__: \
v->SetDataType(data_type); \
break
for (auto& out_name : out_names) {
auto it = origin_var_maps.find(out_name);
if (it != origin_var_maps.end()) {
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName((it->second).Name());
v->SetType((it->second).GetType());
v->SetPersistable((it->second).Persistable());
} else {
// New created vars must be LOD_TENSOR
auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(out_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string out_arg_name;
op->op_info()->GetOutputArgname(out_name, &out_arg_name);
auto type = kernel->GetOutputDeclType(out_arg_name);
if (type->IsTensor()) {
auto tensor = scope->FindVar(out_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
#undef SET_DATATYPE
default:
LOG(FATAL) << "unknown precision type";
}
}
} else {
CHECK(false) << "unsupported var type";
}
......
......@@ -30,13 +30,17 @@ namespace paddle {
namespace lite {
/// For VarDesc transfrom
#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \
template <> \
void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \
T *any_desc) { \
any_desc->SetName(cpp_desc.Name()); \
any_desc->SetType(cpp_desc.GetType()); \
any_desc->SetPersistable(cpp_desc.Persistable()); \
#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \
template <> \
void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \
T *any_desc) { \
any_desc->SetName(cpp_desc.Name()); \
any_desc->SetType(cpp_desc.GetType()); \
any_desc->SetPersistable(cpp_desc.Persistable()); \
if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { \
any_desc->SetShape(cpp_desc.GetShape()); \
any_desc->SetDataType(cpp_desc.GetDataType()); \
} \
}
#ifndef LITE_ON_TINY_PUBLISH
......@@ -46,7 +50,10 @@ void TransformVarDescAnyToCpp<pb::VarDesc>(const pb::VarDesc &any_desc,
cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable());
cpp_desc->SetDataType(any_desc.GetDataType());
if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}
}
#endif
......@@ -56,6 +63,14 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable());
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
/* if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}*/
}
/// For OpDesc transform
......
......@@ -36,6 +36,8 @@ void SetVarDesc(VarDescType* desc) {
desc->SetName("X");
desc->SetPersistable(true);
desc->SetType(VarDescAPI::Type::LOD_TENSOR);
desc->SetShape({1, 3, 224, 224});
desc->SetDataType(VarDescAPI::VarDataType::FP32);
}
template <typename VarDescType>
......@@ -43,6 +45,8 @@ void SetVarDesc1(VarDescType* desc) {
desc->SetName("Y");
desc->SetPersistable(false);
desc->SetType(VarDescAPI::Type::SELECTED_ROWS);
desc->SetShape({1, 3, 224, 224});
desc->SetDataType(VarDescAPI::VarDataType::FP32);
}
template <typename VarDescType>
......
......@@ -14,6 +14,7 @@
#pragma once
#include <string>
#include <vector>
#include "lite/model_parser/desc_apis.h"
namespace paddle {
......@@ -46,11 +47,16 @@ class VarDesc : public VarDescAPI {
void SetDataType(Type data_type) { data_type_ = data_type; }
void SetShape(const std::vector<int64_t> &dims) { shape_ = dims; }
std::vector<int64_t> GetShape() const { return shape_; }
private:
std::string name_;
Type type_;
Type data_type_;
bool persistable_;
std::vector<int64_t> shape_;
};
} // namespace cpp
......
......@@ -76,6 +76,10 @@ class VarDescAPI {
virtual bool Persistable() const = 0;
// Set var to be persistable or not
virtual void SetPersistable(bool persistable) = 0;
// Get var's shape
virtual std::vector<int64_t> GetShape() const = 0;
// Set var's shape
virtual void SetShape(const std::vector<int64_t>& dims) = 0;
};
/*
......
......@@ -131,6 +131,57 @@ proto::VarType* VarDesc::GetMutableVarType() {
return builder;
}
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) {
/* using data_type_builder_t = EnumBuilder<proto::VarDataType>;
auto data_type_builder =
desc_->GetMutableField<proto::TensorDesc>("tensor_desc")
->GetMutableField<data_type_builder_t>("data_type");
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case VarDescAPI::VarDataType::type__: \
data_type_builder->set(proto::VarDataType::type__); \
break
switch (data_type) {
// Only support primary data type now.
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var data type";
}
#undef SET_DATA_TYPE_CASE_ITEM
*/
}
// Get var's shape
std::vector<int64_t> VarDesc::GetShape() const {
using data_type_builder_t = ListBuilder<Int64Builder>;
auto out_builder = desc_->GetField<proto::TensorDesc>("tensor_desc")
.GetField<data_type_builder_t>("dims");
return RepeatedToVector<int64_t, Int64Builder>(out_builder);
}
// Set var's shape
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
void VarDesc::SetShape(const std::vector<int64_t>& dims) {
/* using out_builder_type = ListBuilder<Int64Builder>;
auto out_builder = desc_->GetMutableField<proto::TensorDesc>("tensor_desc")
->GetMutableField<out_builder_type>("dims");
CHECK(out_builder);
VectorToRepeated<int64_t, Int64Builder>(dims, out_builder);*/
}
} // namespace naive_buffer
} // namespace lite
} // namespace paddle
......@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/naive_buffer/naive_buffer_wrapper_helper.h"
#include "lite/model_parser/naive_buffer/proto/framework.nb.h"
namespace paddle {
......@@ -51,8 +52,14 @@ class VarDesc : public VarDescAPI {
void SetPersistable(bool persistable) override;
void SetDataType(VarDescAPI::VarDataType data_type);
VarDescAPI::VarDataType GetDataType() const;
// Get var's shape
std::vector<int64_t> GetShape() const;
// Set var's shape
void SetShape(const std::vector<int64_t> &dims);
private:
const proto::VarType &GetVarType() const;
proto::VarType *GetMutableVarType();
......
......@@ -130,8 +130,27 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return res;
}
void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case VarDescAPI::Type::type__: \
mutable_tensor_desc()->set_data_type(framework::proto::VarType::type__); \
break;
switch (data_type) {
SET_DATA_TYPE_CASE_ITEM(BOOL);
SET_DATA_TYPE_CASE_ITEM(SIZE_T);
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var type: " << static_cast<int>(data_type);
}
#undef SET_DATA_TYPE_CASE_ITEM
}
void VarDesc::SetDataTypes(
......
......@@ -84,7 +84,7 @@ class VarDesc : public VarDescAPI {
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(framework::proto::VarType::Type data_type);
void SetDataType(VarDescAPI::VarDataType data_type);
void SetDataTypes(
const std::vector<framework::proto::VarType::Type> &multiple_data_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册