提交 af2a2e79 编写于 作者: H huzhiqiang 提交者: Yan Chunwei

add the method of loading model from naive buffer for LightPredictor (#1918) (#1937)

上级 72c919dc
...@@ -18,18 +18,26 @@ namespace paddle { ...@@ -18,18 +18,26 @@ namespace paddle {
namespace lite { namespace lite {
void LightPredictor::Build(const std::string& model_dir, void LightPredictor::Build(const std::string& model_dir,
lite_api::LiteModelType model_type) { const std::string& model_buffer,
const std::string& param_buffer,
lite_api::LiteModelType model_type,
bool model_from_memory) {
cpp::ProgramDesc desc; cpp::ProgramDesc desc;
LOG(INFO) << "Load model from " << model_dir;
switch (model_type) { switch (model_type) {
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_dir, "", "", scope_.get(), &desc); LoadModelPb(model_dir, "", "", scope_.get(), &desc);
break; break;
#endif #endif
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer: {
if (model_from_memory) {
LoadModelNaiveFromMemory(
model_buffer, param_buffer, scope_.get(), &desc);
} else {
LoadModelNaive(model_dir, scope_.get(), &desc); LoadModelNaive(model_dir, scope_.get(), &desc);
}
break; break;
}
default: default:
LOG(FATAL) << "Unknown model type"; LOG(FATAL) << "Unknown model type";
} }
...@@ -83,11 +91,5 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { ...@@ -83,11 +91,5 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
program_->set_exec_scope(program.exec_scope()); program_->set_exec_scope(program.exec_scope());
} }
LightPredictor::LightPredictor(const std::string& model_dir,
lite_api::LiteModelType model_type) {
scope_ = std::make_shared<Scope>();
Build(model_dir, model_type);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -38,9 +38,15 @@ namespace lite { ...@@ -38,9 +38,15 @@ namespace lite {
*/ */
class LITE_API LightPredictor { class LITE_API LightPredictor {
public: public:
explicit LightPredictor( LightPredictor(
const std::string& model_dir, const std::string& model_dir,
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); const std::string& model_buffer = "",
const std::string& param_buffer = "",
bool model_from_memory = false,
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf) {
scope_ = std::make_shared<Scope>();
Build(model_dir, model_buffer, param_buffer, model_type, model_from_memory);
}
void Run() { program_->Run(); } void Run() { program_->Run(); }
...@@ -58,7 +64,11 @@ class LITE_API LightPredictor { ...@@ -58,7 +64,11 @@ class LITE_API LightPredictor {
private: private:
void Build( void Build(
const std::string& model_dir, const std::string& model_dir,
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf); const std::string& model_buffer,
const std::string& param_buffer,
lite_api::LiteModelType model_type = lite_api::LiteModelType::kProtobuf,
bool model_from_memory = false);
void BuildRuntimeProgram(const cpp::ProgramDesc& prog); void BuildRuntimeProgram(const cpp::ProgramDesc& prog);
private: private:
......
...@@ -45,6 +45,9 @@ void LightPredictorImpl::Init(const MobileConfig& config) { ...@@ -45,6 +45,9 @@ void LightPredictorImpl::Init(const MobileConfig& config) {
lite::DeviceInfo::Global().SetRunMode(config.power_mode(), config.threads()); lite::DeviceInfo::Global().SetRunMode(config.power_mode(), config.threads());
#endif #endif
raw_predictor_.reset(new lite::LightPredictor(config.model_dir(), raw_predictor_.reset(new lite::LightPredictor(config.model_dir(),
config.model_buffer(),
config.param_buffer(),
config.model_from_memory(),
LiteModelType::kNaiveBuffer)); LiteModelType::kNaiveBuffer));
} }
......
...@@ -28,7 +28,46 @@ TEST(LightAPI, load) { ...@@ -28,7 +28,46 @@ TEST(LightAPI, load) {
if (FLAGS_optimized_model.empty()) { if (FLAGS_optimized_model.empty()) {
FLAGS_optimized_model = "lite_naive_model"; FLAGS_optimized_model = "lite_naive_model";
} }
LightPredictor predictor(FLAGS_optimized_model); LightPredictor predictor(FLAGS_optimized_model, "", "");
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<int64_t>({100, 100})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor.Run();
const auto* output = predictor.GetOutput(0);
const float* raw_output = output->data<float>();
for (int i = 0; i < 10; i++) {
LOG(INFO) << "out " << raw_output[i];
}
}
TEST(LightAPI, loadNaiveBuffer) {
if (FLAGS_optimized_model.empty()) {
FLAGS_optimized_model = "lite_naive_model";
}
auto model_path = std::string(FLAGS_optimized_model) + "/__model__.nb";
auto params_path = std::string(FLAGS_optimized_model) + "/param.nb";
std::string model_buffer = lite::ReadFile(model_path);
size_t size_model = model_buffer.length();
std::string params_buffer = lite::ReadFile(params_path);
size_t size_params = params_buffer.length();
LOG(INFO) << "sizeModel: " << size_model;
LOG(INFO) << "sizeParams: " << size_params;
lite_api::MobileConfig config;
config.set_model_buffer(
model_buffer.c_str(), size_model, params_buffer.c_str(), size_params);
LightPredictor predictor(config.model_dir(),
config.model_buffer(),
config.param_buffer(),
config.model_from_memory(),
lite_api::LiteModelType::kNaiveBuffer);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<int64_t>({100, 100}))); input_tensor->Resize(DDim(std::vector<int64_t>({100, 100})));
......
...@@ -129,6 +129,9 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -129,6 +129,9 @@ class LITE_API CxxConfig : public ConfigBase {
class LITE_API MobileConfig : public ConfigBase { class LITE_API MobileConfig : public ConfigBase {
PowerMode mode_{LITE_POWER_HIGH}; PowerMode mode_{LITE_POWER_HIGH};
int threads_{1}; int threads_{1};
std::string model_buffer_;
std::string param_buffer_;
bool model_from_memory_{false};
public: public:
MobileConfig(Place preferred_place = Place(TARGET(kARM), MobileConfig(Place preferred_place = Place(TARGET(kARM),
...@@ -139,9 +142,20 @@ class LITE_API MobileConfig : public ConfigBase { ...@@ -139,9 +142,20 @@ class LITE_API MobileConfig : public ConfigBase {
: mode_(mode), threads_(threads) {} : mode_(mode), threads_(threads) {}
void set_power_mode(PowerMode mode) { mode_ = mode; } void set_power_mode(PowerMode mode) { mode_ = mode; }
void set_threads(int threads) { threads_ = threads; } void set_threads(int threads) { threads_ = threads; }
void set_model_buffer(const char* model_buffer,
size_t model_buffer_size,
const char* param_buffer,
size_t param_buffer_size) {
model_buffer_ = std::string(model_buffer, model_buffer + model_buffer_size);
param_buffer_ = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true;
}
PowerMode power_mode() const { return mode_; } PowerMode power_mode() const { return mode_; }
int threads() const { return threads_; } int threads() const { return threads_; }
bool model_from_memory() const { return model_from_memory_; }
const std::string& model_buffer() const { return model_buffer_; }
const std::string& param_buffer() const { return param_buffer_; }
}; };
template <typename ConfigT> template <typename ConfigT>
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h" #include "lite/api/paddle_use_passes.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/io.h"
DEFINE_string(model_dir, "", ""); DEFINE_string(model_dir, "", "");
namespace paddle { namespace paddle {
...@@ -58,6 +58,7 @@ TEST(CxxApi, run) { ...@@ -58,6 +58,7 @@ TEST(CxxApi, run) {
LiteModelType::kNaiveBuffer); LiteModelType::kNaiveBuffer);
} }
// Demo1 for Mobile Devices :Load model from file and run
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(LightApi, run) { TEST(LightApi, run) {
lite_api::MobileConfig config; lite_api::MobileConfig config;
...@@ -82,6 +83,39 @@ TEST(LightApi, run) { ...@@ -82,6 +83,39 @@ TEST(LightApi, run) {
EXPECT_NEAR(out[0], 50.2132, 1e-3); EXPECT_NEAR(out[0], 50.2132, 1e-3);
EXPECT_NEAR(out[1], -28.8729, 1e-3); EXPECT_NEAR(out[1], -28.8729, 1e-3);
} }
// Demo2 for Loading model from memory
TEST(MobileConfig, LoadfromMemory) {
// Get naive buffer
auto model_path = std::string(FLAGS_model_dir) + ".opt2.naive/__model__.nb";
auto params_path = std::string(FLAGS_model_dir) + ".opt2.naive/param.nb";
std::string model_buffer = lite::ReadFile(model_path);
size_t size_model = model_buffer.length();
std::string params_buffer = lite::ReadFile(params_path);
size_t size_params = params_buffer.length();
// set model buffer and run model
lite_api::MobileConfig config;
config.set_model_buffer(
model_buffer.c_str(), size_model, params_buffer.c_str(), size_params);
auto predictor = lite_api::CreatePaddlePredictor(config);
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(std::vector<int64_t>({100, 100}));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor->Run();
const auto output = predictor->GetOutput(0);
const float* raw_output = output->data<float>();
for (int i = 0; i < 10; i++) {
LOG(INFO) << "out " << raw_output[i];
}
}
#endif #endif
} // namespace lite_api } // namespace lite_api
......
...@@ -661,9 +661,14 @@ void LoadParamNaive(const std::string &path, ...@@ -661,9 +661,14 @@ void LoadParamNaive(const std::string &path,
void LoadCombinedParamsNaive(const std::string &path, void LoadCombinedParamsNaive(const std::string &path,
lite::Scope *scope, lite::Scope *scope,
const cpp::ProgramDesc &cpp_prog) { const cpp::ProgramDesc &cpp_prog,
bool params_from_memory) {
naive_buffer::BinaryTable table; naive_buffer::BinaryTable table;
if (params_from_memory) {
table.LoadFromMemory(path.c_str(), path.length());
} else {
table.LoadFromFile(path); table.LoadFromFile(path);
}
naive_buffer::proto::CombinedParamsDesc pt_desc(&table); naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
pt_desc.Load(); pt_desc.Load();
naive_buffer::CombinedParamsDesc desc(&pt_desc); naive_buffer::CombinedParamsDesc desc(&pt_desc);
...@@ -710,7 +715,7 @@ void LoadModelNaive(const std::string &model_dir, ...@@ -710,7 +715,7 @@ void LoadModelNaive(const std::string &model_dir,
// NOTE: Only main block be used now. // NOTE: Only main block be used now.
if (combined) { if (combined) {
const std::string combined_params_path = model_dir + "/param.nb"; const std::string combined_params_path = model_dir + "/param.nb";
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog); LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, false);
} else { } else {
auto &prog = *cpp_prog; auto &prog = *cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
...@@ -750,5 +755,40 @@ void LoadModelNaive(const std::string &model_dir, ...@@ -750,5 +755,40 @@ void LoadModelNaive(const std::string &model_dir,
VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully"; VLOG(4) << "Load naive buffer model in '" << model_dir << "' successfully";
} }
void LoadModelNaiveFromMemory(const std::string &model_buffer,
const std::string &param_buffer,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
CHECK(cpp_prog);
CHECK(scope);
cpp_prog->ClearBlocks();
// Load model
std::string prog_path = model_buffer;
naive_buffer::BinaryTable table;
table.LoadFromMemory(prog_path.c_str(), prog_path.length());
naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
nb_proto_prog.Load();
naive_buffer::ProgramDesc nb_prog(&nb_proto_prog);
// Transform to cpp::ProgramDesc
TransformProgramDescAnyToCpp(nb_prog, cpp_prog);
// Load Params
// NOTE: Only main block be used now.
// only combined Params are supported in Loading Model from memory
std::string combined_params_path = param_buffer;
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog, true);
#ifdef LITE_WITH_NPU
LOG(FATAL) << "load from memory is not supported by NPU";
#endif
VLOG(4) << "Load model from naive buffer memory successfully";
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -94,14 +94,15 @@ void LoadParamNaive(const std::string& path, ...@@ -94,14 +94,15 @@ void LoadParamNaive(const std::string& path,
lite::Scope* scope, lite::Scope* scope,
const std::string& name); const std::string& name);
void LoadCombinedParamsNaive(const std::string& path,
lite::Scope* scope,
const cpp::ProgramDesc& cpp_prog);
void LoadModelNaive(const std::string& model_dir, void LoadModelNaive(const std::string& model_dir,
lite::Scope* scope, lite::Scope* scope,
cpp::ProgramDesc* prog, cpp::ProgramDesc* prog,
bool combined = true); bool combined = true);
void LoadModelNaiveFromMemory(const std::string& model_buffer,
const std::string& param_buffer,
lite::Scope* scope,
cpp::ProgramDesc* cpp_prog);
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -121,12 +121,17 @@ TEST(ModelParser, SaveModelNaive) { ...@@ -121,12 +121,17 @@ TEST(ModelParser, SaveModelNaive) {
SaveModelNaive(save_pb_model_path, scope, prog); SaveModelNaive(save_pb_model_path, scope, prog);
} }
TEST(ModelParser, LoadModelNaive) { TEST(ModelParser, LoadModelNaiveFromMemory) {
CHECK(!FLAGS_model_dir.empty()); CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog; cpp::ProgramDesc prog;
Scope scope; Scope scope;
const std::string model_path = FLAGS_model_dir + ".saved.naive";
LoadModelNaive(model_path, &scope, &prog); auto model_path = std::string(FLAGS_model_dir) + ".saved.naive/__model__.nb";
auto params_path = std::string(FLAGS_model_dir) + ".saved.naive/param.nb";
std::string model_buffer = lite::ReadFile(model_path);
std::string params_buffer = lite::ReadFile(params_path);
LoadModelNaiveFromMemory(model_buffer, params_buffer, &scope, &prog);
} }
} // namespace lite } // namespace lite
......
...@@ -66,6 +66,14 @@ void BinaryTable::LoadFromFile(const std::string &filename) { ...@@ -66,6 +66,14 @@ void BinaryTable::LoadFromFile(const std::string &filename) {
is_mutable_mode_ = false; is_mutable_mode_ = false;
} }
void BinaryTable::LoadFromMemory(const char *buffer, size_t buffer_size) {
// get buffer
bytes_.resize(buffer_size);
memcpy(reinterpret_cast<char *>(&bytes_[0]), buffer, buffer_size);
// Set readonly.
is_mutable_mode_ = false;
}
void StringBuilder::Save() { void StringBuilder::Save() {
// memory format: [size][string data] // memory format: [size][string data]
uint64_t mem_size = sizeof(uint64_t) + data_.size(); uint64_t mem_size = sizeof(uint64_t) + data_.size();
......
...@@ -63,6 +63,7 @@ struct BinaryTable { ...@@ -63,6 +63,7 @@ struct BinaryTable {
void SaveToFile(const std::string& filename) const; void SaveToFile(const std::string& filename) const;
void LoadFromFile(const std::string& filename); void LoadFromFile(const std::string& filename);
void LoadFromMemory(const char* buffer, size_t buffer_size);
}; };
/* /*
......
...@@ -43,5 +43,14 @@ static void MkDirRecur(const std::string& path) { ...@@ -43,5 +43,14 @@ static void MkDirRecur(const std::string& path) {
#endif #endif
} }
// read buffer from file
static std::string ReadFile(const std::string& filename) {
std::ifstream ifile(filename.c_str());
std::ostringstream buf;
char ch;
while (buf && ifile.get(ch)) buf.put(ch);
return buf.str();
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册