diff --git a/lite/core/program.cc b/lite/core/program.cc index 5ddf6c0e935a851cc0b3c3eb7554609939ef1cbf..e712ed9e0b6715f0ce2dfe93a0e65da8593ef32a 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/core/program.h" +#include #include #include "lite/model_parser/cpp/block_desc.h" #include "lite/model_parser/cpp/op_desc.h" @@ -85,48 +86,54 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { auto* scope = op->scope(); auto in_names = op->op_info()->input_names(); auto out_names = op->op_info()->output_names(); - for (auto& in_name : in_names) { - auto it = origin_var_maps.find(in_name); + + std::vector var_names; + var_names.insert(var_names.end(), in_names.begin(), in_names.end()); + var_names.insert(var_names.end(), out_names.begin(), out_names.end()); + std::sort(var_names.begin(), var_names.end()); + var_names.erase(std::unique(var_names.begin(), var_names.end()), + var_names.end()); + + for (auto& var_name : var_names) { + auto it = origin_var_maps.find(var_name); if (it != origin_var_maps.end()) { auto* v = main_block.AddVar(); v->SetName((it->second).Name()); v->SetType((it->second).GetType()); v->SetPersistable((it->second).Persistable()); + if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") { + v->SetShape((it->second).GetShape()); + v->SetDataType((it->second).GetDataType()); + } } else { // New created vars must be LOD_TENSOR auto* v = main_block.AddVar(); - v->SetName(in_name); + v->SetName(var_name); v->SetType(cpp::VarDesc::Type::LOD_TENSOR); std::string in_arg_name; - op->op_info()->GetInputArgname(in_name, &in_arg_name); + op->op_info()->GetInputArgname(var_name, &in_arg_name); auto type = kernel->GetInputDeclType(in_arg_name); if (type->IsTensor()) { - auto tensor = scope->FindVar(in_name)->GetMutable(); + auto tensor = scope->FindVar(var_name)->GetMutable(); v->SetPersistable(tensor->persistable()); - } else { - CHECK(false) << "unsupported var type"; - } - } - } + if ((it->second).Name() != "feed" && (it->second).Name() != "fetch") { + v->SetShape(tensor->dims().data()); + switch (tensor->precision()) { +#define SET_DATATYPE(precision__, data_type) \ + case PrecisionType::precision__: \ + v->SetDataType(data_type); \ + break - for (auto& out_name : out_names) { - auto it = origin_var_maps.find(out_name); - if (it != origin_var_maps.end()) { - auto* v = main_block.AddVar(); - v->SetName((it->second).Name()); - v->SetType((it->second).GetType()); - v->SetPersistable((it->second).Persistable()); - } else { - // New created vars must be LOD_TENSOR - auto* v = main_block.AddVar(); - v->SetName(out_name); - v->SetType(cpp::VarDesc::Type::LOD_TENSOR); - std::string out_arg_name; - op->op_info()->GetOutputArgname(out_name, &out_arg_name); - auto type = kernel->GetOutputDeclType(out_arg_name); - if (type->IsTensor()) { - auto tensor = scope->FindVar(out_name)->GetMutable(); - v->SetPersistable(tensor->persistable()); + SET_DATATYPE(kFloat, VarDescAPI::VarDataType::FP32); + SET_DATATYPE(kInt8, VarDescAPI::VarDataType::INT8); + SET_DATATYPE(kInt16, VarDescAPI::VarDataType::INT16); + SET_DATATYPE(kInt32, VarDescAPI::VarDataType::INT32); + SET_DATATYPE(kInt64, VarDescAPI::VarDataType::INT64); +#undef SET_DATATYPE + default: + LOG(FATAL) << "unknown precision type"; + } + } } else { CHECK(false) << "unsupported var type"; } diff --git a/lite/model_parser/compatible_pb.cc b/lite/model_parser/compatible_pb.cc index d1131539bf30abba22feeba8abf009f95ab70a00..3d66a5234994036397e445744499696909a8ab3e 100644 --- a/lite/model_parser/compatible_pb.cc +++ b/lite/model_parser/compatible_pb.cc @@ -30,13 +30,17 @@ namespace paddle { namespace lite { /// For VarDesc transfrom -#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ - template <> \ - void TransformVarDescCppToAny(const cpp::VarDesc &cpp_desc, \ - T *any_desc) { \ - any_desc->SetName(cpp_desc.Name()); \ - any_desc->SetType(cpp_desc.GetType()); \ - any_desc->SetPersistable(cpp_desc.Persistable()); \ +#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ + template <> \ + void TransformVarDescCppToAny(const cpp::VarDesc &cpp_desc, \ + T *any_desc) { \ + any_desc->SetName(cpp_desc.Name()); \ + any_desc->SetType(cpp_desc.GetType()); \ + any_desc->SetPersistable(cpp_desc.Persistable()); \ + if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { \ + any_desc->SetShape(cpp_desc.GetShape()); \ + any_desc->SetDataType(cpp_desc.GetDataType()); \ + } \ } #ifndef LITE_ON_TINY_PUBLISH @@ -46,7 +50,10 @@ void TransformVarDescAnyToCpp(const pb::VarDesc &any_desc, cpp_desc->SetName(any_desc.Name()); cpp_desc->SetType(any_desc.GetType()); cpp_desc->SetPersistable(any_desc.Persistable()); - cpp_desc->SetDataType(any_desc.GetDataType()); + if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") { + cpp_desc->SetDataType(any_desc.GetDataType()); + cpp_desc->SetShape(any_desc.GetShape()); + } } #endif @@ -56,6 +63,14 @@ void TransformVarDescAnyToCpp( cpp_desc->SetName(any_desc.Name()); cpp_desc->SetType(any_desc.GetType()); cpp_desc->SetPersistable(any_desc.Persistable()); + // todo : SetDataType function is commented out temporarily + // because of Compatibility issues. The Compatibility issue + // should be fixed later and the code below should be applied + // later. @DannyIsFunny + /* if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") { + cpp_desc->SetDataType(any_desc.GetDataType()); + cpp_desc->SetShape(any_desc.GetShape()); + }*/ } /// For OpDesc transform diff --git a/lite/model_parser/compatible_pb_test.cc b/lite/model_parser/compatible_pb_test.cc index 3d964d14d7970aec36cb2f7ee2f6c6e11043d9be..088b64bf2cd13ce0f443f962bd2cb5f709c4d4f2 100644 --- a/lite/model_parser/compatible_pb_test.cc +++ b/lite/model_parser/compatible_pb_test.cc @@ -36,6 +36,8 @@ void SetVarDesc(VarDescType* desc) { desc->SetName("X"); desc->SetPersistable(true); desc->SetType(VarDescAPI::Type::LOD_TENSOR); + desc->SetShape({1, 3, 224, 224}); + desc->SetDataType(VarDescAPI::VarDataType::FP32); } template @@ -43,6 +45,8 @@ void SetVarDesc1(VarDescType* desc) { desc->SetName("Y"); desc->SetPersistable(false); desc->SetType(VarDescAPI::Type::SELECTED_ROWS); + desc->SetShape({1, 3, 224, 224}); + desc->SetDataType(VarDescAPI::VarDataType::FP32); } template diff --git a/lite/model_parser/cpp/var_desc.h b/lite/model_parser/cpp/var_desc.h index 9232bba3e8620b2e5e769c9f7a0f50969abe8421..c56d7cce53180e0157913372f8b0da4c9cedd8c9 100644 --- a/lite/model_parser/cpp/var_desc.h +++ b/lite/model_parser/cpp/var_desc.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include "lite/model_parser/desc_apis.h" namespace paddle { @@ -46,11 +47,16 @@ class VarDesc : public VarDescAPI { void SetDataType(Type data_type) { data_type_ = data_type; } + void SetShape(const std::vector &dims) { shape_ = dims; } + + std::vector GetShape() const { return shape_; } + private: std::string name_; Type type_; Type data_type_; bool persistable_; + std::vector shape_; }; } // namespace cpp diff --git a/lite/model_parser/desc_apis.h b/lite/model_parser/desc_apis.h index 5461de54a936f395db6718e9ce6f864f970b4322..e948afa3b9602f7010d678a4e55fa96f11ef5407 100644 --- a/lite/model_parser/desc_apis.h +++ b/lite/model_parser/desc_apis.h @@ -76,6 +76,10 @@ class VarDescAPI { virtual bool Persistable() const = 0; // Set var to be persistable or not virtual void SetPersistable(bool persistable) = 0; + // Get var's shape + virtual std::vector GetShape() const = 0; + // Set var's shape + virtual void SetShape(const std::vector& dims) = 0; }; /* diff --git a/lite/model_parser/naive_buffer/var_desc.cc b/lite/model_parser/naive_buffer/var_desc.cc index 86b6dd72844c694dee1781d322491bf922f32d09..2d2fb21ba3b4669601c44e5d929ae1756e09530d 100644 --- a/lite/model_parser/naive_buffer/var_desc.cc +++ b/lite/model_parser/naive_buffer/var_desc.cc @@ -131,6 +131,57 @@ proto::VarType* VarDesc::GetMutableVarType() { return builder; } +// todo : SetDataType function is commented out temporarily +// because of Compatibility issues. The Compatibility issue +// should be fixed later and the code below should be applied +// later. @DannyIsFunny +void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) { + /* using data_type_builder_t = EnumBuilder; + auto data_type_builder = + desc_->GetMutableField("tensor_desc") + ->GetMutableField("data_type"); + #define SET_DATA_TYPE_CASE_ITEM(type__) \ + case VarDescAPI::VarDataType::type__: \ + data_type_builder->set(proto::VarDataType::type__); \ + break + + switch (data_type) { + // Only support primary data type now. + SET_DATA_TYPE_CASE_ITEM(UINT8); + SET_DATA_TYPE_CASE_ITEM(INT8); + SET_DATA_TYPE_CASE_ITEM(INT16); + SET_DATA_TYPE_CASE_ITEM(INT32); + SET_DATA_TYPE_CASE_ITEM(INT64); + SET_DATA_TYPE_CASE_ITEM(FP32); + SET_DATA_TYPE_CASE_ITEM(FP64); + default: + LOG(FATAL) << "Unknown var data type"; + } + #undef SET_DATA_TYPE_CASE_ITEM + */ +} + +// Get var's shape +std::vector VarDesc::GetShape() const { + using data_type_builder_t = ListBuilder; + auto out_builder = desc_->GetField("tensor_desc") + .GetField("dims"); + return RepeatedToVector(out_builder); +} + +// Set var's shape +// todo : SetDataType function is commented out temporarily +// because of Compatibility issues. The Compatibility issue +// should be fixed later and the code below should be applied +// later. @DannyIsFunny +void VarDesc::SetShape(const std::vector& dims) { + /* using out_builder_type = ListBuilder; + auto out_builder = desc_->GetMutableField("tensor_desc") + ->GetMutableField("dims"); + CHECK(out_builder); + VectorToRepeated(dims, out_builder);*/ +} + } // namespace naive_buffer } // namespace lite } // namespace paddle diff --git a/lite/model_parser/naive_buffer/var_desc.h b/lite/model_parser/naive_buffer/var_desc.h index b638afd79d085e64ef7f1174f0d27975b827e76a..bf0845d7464f511dfb77812612c2b99c954600da 100644 --- a/lite/model_parser/naive_buffer/var_desc.h +++ b/lite/model_parser/naive_buffer/var_desc.h @@ -18,6 +18,7 @@ #include #include #include "lite/model_parser/desc_apis.h" +#include "lite/model_parser/naive_buffer/naive_buffer_wrapper_helper.h" #include "lite/model_parser/naive_buffer/proto/framework.nb.h" namespace paddle { @@ -51,8 +52,14 @@ class VarDesc : public VarDescAPI { void SetPersistable(bool persistable) override; + void SetDataType(VarDescAPI::VarDataType data_type); VarDescAPI::VarDataType GetDataType() const; + // Get var's shape + std::vector GetShape() const; + // Set var's shape + void SetShape(const std::vector &dims); + private: const proto::VarType &GetVarType() const; proto::VarType *GetMutableVarType(); diff --git a/lite/model_parser/pb/var_desc.cc b/lite/model_parser/pb/var_desc.cc index a3f28d00b94054addd728775e9373d73f9b7b729..f849b8dd0ed103f789aec41e5c88f3e4f3cdf878 100644 --- a/lite/model_parser/pb/var_desc.cc +++ b/lite/model_parser/pb/var_desc.cc @@ -130,8 +130,27 @@ std::vector> VarDesc::GetShapes() const { return res; } -void VarDesc::SetDataType(proto::VarType::Type data_type) { - mutable_tensor_desc()->set_data_type(data_type); +void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) { +#define SET_DATA_TYPE_CASE_ITEM(type__) \ + case VarDescAPI::Type::type__: \ + mutable_tensor_desc()->set_data_type(framework::proto::VarType::type__); \ + break; + + switch (data_type) { + SET_DATA_TYPE_CASE_ITEM(BOOL); + SET_DATA_TYPE_CASE_ITEM(SIZE_T); + SET_DATA_TYPE_CASE_ITEM(UINT8); + SET_DATA_TYPE_CASE_ITEM(INT8); + SET_DATA_TYPE_CASE_ITEM(INT16); + SET_DATA_TYPE_CASE_ITEM(INT32); + SET_DATA_TYPE_CASE_ITEM(INT64); + SET_DATA_TYPE_CASE_ITEM(FP16); + SET_DATA_TYPE_CASE_ITEM(FP32); + SET_DATA_TYPE_CASE_ITEM(FP64); + default: + LOG(FATAL) << "Unknown var type: " << static_cast(data_type); + } +#undef SET_DATA_TYPE_CASE_ITEM } void VarDesc::SetDataTypes( diff --git a/lite/model_parser/pb/var_desc.h b/lite/model_parser/pb/var_desc.h index bbf78b75d3f1b1a4a6488e28380f2587ca77bbc4..eefacef4b0c90faf132b2e4ef141ac7009939db5 100644 --- a/lite/model_parser/pb/var_desc.h +++ b/lite/model_parser/pb/var_desc.h @@ -84,7 +84,7 @@ class VarDesc : public VarDescAPI { std::vector> GetShapes() const; - void SetDataType(framework::proto::VarType::Type data_type); + void SetDataType(VarDescAPI::VarDataType data_type); void SetDataTypes( const std::vector &multiple_data_type);