提交 5404c2ee 编写于 作者: T TianXiaogang 提交者: Yan Chunwei

Bug fix for model save and load (#1992)

* fix: fix model parser and save bug

* style: delete debug code

* fix: fix light_predictor program run model with subblock bug
上级 2ad127dd
...@@ -22,26 +22,25 @@ void LightPredictor::Build(const std::string& model_dir, ...@@ -22,26 +22,25 @@ void LightPredictor::Build(const std::string& model_dir,
const std::string& param_buffer, const std::string& param_buffer,
lite_api::LiteModelType model_type, lite_api::LiteModelType model_type,
bool model_from_memory) { bool model_from_memory) {
cpp::ProgramDesc desc;
switch (model_type) { switch (model_type) {
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_dir, "", "", scope_.get(), &desc); LoadModelPb(model_dir, "", "", scope_.get(), &cpp_program_desc_);
break; break;
#endif #endif
case lite_api::LiteModelType::kNaiveBuffer: { case lite_api::LiteModelType::kNaiveBuffer: {
if (model_from_memory) { if (model_from_memory) {
LoadModelNaiveFromMemory( LoadModelNaiveFromMemory(
model_buffer, param_buffer, scope_.get(), &desc); model_buffer, param_buffer, scope_.get(), &cpp_program_desc_);
} else { } else {
LoadModelNaive(model_dir, scope_.get(), &desc); LoadModelNaive(model_dir, scope_.get(), &cpp_program_desc_);
} }
break; break;
} }
default: default:
LOG(FATAL) << "Unknown model type"; LOG(FATAL) << "Unknown model type";
} }
BuildRuntimeProgram(desc); BuildRuntimeProgram(cpp_program_desc_);
} }
Tensor* LightPredictor::GetInput(size_t offset) { Tensor* LightPredictor::GetInput(size_t offset) {
...@@ -84,9 +83,11 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { ...@@ -84,9 +83,11 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
}); });
CHECK(it != kernels.end()); CHECK(it != kernels.end());
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
insts.emplace_back(op, std::move(*it)); insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope()); CHECK(program.exec_scope());
program_->set_exec_scope(program.exec_scope()); program_->set_exec_scope(program.exec_scope());
} }
......
...@@ -74,6 +74,7 @@ class LITE_API LightPredictor { ...@@ -74,6 +74,7 @@ class LITE_API LightPredictor {
private: private:
std::shared_ptr<Scope> scope_; std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_; std::unique_ptr<RuntimeProgram> program_;
cpp::ProgramDesc cpp_program_desc_;
}; };
} // namespace lite } // namespace lite
......
...@@ -600,10 +600,6 @@ void act_log<float>(const float* din, float* dout, int size, int threads) { ...@@ -600,10 +600,6 @@ void act_log<float>(const float* din, float* dout, int size, int threads) {
int remain = size - threads * nums_per_thread; int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2; int neon_loop_cnt_dim4 = nums_per_thread >> 2;
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2); int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);
LOG(INFO) << "nums_per_thread" << nums_per_thread;
LOG(INFO) << "remain" << remain;
LOG(INFO) << "neon_loop_cnt_dim4" << neon_loop_cnt_dim4;
LOG(INFO) << "neon_loop_remian_dim4" << neon_loop_remain_dim4;
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for #pragma omp parallel for
......
...@@ -140,7 +140,7 @@ void Program::Build(const cpp::ProgramDesc& prog) { ...@@ -140,7 +140,7 @@ void Program::Build(const cpp::ProgramDesc& prog) {
auto op = LiteOpRegistry::Global().Create(op_type); auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type; CHECK(op) << "no Op found for " << op_type;
if (op_type == "while") { if (op_type == "while") {
auto sub_block_idx = op_desc.GetAttr<int16_t>("sub_block"); auto sub_block_idx = op_desc.GetAttr<int32_t>("sub_block");
auto sub_block = auto sub_block =
const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>( const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>(
sub_block_idx); sub_block_idx);
......
...@@ -116,10 +116,8 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { ...@@ -116,10 +116,8 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) {
name, any_desc.template GetAttr<std::vector<int64_t>>(name)); name, any_desc.template GetAttr<std::vector<int64_t>>(name));
break; break;
case AttrType::BLOCK: { case AttrType::BLOCK: {
LOG(INFO) << "loading block " << name;
auto i = any_desc.template GetAttr<int16_t>(name); auto i = any_desc.template GetAttr<int16_t>(name);
LOG(INFO) << i; cpp_desc->SetAttr<int32_t>(name, i);
cpp_desc->SetAttr<int16_t>(name, i);
// naive_buffer::BlockDesc* sub_block = any_desc.template // naive_buffer::BlockDesc* sub_block = any_desc.template
// GetAttr<naive_buffer::BlockDesc*>(name); // GetAttr<naive_buffer::BlockDesc*>(name);
// LOG(INFO) << sub_block->OpsSize(); // LOG(INFO) << sub_block->OpsSize();
...@@ -152,6 +150,8 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { ...@@ -152,6 +150,8 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
IMPL_ONE(FLOATS, std::vector<float>); IMPL_ONE(FLOATS, std::vector<float>);
IMPL_ONE(INTS, std::vector<int>); IMPL_ONE(INTS, std::vector<int>);
IMPL_ONE(BOOLEAN, bool); IMPL_ONE(BOOLEAN, bool);
IMPL_ONE(LONG, int64_t);
IMPL_ONE(LONGS, std::vector<int64_t>);
default: default:
LOG(FATAL) << "Unsupported attr type found: " << static_cast<int>(type); LOG(FATAL) << "Unsupported attr type found: " << static_cast<int>(type);
} }
......
...@@ -28,7 +28,6 @@ namespace cpp { ...@@ -28,7 +28,6 @@ namespace cpp {
} }
SET_ATTR_IMPL(int32_t, INT); SET_ATTR_IMPL(int32_t, INT);
SET_ATTR_IMPL(int16_t, INT);
SET_ATTR_IMPL(float, FLOAT); SET_ATTR_IMPL(float, FLOAT);
SET_ATTR_IMPL(std::string, STRING); SET_ATTR_IMPL(std::string, STRING);
SET_ATTR_IMPL(bool, BOOLEAN); SET_ATTR_IMPL(bool, BOOLEAN);
...@@ -108,7 +107,6 @@ bool OpDesc::HasOutput(const std::string& param) const { ...@@ -108,7 +107,6 @@ bool OpDesc::HasOutput(const std::string& param) const {
} }
GET_IMPL_ONE(float, FLOAT); GET_IMPL_ONE(float, FLOAT);
GET_IMPL_ONE(int16_t, INT);
GET_IMPL_ONE(std::string, STRING); GET_IMPL_ONE(std::string, STRING);
GET_IMPL_ONE(int64_t, LONG); GET_IMPL_ONE(int64_t, LONG);
GET_IMPL_ONE(bool, BOOLEAN); GET_IMPL_ONE(bool, BOOLEAN);
......
...@@ -54,6 +54,7 @@ SET_ATTR_IMPL(int, INT, Int32, i); ...@@ -54,6 +54,7 @@ SET_ATTR_IMPL(int, INT, Int32, i);
SET_ATTR_IMPL(float, FLOAT, Float32, f); SET_ATTR_IMPL(float, FLOAT, Float32, f);
SET_ATTR_IMPL(bool, BOOLEAN, Bool, b); SET_ATTR_IMPL(bool, BOOLEAN, Bool, b);
SET_ATTR_IMPL(std::string, STRING, String, s); SET_ATTR_IMPL(std::string, STRING, String, s);
SET_ATTR_IMPL(int64_t, LONG, Int64, l);
#undef SET_ATTR_IMPL #undef SET_ATTR_IMPL
#define SET_ATTRS_IMPL(T, ty__, bd__, pb_f__) \ #define SET_ATTRS_IMPL(T, ty__, bd__, pb_f__) \
...@@ -77,6 +78,7 @@ SET_ATTR_IMPL(std::string, STRING, String, s); ...@@ -77,6 +78,7 @@ SET_ATTR_IMPL(std::string, STRING, String, s);
SET_ATTRS_IMPL(int, INTS, Int32, ints); SET_ATTRS_IMPL(int, INTS, Int32, ints);
SET_ATTRS_IMPL(float, FLOATS, Float32, floats); SET_ATTRS_IMPL(float, FLOATS, Float32, floats);
SET_ATTRS_IMPL(std::string, STRINGS, String, strings); SET_ATTRS_IMPL(std::string, STRINGS, String, strings);
SET_ATTRS_IMPL(int64_t, LONGS, Int64, longs);
#undef SET_ATTRS_IMPL #undef SET_ATTRS_IMPL
const proto::OpDesc::Attr& GetFindAttr(const proto::OpDesc& desc, const proto::OpDesc::Attr& GetFindAttr(const proto::OpDesc& desc,
......
...@@ -46,6 +46,7 @@ FindAttr(framework::proto::OpDesc *desc, const std::string &name) { ...@@ -46,6 +46,7 @@ FindAttr(framework::proto::OpDesc *desc, const std::string &name) {
SET_IMPL_ONE(int, INT, i); SET_IMPL_ONE(int, INT, i);
SET_IMPL_ONE(float, FLOAT, f); SET_IMPL_ONE(float, FLOAT, f);
SET_IMPL_ONE(bool, BOOLEAN, b); SET_IMPL_ONE(bool, BOOLEAN, b);
SET_IMPL_ONE(int64_t, LONG, l);
template <> template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name, void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
...@@ -88,6 +89,16 @@ void OpDesc::SetAttr<std::vector<std::string>>( ...@@ -88,6 +89,16 @@ void OpDesc::SetAttr<std::vector<std::string>>(
} }
} }
template <>
void OpDesc::SetAttr<std::vector<int64_t>>(const std::string &name,
const std::vector<int64_t> &v) {
auto it = FindAttr(desc_, name);
it->set_type(framework::proto::LONGS);
it->clear_longs();
for (auto &i : v) {
it->add_longs(i);
}
}
google::protobuf::internal::RepeatedPtrIterator< google::protobuf::internal::RepeatedPtrIterator<
const framework::proto::OpDesc_Attr> const framework::proto::OpDesc_Attr>
GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) { GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册