// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "lite/model_parser/compatible_pb.h" #include #include #include "lite/model_parser/naive_buffer/block_desc.h" #include "lite/model_parser/naive_buffer/op_desc.h" #include "lite/model_parser/naive_buffer/program_desc.h" #include "lite/model_parser/naive_buffer/var_desc.h" #ifndef LITE_ON_TINY_PUBLISH #include "lite/model_parser/pb/block_desc.h" #include "lite/model_parser/pb/op_desc.h" #include "lite/model_parser/pb/program_desc.h" #include "lite/model_parser/pb/var_desc.h" #endif namespace paddle { namespace lite { /// For VarDesc transfrom #define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ template <> \ void TransformVarDescCppToAny(const cpp::VarDesc &cpp_desc, \ T *any_desc) { \ any_desc->SetName(cpp_desc.Name()); \ any_desc->SetType(cpp_desc.GetType()); \ any_desc->SetPersistable(cpp_desc.Persistable()); \ if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") { \ any_desc->SetShape(cpp_desc.GetShape()); \ any_desc->SetDataType(cpp_desc.GetDataType()); \ } \ } #ifndef LITE_ON_TINY_PUBLISH template <> void TransformVarDescAnyToCpp(const pb::VarDesc &any_desc, cpp::VarDesc *cpp_desc) { cpp_desc->SetName(any_desc.Name()); cpp_desc->SetType(any_desc.GetType()); cpp_desc->SetPersistable(any_desc.Persistable()); if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") { cpp_desc->SetDataType(any_desc.GetDataType()); cpp_desc->SetShape(any_desc.GetShape()); } } #endif template <> void TransformVarDescAnyToCpp( const naive_buffer::VarDesc &any_desc, cpp::VarDesc *cpp_desc) { cpp_desc->SetName(any_desc.Name()); cpp_desc->SetType(any_desc.GetType()); cpp_desc->SetPersistable(any_desc.Persistable()); // todo : SetDataType function is commented out temporarily // because of Compatibility issues. The Compatibility issue // should be fixed later and the code below should be applied // later. @DannyIsFunny /* if (any_desc.Name() != "feed" && any_desc.Name() != "fetch") { cpp_desc->SetDataType(any_desc.GetDataType()); cpp_desc->SetShape(any_desc.GetShape()); }*/ } /// For OpDesc transform template void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { for (const std::string ¶m : any_desc.InputArgumentNames()) { cpp_desc->SetInput(param, any_desc.Input(param)); } } template void OpInputsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { for (const std::string ¶m : cpp_desc.InputArgumentNames()) { any_desc->SetInput(param, cpp_desc.Input(param)); } } template void OpOutputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { for (const std::string ¶m : any_desc.OutputArgumentNames()) { cpp_desc->SetOutput(param, any_desc.Output(param)); } } template void OpOutputsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { for (const std::string ¶m : cpp_desc.OutputArgumentNames()) { any_desc->SetOutput(param, cpp_desc.Output(param)); } } template void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { using AttrType = OpDescAPI::AttrType; auto set_attr = [&](const std::string &name, AttrType type) { switch (type) { case AttrType::INT: cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); break; case AttrType::FLOAT: cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); break; case AttrType::STRING: cpp_desc->SetAttr( name, any_desc.template GetAttr(name)); break; case AttrType::LONG: cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); break; case AttrType::INTS: cpp_desc->SetAttr>( name, any_desc.template GetAttr>(name)); break; case AttrType::FLOATS: cpp_desc->SetAttr>( name, any_desc.template GetAttr>(name)); break; case AttrType::BOOLEAN: cpp_desc->SetAttr(name, any_desc.template GetAttr(name)); break; case AttrType::STRINGS: cpp_desc->SetAttr>( name, any_desc.template GetAttr>(name)); break; case AttrType::LONGS: cpp_desc->SetAttr>( name, any_desc.template GetAttr>(name)); break; case AttrType::BLOCK: { auto i = any_desc.template GetAttr(name); cpp_desc->SetAttr(name, i); // naive_buffer::BlockDesc* sub_block = any_desc.template // GetAttr(name); // LOG(INFO) << sub_block->OpsSize(); break; } default: LOG(FATAL) << "Unsupported attr type found " << static_cast(type); } }; for (const auto &attr_name : any_desc.AttrNames()) { auto type = any_desc.GetAttrType(attr_name); set_attr(attr_name, type); } } template void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { using AttrType = OpDescAPI::AttrType; auto set_attr = [&](const std::string &name, AttrType type) { switch (type) { #define IMPL_ONE(type__, T) \ case AttrType::type__: \ any_desc->template SetAttr(name, cpp_desc.GetAttr(name)); \ break; IMPL_ONE(INT, int32_t); IMPL_ONE(FLOAT, float); IMPL_ONE(STRING, std::string); IMPL_ONE(STRINGS, std::vector); IMPL_ONE(FLOATS, std::vector); IMPL_ONE(INTS, std::vector); IMPL_ONE(BOOLEAN, bool); IMPL_ONE(LONG, int64_t); IMPL_ONE(LONGS, std::vector); default: LOG(FATAL) << "Unsupported attr type found: " << static_cast(type); } }; #undef IMPL_ONE for (const auto &attr_name : cpp_desc.AttrNames()) { auto type = cpp_desc.GetAttrType(attr_name); set_attr(attr_name, type); } } #define TRANS_OP_ANY_WITH_CPP_IMPL(T) \ template <> \ void TransformOpDescAnyToCpp(const T &any_desc, cpp::OpDesc *cpp_desc) { \ cpp_desc->SetType(any_desc.Type()); \ OpInputsAnyToCpp(any_desc, cpp_desc); \ OpOutputsAnyToCpp(any_desc, cpp_desc); \ OpAttrsAnyToCpp(any_desc, cpp_desc); \ } \ \ template <> \ void TransformOpDescCppToAny(const cpp::OpDesc &cpp_desc, T *any_desc) { \ any_desc->SetType(cpp_desc.Type()); \ OpInputsCppToAny(cpp_desc, any_desc); \ OpOutputsCppToAny(cpp_desc, any_desc); \ OpAttrsCppToAny(cpp_desc, any_desc); \ } /// For BlockDesc transform #define TRANS_BLOCK_ANY_WITH_CPP_IMPL(T, NT, PNT) \ template <> \ void TransformBlockDescAnyToCpp(const NT::T &any_desc, \ cpp::BlockDesc *cpp_desc) { \ NT::T desc = any_desc; \ cpp_desc->SetIdx(desc.Idx()); \ cpp_desc->SetParentIdx(desc.ParentIdx()); \ cpp_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ \ cpp_desc->ClearOps(); \ for (size_t i = 0; i < desc.OpsSize(); ++i) { \ auto any_op_desc = NT::OpDesc(desc.GetOp(i)); \ auto *cpp_op_desc = cpp_desc->AddOp(); \ TransformOpDescAnyToCpp(any_op_desc, cpp_op_desc); \ } \ \ cpp_desc->ClearVars(); \ for (size_t i = 0; i < desc.VarsSize(); ++i) { \ auto any_var_desc = NT::VarDesc(desc.GetVar(i)); \ auto *cpp_var_desc = cpp_desc->AddVar(); \ TransformVarDescAnyToCpp(any_var_desc, cpp_var_desc); \ } \ } \ \ template <> \ void TransformBlockDescCppToAny(const cpp::T &cpp_desc, \ NT::T *any_desc) { \ auto desc = cpp_desc; \ any_desc->SetIdx(desc.Idx()); \ any_desc->SetParentIdx(desc.ParentIdx()); \ any_desc->SetForwardBlockIdx(desc.ForwardBlockIdx()); \ \ any_desc->ClearOps(); \ for (size_t i = 0; i < desc.OpsSize(); ++i) { \ auto *cpp_op_desc = desc.GetOp(i); \ auto any_op_desc = NT::OpDesc(any_desc->AddOp()); \ TransformOpDescCppToAny(*cpp_op_desc, &any_op_desc); \ } \ \ any_desc->ClearVars(); \ for (size_t i = 0; i < desc.VarsSize(); ++i) { \ auto *cpp_var_desc = desc.GetVar(i); \ auto any_var_desc = \ NT::VarDesc(any_desc->AddVar()); \ TransformVarDescCppToAny(*cpp_var_desc, &any_var_desc); \ } \ } /// For ProgramDesc transform #define TRANS_PROGRAM_ANY_WITH_CPP_IMPL(T, NT, PNT) \ template <> \ void TransformProgramDescAnyToCpp(const NT::T &any_desc, \ cpp::ProgramDesc *cpp_desc) { \ NT::T desc = any_desc; \ if (desc.HasVersion()) { \ cpp_desc->SetVersion(desc.Version()); \ } \ \ cpp_desc->ClearBlocks(); \ for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ auto any_block_desc = \ NT::BlockDesc(desc.GetBlock(i)); \ auto *cpp_block_desc = cpp_desc->AddBlock(); \ TransformBlockDescAnyToCpp(any_block_desc, cpp_block_desc); \ } \ } \ \ template <> \ void TransformProgramDescCppToAny(const cpp::T &cpp_desc, \ NT::T *any_desc) { \ auto desc = cpp_desc; \ if (desc.HasVersion()) { \ any_desc->SetVersion(desc.Version()); \ } \ \ any_desc->ClearBlocks(); \ for (size_t i = 0; i < desc.BlocksSize(); ++i) { \ auto *cpp_block_desc = desc.GetBlock(i); \ auto any_block_desc = \ NT::BlockDesc(any_desc->AddBlock()); \ TransformBlockDescCppToAny(*cpp_block_desc, &any_block_desc); \ } \ } TRANS_VAR_ANY_WITH_CPP_IMPL(naive_buffer::VarDesc); TRANS_OP_ANY_WITH_CPP_IMPL(naive_buffer::OpDesc); TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, naive_buffer, naive_buffer); TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, naive_buffer, naive_buffer); #ifndef LITE_ON_TINY_PUBLISH TRANS_VAR_ANY_WITH_CPP_IMPL(pb::VarDesc); TRANS_OP_ANY_WITH_CPP_IMPL(pb::OpDesc); TRANS_BLOCK_ANY_WITH_CPP_IMPL(BlockDesc, pb, framework); TRANS_PROGRAM_ANY_WITH_CPP_IMPL(ProgramDesc, pb, framework); #endif #undef TRANS_VAR_ANY_WITH_CPP_IMPL #undef TRANS_OP_ANY_WITH_CPP_IMPL #undef TRANS_BLOCK_ANY_WITH_CPP_IMPL #undef TRANS_PROGRAM_ANY_WITH_CPP_IMPL } // namespace lite } // namespace paddle