未验证 提交 546d4da8 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework][ModelType] Add Shape&Precision information into optimized model (#3643)

上级 a24d4dd1
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/core/program.h" #include "lite/core/program.h"
#include <algorithm>
#include <unordered_map> #include <unordered_map>
#include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h" #include "lite/model_parser/cpp/op_desc.h"
...@@ -85,48 +86,54 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -85,48 +86,54 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
auto* scope = op->scope(); auto* scope = op->scope();
auto in_names = op->op_info()->input_names(); auto in_names = op->op_info()->input_names();
auto out_names = op->op_info()->output_names(); auto out_names = op->op_info()->output_names();
for (auto& in_name : in_names) {
auto it = origin_var_maps.find(in_name); std::vector<std::string> var_names;
var_names.insert(var_names.end(), in_names.begin(), in_names.end());
var_names.insert(var_names.end(), out_names.begin(), out_names.end());
std::sort(var_names.begin(), var_names.end());
var_names.erase(std::unique(var_names.begin(), var_names.end()),
var_names.end());
for (auto& var_name : var_names) {
auto it = origin_var_maps.find(var_name);
if (it != origin_var_maps.end()) { if (it != origin_var_maps.end()) {
auto* v = main_block.AddVar<cpp::VarDesc>(); auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName((it->second).Name()); v->SetName((it->second).Name());
v->SetType((it->second).GetType()); v->SetType((it->second).GetType());
v->SetPersistable((it->second).Persistable()); v->SetPersistable((it->second).Persistable());
if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") {
v->SetShape((it->second).GetShape());
v->SetDataType((it->second).GetDataType());
}
} else { } else {
// New created vars must be LOD_TENSOR // New created vars must be LOD_TENSOR
auto* v = main_block.AddVar<cpp::VarDesc>(); auto* v = main_block.AddVar<cpp::VarDesc>();
v->SetName(in_name); v->SetName(var_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR); v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string in_arg_name; std::string in_arg_name;
op->op_info()->GetInputArgname(in_name, &in_arg_name); op->op_info()->GetInputArgname(var_name, &in_arg_name);
auto type = kernel->GetInputDeclType(in_arg_name); auto type = kernel->GetInputDeclType(in_arg_name);
if (type->IsTensor()) { if (type->IsTensor()) {
auto tensor = scope->FindVar(in_name)->GetMutable<Tensor>(); auto tensor = scope->FindVar(var_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable()); v->SetPersistable(tensor->persistable());
} else { if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") {
CHECK(false) << "unsupported var type"; v->SetShape(tensor->dims().data());
} switch (tensor->precision()) {
} #define SET_DATATYPE(precision__, data_type) \
} case PrecisionType::precision__: \
v->SetDataType(data_type); \
break
for (auto& out_name : out_names) { SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32);
auto it = origin_var_maps.find(out_name); SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8);
if (it != origin_var_maps.end()) { SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16);
auto* v = main_block.AddVar<cpp::VarDesc>(); SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32);
v->SetName((it->second).Name()); SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64);
v->SetType((it->second).GetType()); #undef SET_DATATYPE
v->SetPersistable((it->second).Persistable()); default:
} else { LOG(FATAL) << "unknown precision type";
// New created vars must be LOD_TENSOR }
auto* v = main_block.AddVar<cpp::VarDesc>(); }
v->SetName(out_name);
v->SetType(cpp::VarDesc::Type::LOD_TENSOR);
std::string out_arg_name;
op->op_info()->GetOutputArgname(out_name, &out_arg_name);
auto type = kernel->GetOutputDeclType(out_arg_name);
if (type->IsTensor()) {
auto tensor = scope->FindVar(out_name)->GetMutable<Tensor>();
v->SetPersistable(tensor->persistable());
} else { } else {
CHECK(false) << "unsupported var type"; CHECK(false) << "unsupported var type";
} }
......
...@@ -30,13 +30,17 @@ namespace paddle { ...@@ -30,13 +30,17 @@ namespace paddle {
namespace lite { namespace lite {
/// For VarDesc transfrom /// For VarDesc transfrom
#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ #define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \
template <> \ template <> \
void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \ void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \
T *any_desc) { \ T *any_desc) { \
any_desc->SetName(cpp_desc.Name()); \ any_desc->SetName(cpp_desc.Name()); \
any_desc->SetType(cpp_desc.GetType()); \ any_desc->SetType(cpp_desc.GetType()); \
any_desc->SetPersistable(cpp_desc.Persistable()); \ any_desc->SetPersistable(cpp_desc.Persistable()); \
if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { \
any_desc->SetShape(cpp_desc.GetShape()); \
any_desc->SetDataType(cpp_desc.GetDataType()); \
} \
} }
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
...@@ -46,7 +50,10 @@ void TransformVarDescAnyToCpp<pb::VarDesc>(const pb::VarDesc &any_desc, ...@@ -46,7 +50,10 @@ void TransformVarDescAnyToCpp<pb::VarDesc>(const pb::VarDesc &any_desc,
cpp_desc->SetName(any_desc.Name()); cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType()); cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable()); cpp_desc->SetPersistable(any_desc.Persistable());
cpp_desc->SetDataType(any_desc.GetDataType()); if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}
} }
#endif #endif
...@@ -56,6 +63,14 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>( ...@@ -56,6 +63,14 @@ void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
cpp_desc->SetName(any_desc.Name()); cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType()); cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable()); cpp_desc->SetPersistable(any_desc.Persistable());
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
/* if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") {
cpp_desc->SetDataType(any_desc.GetDataType());
cpp_desc->SetShape(any_desc.GetShape());
}*/
} }
/// For OpDesc transform /// For OpDesc transform
......
...@@ -36,6 +36,8 @@ void SetVarDesc(VarDescType* desc) { ...@@ -36,6 +36,8 @@ void SetVarDesc(VarDescType* desc) {
desc->SetName("X"); desc->SetName("X");
desc->SetPersistable(true); desc->SetPersistable(true);
desc->SetType(VarDescAPI::Type::LOD_TENSOR); desc->SetType(VarDescAPI::Type::LOD_TENSOR);
desc->SetShape({1, 3, 224, 224});
desc->SetDataType(VarDescAPI::VarDataType::FP32);
} }
template <typename VarDescType> template <typename VarDescType>
...@@ -43,6 +45,8 @@ void SetVarDesc1(VarDescType* desc) { ...@@ -43,6 +45,8 @@ void SetVarDesc1(VarDescType* desc) {
desc->SetName("Y"); desc->SetName("Y");
desc->SetPersistable(false); desc->SetPersistable(false);
desc->SetType(VarDescAPI::Type::SELECTED_ROWS); desc->SetType(VarDescAPI::Type::SELECTED_ROWS);
desc->SetShape({1, 3, 224, 224});
desc->SetDataType(VarDescAPI::VarDataType::FP32);
} }
template <typename VarDescType> template <typename VarDescType>
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "lite/model_parser/desc_apis.h" #include "lite/model_parser/desc_apis.h"
namespace paddle { namespace paddle {
...@@ -46,11 +47,16 @@ class VarDesc : public VarDescAPI { ...@@ -46,11 +47,16 @@ class VarDesc : public VarDescAPI {
void SetDataType(Type data_type) { data_type_ = data_type; } void SetDataType(Type data_type) { data_type_ = data_type; }
void SetShape(const std::vector<int64_t> &dims) { shape_ = dims; }
std::vector<int64_t> GetShape() const { return shape_; }
private: private:
std::string name_; std::string name_;
Type type_; Type type_;
Type data_type_; Type data_type_;
bool persistable_; bool persistable_;
std::vector<int64_t> shape_;
}; };
} // namespace cpp } // namespace cpp
......
...@@ -76,6 +76,10 @@ class VarDescAPI { ...@@ -76,6 +76,10 @@ class VarDescAPI {
virtual bool Persistable() const = 0; virtual bool Persistable() const = 0;
// Set var to be persistable or not // Set var to be persistable or not
virtual void SetPersistable(bool persistable) = 0; virtual void SetPersistable(bool persistable) = 0;
// Get var's shape
virtual std::vector<int64_t> GetShape() const = 0;
// Set var's shape
virtual void SetShape(const std::vector<int64_t>& dims) = 0;
}; };
/* /*
......
...@@ -131,6 +131,57 @@ proto::VarType* VarDesc::GetMutableVarType() { ...@@ -131,6 +131,57 @@ proto::VarType* VarDesc::GetMutableVarType() {
return builder; return builder;
} }
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) {
/* using data_type_builder_t = EnumBuilder<proto::VarDataType>;
auto data_type_builder =
desc_->GetMutableField<proto::TensorDesc>("tensor_desc")
->GetMutableField<data_type_builder_t>("data_type");
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case VarDescAPI::VarDataType::type__: \
data_type_builder->set(proto::VarDataType::type__); \
break
switch (data_type) {
// Only support primary data type now.
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var data type";
}
#undef SET_DATA_TYPE_CASE_ITEM
*/
}
// Get var's shape
std::vector<int64_t> VarDesc::GetShape() const {
using data_type_builder_t = ListBuilder<Int64Builder>;
auto out_builder = desc_->GetField<proto::TensorDesc>("tensor_desc")
.GetField<data_type_builder_t>("dims");
return RepeatedToVector<int64_t, Int64Builder>(out_builder);
}
// Set var's shape
// todo : SetDataType function is commented out temporarily
// because of Compatibility issues. The Compatibility issue
// should be fixed later and the code below should be applied
// later. @DannyIsFunny
void VarDesc::SetShape(const std::vector<int64_t>& dims) {
/* using out_builder_type = ListBuilder<Int64Builder>;
auto out_builder = desc_->GetMutableField<proto::TensorDesc>("tensor_desc")
->GetMutableField<out_builder_type>("dims");
CHECK(out_builder);
VectorToRepeated<int64_t, Int64Builder>(dims, out_builder);*/
}
} // namespace naive_buffer } // namespace naive_buffer
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/model_parser/desc_apis.h" #include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/naive_buffer/naive_buffer_wrapper_helper.h"
#include "lite/model_parser/naive_buffer/proto/framework.nb.h" #include "lite/model_parser/naive_buffer/proto/framework.nb.h"
namespace paddle { namespace paddle {
...@@ -51,8 +52,14 @@ class VarDesc : public VarDescAPI { ...@@ -51,8 +52,14 @@ class VarDesc : public VarDescAPI {
void SetPersistable(bool persistable) override; void SetPersistable(bool persistable) override;
void SetDataType(VarDescAPI::VarDataType data_type);
VarDescAPI::VarDataType GetDataType() const; VarDescAPI::VarDataType GetDataType() const;
// Get var's shape
std::vector<int64_t> GetShape() const;
// Set var's shape
void SetShape(const std::vector<int64_t> &dims);
private: private:
const proto::VarType &GetVarType() const; const proto::VarType &GetVarType() const;
proto::VarType *GetMutableVarType(); proto::VarType *GetMutableVarType();
......
...@@ -130,8 +130,27 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const { ...@@ -130,8 +130,27 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
return res; return res;
} }
void VarDesc::SetDataType(proto::VarType::Type data_type) { void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) {
mutable_tensor_desc()->set_data_type(data_type); #define SET_DATA_TYPE_CASE_ITEM(type__) \
case VarDescAPI::Type::type__: \
mutable_tensor_desc()->set_data_type(framework::proto::VarType::type__); \
break;
switch (data_type) {
SET_DATA_TYPE_CASE_ITEM(BOOL);
SET_DATA_TYPE_CASE_ITEM(SIZE_T);
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var type: " << static_cast<int>(data_type);
}
#undef SET_DATA_TYPE_CASE_ITEM
} }
void VarDesc::SetDataTypes( void VarDesc::SetDataTypes(
......
...@@ -84,7 +84,7 @@ class VarDesc : public VarDescAPI { ...@@ -84,7 +84,7 @@ class VarDesc : public VarDescAPI {
std::vector<std::vector<int64_t>> GetShapes() const; std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(framework::proto::VarType::Type data_type); void SetDataType(VarDescAPI::VarDataType data_type);
void SetDataTypes( void SetDataTypes(
const std::vector<framework::proto::VarType::Type> &multiple_data_type); const std::vector<framework::proto::VarType::Type> &multiple_data_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册