From 5404c2ee651af2cac3dbb6b7f53e978be19fea4d Mon Sep 17 00:00:00 2001 From: TianXiaogang Date: Thu, 19 Sep 2019 13:30:27 +0800 Subject: [PATCH] Bug fix for model save and load (#1992) * fix: fix model parser and save bug * style: delete debug code * fix: fix light_predictor program run model with subblock bug --- lite/api/light_api.cc | 11 ++++++----- lite/api/light_api.h | 1 + lite/backends/arm/math/activation.cc | 4 ---- lite/core/program.cc | 2 +- lite/model_parser/compatible_pb.cc | 6 +++--- lite/model_parser/cpp/op_desc.cc | 2 -- lite/model_parser/naive_buffer/op_desc.cc | 2 ++ lite/model_parser/pb/op_desc.cc | 11 +++++++++++ 8 files changed, 24 insertions(+), 15 deletions(-) diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index 98b79e58aa..2d75a1ba82 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -22,26 +22,25 @@ void LightPredictor::Build(const std::string& model_dir, const std::string& param_buffer, lite_api::LiteModelType model_type, bool model_from_memory) { - cpp::ProgramDesc desc; switch (model_type) { #ifndef LITE_ON_TINY_PUBLISH case lite_api::LiteModelType::kProtobuf: - LoadModelPb(model_dir, "", "", scope_.get(), &desc); + LoadModelPb(model_dir, "", "", scope_.get(), &cpp_program_desc_); break; #endif case lite_api::LiteModelType::kNaiveBuffer: { if (model_from_memory) { LoadModelNaiveFromMemory( - model_buffer, param_buffer, scope_.get(), &desc); + model_buffer, param_buffer, scope_.get(), &cpp_program_desc_); } else { - LoadModelNaive(model_dir, scope_.get(), &desc); + LoadModelNaive(model_dir, scope_.get(), &cpp_program_desc_); } break; } default: LOG(FATAL) << "Unknown model type"; } - BuildRuntimeProgram(desc); + BuildRuntimeProgram(cpp_program_desc_); } Tensor* LightPredictor::GetInput(size_t offset) { @@ -84,9 +83,11 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { }); CHECK(it != kernels.end()); (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); + insts.emplace_back(op, std::move(*it)); } program_.reset(new RuntimeProgram(std::move(insts))); + CHECK(program.exec_scope()); program_->set_exec_scope(program.exec_scope()); } diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 2415401744..0d5c7006c8 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -74,6 +74,7 @@ class LITE_API LightPredictor { private: std::shared_ptr scope_; std::unique_ptr program_; + cpp::ProgramDesc cpp_program_desc_; }; } // namespace lite diff --git a/lite/backends/arm/math/activation.cc b/lite/backends/arm/math/activation.cc index c227077779..098a43d682 100644 --- a/lite/backends/arm/math/activation.cc +++ b/lite/backends/arm/math/activation.cc @@ -600,10 +600,6 @@ void act_log(const float* din, float* dout, int size, int threads) { int remain = size - threads * nums_per_thread; int neon_loop_cnt_dim4 = nums_per_thread >> 2; int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); - LOG(INFO) << "nums_per_thread" << nums_per_thread; - LOG(INFO) << "remain" << remain; - LOG(INFO) << "neon_loop_cnt_dim4" << neon_loop_cnt_dim4; - LOG(INFO) << "neon_loop_remian_dim4" << neon_loop_remain_dim4; float32x4_t vzero = vdupq_n_f32(0.f); #pragma omp parallel for diff --git a/lite/core/program.cc b/lite/core/program.cc index 179cdf909a..2a1f748cb5 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -140,7 +140,7 @@ void Program::Build(const cpp::ProgramDesc& prog) { auto op = LiteOpRegistry::Global().Create(op_type); CHECK(op) << "no Op found for " << op_type; if (op_type == "while") { - auto sub_block_idx = op_desc.GetAttr("sub_block"); + auto sub_block_idx = op_desc.GetAttr("sub_block"); auto sub_block = const_cast(prog).GetBlock( sub_block_idx); diff --git a/lite/model_parser/compatible_pb.cc b/lite/model_parser/compatible_pb.cc index 09604b014a..2df4a92270 100644 --- a/lite/model_parser/compatible_pb.cc +++ b/lite/model_parser/compatible_pb.cc @@ -116,10 +116,8 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { name, any_desc.template GetAttr>(name)); break; case AttrType::BLOCK: { - LOG(INFO) << "loading block " << name; auto i = any_desc.template GetAttr(name); - LOG(INFO) << i; - cpp_desc->SetAttr(name, i); + cpp_desc->SetAttr(name, i); // naive_buffer::BlockDesc* sub_block = any_desc.template // GetAttr(name); // LOG(INFO) << sub_block->OpsSize(); @@ -152,6 +150,8 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { IMPL_ONE(FLOATS, std::vector); IMPL_ONE(INTS, std::vector); IMPL_ONE(BOOLEAN, bool); + IMPL_ONE(LONG, int64_t); + IMPL_ONE(LONGS, std::vector); default: LOG(FATAL) << "Unsupported attr type found: " << static_cast(type); } diff --git a/lite/model_parser/cpp/op_desc.cc b/lite/model_parser/cpp/op_desc.cc index 4c99fdfb3d..f4be0106fc 100644 --- a/lite/model_parser/cpp/op_desc.cc +++ b/lite/model_parser/cpp/op_desc.cc @@ -28,7 +28,6 @@ namespace cpp { } SET_ATTR_IMPL(int32_t, INT); -SET_ATTR_IMPL(int16_t, INT); SET_ATTR_IMPL(float, FLOAT); SET_ATTR_IMPL(std::string, STRING); SET_ATTR_IMPL(bool, BOOLEAN); @@ -108,7 +107,6 @@ bool OpDesc::HasOutput(const std::string& param) const { } GET_IMPL_ONE(float, FLOAT); -GET_IMPL_ONE(int16_t, INT); GET_IMPL_ONE(std::string, STRING); GET_IMPL_ONE(int64_t, LONG); GET_IMPL_ONE(bool, BOOLEAN); diff --git a/lite/model_parser/naive_buffer/op_desc.cc b/lite/model_parser/naive_buffer/op_desc.cc index 8d36a4ad3d..8a2ad55807 100644 --- a/lite/model_parser/naive_buffer/op_desc.cc +++ b/lite/model_parser/naive_buffer/op_desc.cc @@ -54,6 +54,7 @@ SET_ATTR_IMPL(int, INT, Int32, i); SET_ATTR_IMPL(float, FLOAT, Float32, f); SET_ATTR_IMPL(bool, BOOLEAN, Bool, b); SET_ATTR_IMPL(std::string, STRING, String, s); +SET_ATTR_IMPL(int64_t, LONG, Int64, l); #undef SET_ATTR_IMPL #define SET_ATTRS_IMPL(T, ty__, bd__, pb_f__) \ @@ -77,6 +78,7 @@ SET_ATTR_IMPL(std::string, STRING, String, s); SET_ATTRS_IMPL(int, INTS, Int32, ints); SET_ATTRS_IMPL(float, FLOATS, Float32, floats); SET_ATTRS_IMPL(std::string, STRINGS, String, strings); +SET_ATTRS_IMPL(int64_t, LONGS, Int64, longs); #undef SET_ATTRS_IMPL const proto::OpDesc::Attr& GetFindAttr(const proto::OpDesc& desc, diff --git a/lite/model_parser/pb/op_desc.cc b/lite/model_parser/pb/op_desc.cc index 34b83d55b5..37ed07a2c5 100644 --- a/lite/model_parser/pb/op_desc.cc +++ b/lite/model_parser/pb/op_desc.cc @@ -46,6 +46,7 @@ FindAttr(framework::proto::OpDesc *desc, const std::string &name) { SET_IMPL_ONE(int, INT, i); SET_IMPL_ONE(float, FLOAT, f); SET_IMPL_ONE(bool, BOOLEAN, b); +SET_IMPL_ONE(int64_t, LONG, l); template <> void OpDesc::SetAttr>(const std::string &name, @@ -88,6 +89,16 @@ void OpDesc::SetAttr>( } } +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(desc_, name); + it->set_type(framework::proto::LONGS); + it->clear_longs(); + for (auto &i : v) { + it->add_longs(i); + } +} google::protobuf::internal::RepeatedPtrIterator< const framework::proto::OpDesc_Attr> GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) { -- GitLab