提交 5404c2ee 编写于 作者: T TianXiaogang 提交者: Yan Chunwei

Bug fix for model save and load (#1992)

* fix: fix model parser and save bug

* style: delete debug code

* fix: fix light_predictor program run model with subblock bug
上级 2ad127dd
......@@ -22,26 +22,25 @@ void LightPredictor::Build(const std::string& model_dir,
const std::string& param_buffer,
lite_api::LiteModelType model_type,
bool model_from_memory) {
cpp::ProgramDesc desc;
switch (model_type) {
#ifndef LITE_ON_TINY_PUBLISH
case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_dir, "", "", scope_.get(), &desc);
LoadModelPb(model_dir, "", "", scope_.get(), &cpp_program_desc_);
break;
#endif
case lite_api::LiteModelType::kNaiveBuffer: {
if (model_from_memory) {
LoadModelNaiveFromMemory(
model_buffer, param_buffer, scope_.get(), &desc);
model_buffer, param_buffer, scope_.get(), &cpp_program_desc_);
} else {
LoadModelNaive(model_dir, scope_.get(), &desc);
LoadModelNaive(model_dir, scope_.get(), &cpp_program_desc_);
}
break;
}
default:
LOG(FATAL) << "Unknown model type";
}
BuildRuntimeProgram(desc);
BuildRuntimeProgram(cpp_program_desc_);
}
Tensor* LightPredictor::GetInput(size_t offset) {
......@@ -84,9 +83,11 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
});
CHECK(it != kernels.end());
(*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target()));
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope());
program_->set_exec_scope(program.exec_scope());
}
......
......@@ -74,6 +74,7 @@ class LITE_API LightPredictor {
private:
std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_;
cpp::ProgramDesc cpp_program_desc_;
};
} // namespace lite
......
......@@ -600,10 +600,6 @@ void act_log<float>(const float* din, float* dout, int size, int threads) {
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);
LOG(INFO) << "nums_per_thread" << nums_per_thread;
LOG(INFO) << "remain" << remain;
LOG(INFO) << "neon_loop_cnt_dim4" << neon_loop_cnt_dim4;
LOG(INFO) << "neon_loop_remian_dim4" << neon_loop_remain_dim4;
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
......
......@@ -140,7 +140,7 @@ void Program::Build(const cpp::ProgramDesc& prog) {
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
if (op_type == "while") {
auto sub_block_idx = op_desc.GetAttr<int16_t>("sub_block");
auto sub_block_idx = op_desc.GetAttr<int32_t>("sub_block");
auto sub_block =
const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>(
sub_block_idx);
......
......@@ -116,10 +116,8 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) {
name, any_desc.template GetAttr<std::vector<int64_t>>(name));
break;
case AttrType::BLOCK: {
LOG(INFO) << "loading block " << name;
auto i = any_desc.template GetAttr<int16_t>(name);
LOG(INFO) << i;
cpp_desc->SetAttr<int16_t>(name, i);
cpp_desc->SetAttr<int32_t>(name, i);
// naive_buffer::BlockDesc* sub_block = any_desc.template
// GetAttr<naive_buffer::BlockDesc*>(name);
// LOG(INFO) << sub_block->OpsSize();
......@@ -152,6 +150,8 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
IMPL_ONE(FLOATS, std::vector<float>);
IMPL_ONE(INTS, std::vector<int>);
IMPL_ONE(BOOLEAN, bool);
IMPL_ONE(LONG, int64_t);
IMPL_ONE(LONGS, std::vector<int64_t>);
default:
LOG(FATAL) << "Unsupported attr type found: " << static_cast<int>(type);
}
......
......@@ -28,7 +28,6 @@ namespace cpp {
}
SET_ATTR_IMPL(int32_t, INT);
SET_ATTR_IMPL(int16_t, INT);
SET_ATTR_IMPL(float, FLOAT);
SET_ATTR_IMPL(std::string, STRING);
SET_ATTR_IMPL(bool, BOOLEAN);
......@@ -108,7 +107,6 @@ bool OpDesc::HasOutput(const std::string& param) const {
}
GET_IMPL_ONE(float, FLOAT);
GET_IMPL_ONE(int16_t, INT);
GET_IMPL_ONE(std::string, STRING);
GET_IMPL_ONE(int64_t, LONG);
GET_IMPL_ONE(bool, BOOLEAN);
......
......@@ -54,6 +54,7 @@ SET_ATTR_IMPL(int, INT, Int32, i);
SET_ATTR_IMPL(float, FLOAT, Float32, f);
SET_ATTR_IMPL(bool, BOOLEAN, Bool, b);
SET_ATTR_IMPL(std::string, STRING, String, s);
SET_ATTR_IMPL(int64_t, LONG, Int64, l);
#undef SET_ATTR_IMPL
#define SET_ATTRS_IMPL(T, ty__, bd__, pb_f__) \
......@@ -77,6 +78,7 @@ SET_ATTR_IMPL(std::string, STRING, String, s);
SET_ATTRS_IMPL(int, INTS, Int32, ints);
SET_ATTRS_IMPL(float, FLOATS, Float32, floats);
SET_ATTRS_IMPL(std::string, STRINGS, String, strings);
SET_ATTRS_IMPL(int64_t, LONGS, Int64, longs);
#undef SET_ATTRS_IMPL
const proto::OpDesc::Attr& GetFindAttr(const proto::OpDesc& desc,
......
......@@ -46,6 +46,7 @@ FindAttr(framework::proto::OpDesc *desc, const std::string &name) {
SET_IMPL_ONE(int, INT, i);
SET_IMPL_ONE(float, FLOAT, f);
SET_IMPL_ONE(bool, BOOLEAN, b);
SET_IMPL_ONE(int64_t, LONG, l);
template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
......@@ -88,6 +89,16 @@ void OpDesc::SetAttr<std::vector<std::string>>(
}
}
template <>
void OpDesc::SetAttr<std::vector<int64_t>>(const std::string &name,
const std::vector<int64_t> &v) {
auto it = FindAttr(desc_, name);
it->set_type(framework::proto::LONGS);
it->clear_longs();
for (auto &i : v) {
it->add_longs(i);
}
}
google::protobuf::internal::RepeatedPtrIterator<
const framework::proto::OpDesc_Attr>
GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册