提交 93950441 编写于 作者: S sangoly 提交者: Yan Chunwei

[Protobuf] add combined-param model save/load supported test=develop (#1876)

上级 866f6be5
...@@ -57,8 +57,11 @@ TEST(CXXApi_LightApi, optim_model) { ...@@ -57,8 +57,11 @@ TEST(CXXApi_LightApi, optim_model) {
}); });
// On ARM devices, the preferred X86 target not works, but it can still // On ARM devices, the preferred X86 target not works, but it can still
// select ARM kernels. // select ARM kernels.
cxx_api.Build( cxx_api.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
cxx_api.SaveModel(FLAGS_optimized_model); cxx_api.SaveModel(FLAGS_optimized_model);
} }
...@@ -75,8 +78,11 @@ TEST(CXXApi_LightApi, save_and_load_model) { ...@@ -75,8 +78,11 @@ TEST(CXXApi_LightApi, save_and_load_model) {
}); });
// On ARM devices, the preferred X86 target not works, but it can still // On ARM devices, the preferred X86 target not works, but it can still
// select ARM kernels. // select ARM kernels.
cxx_api.Build( cxx_api.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
auto* x = cxx_api.GetInput(0); auto* x = cxx_api.GetInput(0);
SetConstInput(x); SetConstInput(x);
......
...@@ -33,7 +33,7 @@ void Predictor::SaveModel(const std::string &dir, ...@@ -33,7 +33,7 @@ void Predictor::SaveModel(const std::string &dir,
program_->SaveOpInfosToProgram(&program_desc_); program_->SaveOpInfosToProgram(&program_desc_);
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
SaveModelPb(dir, *program_->exec_scope(), program_desc_); SaveModelPb(dir, *program_->exec_scope(), program_desc_, true);
break; break;
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
SaveModelNaive(dir, *program_->exec_scope(), program_desc_); SaveModelNaive(dir, *program_->exec_scope(), program_desc_);
...@@ -84,16 +84,29 @@ const cpp::ProgramDesc &Predictor::program_desc() const { ...@@ -84,16 +84,29 @@ const cpp::ProgramDesc &Predictor::program_desc() const {
const RuntimeProgram &Predictor::runtime_program() const { return *program_; } const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
void Predictor::Build(const std::string &model_path, void Predictor::Build(const std::string &model_path,
const std::string model_file,
const std::string param_file,
const Place &prefer_place, const Place &prefer_place,
const std::vector<Place> &valid_places, const std::vector<Place> &valid_places,
const std::vector<std::string> &passes, const std::vector<std::string> &passes,
lite_api::LiteModelType model_type) { lite_api::LiteModelType model_type) {
LOG(INFO) << "Load model from " << model_path; LOG(INFO) << "Load model from " << model_path;
switch (model_type) { switch (model_type) {
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf: {
LoadModelPb(model_path, scope_.get(), &program_desc_); bool combined_param = false;
break; if (!model_file.empty() && !param_file.empty()) {
combined_param = true;
}
LoadModelPb(model_path,
model_file,
param_file,
scope_.get(),
&program_desc_,
combined_param);
} break;
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
CHECK(!model_path.empty())
<< "NaiveBuffer backend only supported combined param";
LoadModelNaive(model_path, scope_.get(), &program_desc_); LoadModelNaive(model_path, scope_.get(), &program_desc_);
break; break;
default: default:
......
...@@ -41,6 +41,8 @@ class LITE_API Predictor { ...@@ -41,6 +41,8 @@ class LITE_API Predictor {
// Build from a model, with places set for hardware config. // Build from a model, with places set for hardware config.
void Build( void Build(
const std::string& model_path, const std::string& model_path,
const std::string model_file_path,
const std::string param_file_path,
const Place& prefer_place, const Place& prefer_place,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& passes = {}, const std::vector<std::string>& passes = {},
......
...@@ -41,7 +41,7 @@ void Run(const char* model_dir, int repeat) { ...@@ -41,7 +41,7 @@ void Run(const char* model_dir, int repeat) {
}); });
predictor.Build( predictor.Build(
model_dir, Place{TARGET(kARM), PRECISION(kInt8)}, valid_places); model_dir, "", "", Place{TARGET(kARM), PRECISION(kInt8)}, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -47,7 +47,11 @@ CxxPaddleApiImpl::CxxPaddleApiImpl() {} ...@@ -47,7 +47,11 @@ CxxPaddleApiImpl::CxxPaddleApiImpl() {}
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
auto places = config.valid_places(); auto places = config.valid_places();
places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
raw_predictor_.Build(config.model_dir(), config.preferred_place(), places); raw_predictor_.Build(config.model_dir(),
config.model_file(),
config.param_file(),
config.preferred_place(),
places);
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......
...@@ -45,8 +45,11 @@ TEST(CXXApi, save_model) { ...@@ -45,8 +45,11 @@ TEST(CXXApi, save_model) {
lite::Predictor predictor; lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}}); Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build( predictor.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kCUDA), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model, predictor.SaveModel(FLAGS_optimized_model,
...@@ -93,8 +96,11 @@ TEST(CXXApi, save_model) { ...@@ -93,8 +96,11 @@ TEST(CXXApi, save_model) {
lite::Predictor predictor; lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}}); Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build( predictor.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model); predictor.SaveModel(FLAGS_optimized_model);
...@@ -107,6 +113,8 @@ TEST(CXXApi, load_model_naive) { ...@@ -107,6 +113,8 @@ TEST(CXXApi, load_model_naive) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}}); Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build(FLAGS_optimized_model + ".naive", predictor.Build(FLAGS_optimized_model + ".naive",
"",
"",
Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)},
valid_places, valid_places,
{}, {},
......
...@@ -31,7 +31,7 @@ void TestModel(const std::vector<Place> &valid_places, ...@@ -31,7 +31,7 @@ void TestModel(const std::vector<Place> &valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places); predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto *input_tensor = predictor.GetInput(0); auto *input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -33,8 +33,11 @@ TEST(InceptionV4, test) { ...@@ -33,8 +33,11 @@ TEST(InceptionV4, test) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}}); Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build( predictor.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -24,7 +24,7 @@ void LightPredictor::Build(const std::string& model_dir, ...@@ -24,7 +24,7 @@ void LightPredictor::Build(const std::string& model_dir,
switch (model_type) { switch (model_type) {
#ifndef LITE_ON_TINY_PUBLISH #ifndef LITE_ON_TINY_PUBLISH
case lite_api::LiteModelType::kProtobuf: case lite_api::LiteModelType::kProtobuf:
LoadModelPb(model_dir, scope_.get(), &desc); LoadModelPb(model_dir, "", "", scope_.get(), &desc);
break; break;
#endif #endif
case lite_api::LiteModelType::kNaiveBuffer: case lite_api::LiteModelType::kNaiveBuffer:
......
...@@ -38,6 +38,8 @@ const lite::Tensor* RunHvyModel() { ...@@ -38,6 +38,8 @@ const lite::Tensor* RunHvyModel() {
#endif #endif
predictor.Build(FLAGS_model_dir, predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda
valid_places); valid_places);
......
...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places); predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places); predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 300, 300}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 300, 300})));
......
...@@ -36,7 +36,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -36,7 +36,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(model_dir, preferred_place, valid_places); predictor.Build(model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places); predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 608, 608}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 608, 608})));
......
...@@ -37,7 +37,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -37,7 +37,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(model_dir, preferred_place, valid_places); predictor.Build(model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -23,7 +23,12 @@ ...@@ -23,7 +23,12 @@
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
DEFINE_string(model_dir, "", "path of the model"); DEFINE_string(model_dir,
"",
"path of the model. This option will be ignored if model_file "
"and param_file are exist");
DEFINE_string(model_file, "", "model file path of the combined-param model");
DEFINE_string(param_file, "", "param file path of the combined-param model");
DEFINE_string( DEFINE_string(
optimize_out_type, optimize_out_type,
"protobuf", "protobuf",
...@@ -39,8 +44,15 @@ namespace paddle { ...@@ -39,8 +44,15 @@ namespace paddle {
namespace lite_api { namespace lite_api {
void Main() { void Main() {
if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) {
LOG(WARNING)
<< "Load combined-param model. Option model_dir will be ignored";
}
lite_api::CxxConfig config; lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir); config.set_model_dir(FLAGS_model_dir);
config.set_model_file(FLAGS_model_file);
config.set_param_file(FLAGS_param_file);
std::vector<Place> valid_places; std::vector<Place> valid_places;
auto target_reprs = lite::Split(FLAGS_valid_targets, " "); auto target_reprs = lite::Split(FLAGS_valid_targets, " ");
......
...@@ -39,7 +39,7 @@ TEST(model, test) { ...@@ -39,7 +39,7 @@ TEST(model, test) {
precision = PRECISION(kInt8); precision = PRECISION(kInt8);
} }
predictor.Build( predictor.Build(
FLAGS_model_dir, Place{TARGET(kARM), precision}, valid_places); FLAGS_model_dir, "", "", Place{TARGET(kARM), precision}, valid_places);
int im_width = FLAGS_im_width; int im_width = FLAGS_im_width;
int im_height = FLAGS_im_height; int im_height = FLAGS_im_height;
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
......
...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places); predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 1, 48, 512}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 1, 48, 512})));
......
...@@ -99,13 +99,19 @@ class LITE_API ConfigBase { ...@@ -99,13 +99,19 @@ class LITE_API ConfigBase {
class LITE_API CxxConfig : public ConfigBase { class LITE_API CxxConfig : public ConfigBase {
Place preferred_place_; Place preferred_place_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
std::string model_file_;
std::string param_file_;
public: public:
void set_preferred_place(const Place& x) { preferred_place_ = x; } void set_preferred_place(const Place& x) { preferred_place_ = x; }
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; } void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
void set_model_file(const std::string& path) { model_file_ = path; }
void set_param_file(const std::string& path) { param_file_ = path; }
const Place& preferred_place() const { return preferred_place_; } const Place& preferred_place() const { return preferred_place_; }
const std::vector<Place>& valid_places() const { return valid_places_; } const std::vector<Place>& valid_places() const { return valid_places_; }
std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; }
}; };
/// MobileConfig is the config for the light weight predictor, it will skip /// MobileConfig is the config for the light weight predictor, it will skip
......
...@@ -31,8 +31,11 @@ TEST(ResNet18, test) { ...@@ -31,8 +31,11 @@ TEST(ResNet18, test) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}}); Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build( predictor.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -32,7 +32,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places); predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -33,6 +33,8 @@ TEST(ResNet50, test) { ...@@ -33,6 +33,8 @@ TEST(ResNet50, test) {
Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNHWC)}}); Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNHWC)}});
predictor.Build(FLAGS_model_dir, predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)}, Place{TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)},
valid_places); valid_places);
......
...@@ -31,7 +31,7 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -31,7 +31,7 @@ void TestModel(const std::vector<Place>& valid_places,
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, preferred_place, valid_places); predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim((std::vector<DDim::value_type>({1, 3, 224, 224})))); input_tensor->Resize(DDim((std::vector<DDim::value_type>({1, 3, 224, 224}))));
......
...@@ -51,7 +51,7 @@ TEST(CXXApi, test_lite_googlenet) { ...@@ -51,7 +51,7 @@ TEST(CXXApi, test_lite_googlenet) {
// LOG(INFO)<<"FLAGS_eval_googlenet_dir:"<<FLAGS_test_lite_googlenet_dir; // LOG(INFO)<<"FLAGS_eval_googlenet_dir:"<<FLAGS_test_lite_googlenet_dir;
std::string model_dir = FLAGS_model_dir; std::string model_dir = FLAGS_model_dir;
predictor.Build( predictor.Build(
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places); model_dir, "", "", Place{TARGET(kX86), PRECISION(kFloat)}, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -55,8 +55,12 @@ TEST(InceptionV4, test_inceptionv4_lite_x86) { ...@@ -55,8 +55,12 @@ TEST(InceptionV4, test_inceptionv4_lite_x86) {
"io_copy_kernel_pick_pass", "io_copy_kernel_pick_pass",
"variable_place_inference_pass", "variable_place_inference_pass",
"runtime_context_assign_pass"}); "runtime_context_assign_pass"});
predictor.Build( predictor.Build(model_dir,
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes); "",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places,
passes);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -54,8 +54,12 @@ TEST(Mobilenet_v1, test_mobilenetv1_lite_x86) { ...@@ -54,8 +54,12 @@ TEST(Mobilenet_v1, test_mobilenetv1_lite_x86) {
"io_copy_kernel_pick_pass", "io_copy_kernel_pick_pass",
"variable_place_inference_pass", "variable_place_inference_pass",
"runtime_context_assign_pass"}); "runtime_context_assign_pass"});
predictor.Build( predictor.Build(model_dir,
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes); "",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places,
passes);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
......
...@@ -56,8 +56,12 @@ TEST(Mobilenet_v2, test_mobilenetv2_lite_x86) { ...@@ -56,8 +56,12 @@ TEST(Mobilenet_v2, test_mobilenetv2_lite_x86) {
"io_copy_kernel_pick_pass", "io_copy_kernel_pick_pass",
"variable_place_inference_pass", "variable_place_inference_pass",
"runtime_context_assign_pass"}); "runtime_context_assign_pass"});
predictor.Build( predictor.Build(model_dir,
model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places, passes); "",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places,
passes);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......
...@@ -33,8 +33,11 @@ TEST(unet, test) { ...@@ -33,8 +33,11 @@ TEST(unet, test) {
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)}}); Place{TARGET(kARM), PRECISION(kFloat)}});
predictor.Build( predictor.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kARM), PRECISION(kFloat)},
valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 512, 512}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 512, 512})));
......
...@@ -44,6 +44,8 @@ TEST(fc_fuse_pass, fuse_test) { ...@@ -44,6 +44,8 @@ TEST(fc_fuse_pass, fuse_test) {
#endif #endif
predictor.Build(FLAGS_model_dir, predictor.Build(FLAGS_model_dir,
"",
"",
Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda
valid_places); valid_places);
...@@ -72,8 +74,11 @@ TEST(fc_fuse_pass, save_model_test) { ...@@ -72,8 +74,11 @@ TEST(fc_fuse_pass, save_model_test) {
lite::Predictor predictor; lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}}); Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build( predictor.Build(FLAGS_model_dir,
FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, valid_places); "",
"",
Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model); predictor.SaveModel(FLAGS_optimized_model);
......
...@@ -40,7 +40,7 @@ void TestModel(lite::Predictor* predictor, ...@@ -40,7 +40,7 @@ void TestModel(lite::Predictor* predictor,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::string& model_dir) { const std::string& model_dir) {
predictor->Build( predictor->Build(
model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); model_dir, "", "", Place{TARGET(kARM), PRECISION(kFloat)}, valid_places);
auto* input_tensor = predictor->GetInput(0); auto* input_tensor = predictor->GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>( input_tensor->Resize(DDim(std::vector<DDim::value_type>(
......
...@@ -32,7 +32,7 @@ namespace lite { ...@@ -32,7 +32,7 @@ namespace lite {
TEST(SubgraphTest, mobilenetv2) { TEST(SubgraphTest, mobilenetv2) {
cpp::ProgramDesc program_desc; cpp::ProgramDesc program_desc;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
LoadModelPb(FLAGS_model_dir, scope.get(), &program_desc); LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc);
std::vector<Place> valid_places({ std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kHost), PRECISION(kFloat)},
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
......
...@@ -145,7 +145,10 @@ TEST(gen_code, auto_gen) { ...@@ -145,7 +145,10 @@ TEST(gen_code, auto_gen) {
TEST(gen_code, optimized_program) { TEST(gen_code, optimized_program) {
lite::Scope scope; lite::Scope scope;
cpp::ProgramDesc cpp_desc; cpp::ProgramDesc cpp_desc;
LoadModelPb(FLAGS_optimized_model, &scope, &cpp_desc); std::string model_file = FLAGS_optimized_model + "/model";
std::string param_file = FLAGS_optimized_model + "/params";
LoadModelPb(
FLAGS_optimized_model, model_file, param_file, &scope, &cpp_desc, true);
framework::proto::ProgramDesc pb_proto_desc; framework::proto::ProgramDesc pb_proto_desc;
lite::pb::ProgramDesc pb_desc(&pb_proto_desc); lite::pb::ProgramDesc pb_desc(&pb_proto_desc);
......
...@@ -27,7 +27,9 @@ namespace gencode { ...@@ -27,7 +27,9 @@ namespace gencode {
void GenCode(const std::string& model_dir, const std::string& out_file) { void GenCode(const std::string& model_dir, const std::string& out_file) {
lite::Scope scope; lite::Scope scope;
cpp::ProgramDesc cpp_desc; cpp::ProgramDesc cpp_desc;
LoadModelPb(model_dir, &scope, &cpp_desc); std::string model_file = model_dir + "/model";
std::string param_file = model_dir + "/params";
LoadModelPb(model_dir, model_file, param_file, &scope, &cpp_desc, true);
framework::proto::ProgramDesc pb_proto_desc; framework::proto::ProgramDesc pb_proto_desc;
lite::pb::ProgramDesc pb_desc(&pb_proto_desc); lite::pb::ProgramDesc pb_desc(&pb_proto_desc);
......
...@@ -159,15 +159,63 @@ void LoadParam(const std::string &path, Variable *out) { ...@@ -159,15 +159,63 @@ void LoadParam(const std::string &path, Variable *out) {
LoadLoDTensor(fin, out); LoadLoDTensor(fin, out);
} }
//
bool IsPersistable(const cpp::VarDesc &var) {
if (var.Persistable() && var.GetType() != VarDescAPI::Type::FEED_MINIBATCH &&
var.GetType() != VarDescAPI::Type::FETCH_LIST &&
var.GetType() != VarDescAPI::Type::RAW) {
return true;
}
return false;
}
void LoadCombinedParamsPb(const std::string &path,
lite::Scope *scope,
const cpp::ProgramDesc &cpp_prog) {
CHECK(scope);
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
std::vector<std::string> paramlist;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (!IsPersistable(var)) continue;
paramlist.push_back(var.Name());
}
std::sort(paramlist.begin(), paramlist.end());
// Load vars
std::ifstream file(path);
CHECK(file.is_open());
for (size_t i = 0; i < paramlist.size(); ++i) {
auto *var = scope->Var(paramlist[i]);
// Error checking
CHECK(static_cast<bool>(file))
<< "There is a problem with loading model parameters";
LoadLoDTensor(file, var);
}
file.peek();
CHECK(file.eof()) << "You are not allowed to load partial data via"
<< " LoadCombinedParamsPb, use LoadParam instead.";
file.close();
}
void LoadModelPb(const std::string &model_dir, void LoadModelPb(const std::string &model_dir,
const std::string &model_file,
const std::string &param_file,
Scope *scope, Scope *scope,
cpp::ProgramDesc *cpp_prog) { cpp::ProgramDesc *cpp_prog,
bool combined) {
CHECK(cpp_prog); CHECK(cpp_prog);
CHECK(scope); CHECK(scope);
cpp_prog->ClearBlocks(); cpp_prog->ClearBlocks();
// Load model // Load model
const std::string prog_path = model_dir + "/__model__"; std::string prog_path = model_dir + "/__model__";
if (combined) {
prog_path = model_file;
}
framework::proto::ProgramDesc pb_proto_prog = *LoadProgram(prog_path); framework::proto::ProgramDesc pb_proto_prog = *LoadProgram(prog_path);
pb::ProgramDesc pb_prog(&pb_proto_prog); pb::ProgramDesc pb_prog(&pb_proto_prog);
...@@ -176,6 +224,9 @@ void LoadModelPb(const std::string &model_dir, ...@@ -176,6 +224,9 @@ void LoadModelPb(const std::string &model_dir,
// Load Params // Load Params
// NOTE: Only main block be used now. // NOTE: Only main block be used now.
if (combined) {
LoadCombinedParamsPb(param_file, scope, *cpp_prog);
} else {
auto main_block = pb_proto_prog.blocks(0); auto main_block = pb_proto_prog.blocks(0);
for (auto &var : main_block.vars()) { for (auto &var : main_block.vars()) {
if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) if (var.name() == "feed" || var.name() == "fetch" || !var.persistable())
...@@ -193,6 +244,8 @@ void LoadModelPb(const std::string &model_dir, ...@@ -193,6 +244,8 @@ void LoadModelPb(const std::string &model_dir,
CHECK(false) << "unknown weight type"; CHECK(false) << "unknown weight type";
} }
} }
}
#ifdef LITE_WITH_NPU #ifdef LITE_WITH_NPU
for (auto &op : main_block.ops()) { for (auto &op : main_block.ops()) {
LOG(INFO) << "op type:" << op.type(); LOG(INFO) << "op type:" << op.type();
...@@ -216,14 +269,18 @@ void LoadModelPb(const std::string &model_dir, ...@@ -216,14 +269,18 @@ void LoadModelPb(const std::string &model_dir,
void SaveModelPb(const std::string &model_dir, void SaveModelPb(const std::string &model_dir,
const Scope &exec_scope, const Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) { const cpp::ProgramDesc &cpp_prog,
bool combined) {
MkDirRecur(model_dir); MkDirRecur(model_dir);
// Save program // Save program
framework::proto::ProgramDesc pb_proto_prog; framework::proto::ProgramDesc pb_proto_prog;
pb::ProgramDesc pb_prog(&pb_proto_prog); pb::ProgramDesc pb_prog(&pb_proto_prog);
TransformProgramDescCppToAny(cpp_prog, &pb_prog); TransformProgramDescCppToAny(cpp_prog, &pb_prog);
const std::string prog_path = model_dir + "/__model__"; std::string prog_path = model_dir + "/__model__";
if (combined) {
prog_path = model_dir + "/model";
}
std::ofstream model_ostream(prog_path, std::ios_base::binary); std::ofstream model_ostream(prog_path, std::ios_base::binary);
CHECK(model_ostream.is_open()); CHECK(model_ostream.is_open());
const std::string pb_str = pb_proto_prog.SerializeAsString(); const std::string pb_str = pb_proto_prog.SerializeAsString();
...@@ -232,8 +289,13 @@ void SaveModelPb(const std::string &model_dir, ...@@ -232,8 +289,13 @@ void SaveModelPb(const std::string &model_dir,
// Save Params // Save Params
// NOTE: Only main block be used now. // NOTE: Only main block be used now.
if (combined) {
const std::string combined_params_path = model_dir + "/params";
SaveCombinedParamsPb(combined_params_path, exec_scope, cpp_prog);
} else {
for (auto &item : pb_proto_prog.blocks(0).vars()) { for (auto &item : pb_proto_prog.blocks(0).vars()) {
if (item.name() == "feed" || item.name() == "fetch" || !item.persistable()) if (item.name() == "feed" || item.name() == "fetch" ||
!item.persistable())
continue; continue;
const std::string path = model_dir + "/" + item.name(); const std::string path = model_dir + "/" + item.name();
std::ofstream var_ostream(path, std::ios::binary); std::ofstream var_ostream(path, std::ios::binary);
...@@ -241,9 +303,34 @@ void SaveModelPb(const std::string &model_dir, ...@@ -241,9 +303,34 @@ void SaveModelPb(const std::string &model_dir,
SerializeTensor(var_ostream, exec_scope, item.name()); SerializeTensor(var_ostream, exec_scope, item.name());
var_ostream.close(); var_ostream.close();
} }
}
VLOG(4) << "Save protobuf model in '" << model_dir << "'' successfully"; VLOG(4) << "Save protobuf model in '" << model_dir << "'' successfully";
} }
void SaveCombinedParamsPb(const std::string &path,
const lite::Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
std::vector<std::string> paramlist;
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (!IsPersistable(var)) continue;
paramlist.push_back(var.Name());
}
std::sort(paramlist.begin(), paramlist.end());
// Load vars
std::ofstream file(path);
CHECK(file.is_open());
for (size_t i = 0; i < paramlist.size(); ++i) {
SerializeTensor(file, exec_scope, paramlist[i]);
}
file.close();
}
void TensorToStream(std::ostream &os, const lite::Tensor &tensor) { void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
// the 1st field, uint32_t version // the 1st field, uint32_t version
constexpr uint32_t version = 0; constexpr uint32_t version = 0;
......
...@@ -41,14 +41,26 @@ void LoadParams(const std::string& path); ...@@ -41,14 +41,26 @@ void LoadParams(const std::string& path);
// Load a single parameter to an output tensor. // Load a single parameter to an output tensor.
void LoadParam(const std::string& path, Variable* out); void LoadParam(const std::string& path, Variable* out);
void LoadCombinedParamsPb(const std::string& path,
lite::Scope* scope,
const cpp::ProgramDesc& prog);
// Read a model and files of parameters in pb format. // Read a model and files of parameters in pb format.
void LoadModelPb(const std::string& model_dir, void LoadModelPb(const std::string& model_dir,
const std::string& model_file,
const std::string& param_file,
Scope* scope, Scope* scope,
cpp::ProgramDesc* prog); cpp::ProgramDesc* prog,
bool combined = false);
// Save a model and files of parameters in pb format. // Save a model and files of parameters in pb format.
void SaveModelPb(const std::string& model_dir, void SaveModelPb(const std::string& model_dir,
const Scope& scope, const Scope& scope,
const cpp::ProgramDesc& prog,
bool combined = false);
void SaveCombinedParamsPb(const std::string& path,
const lite::Scope& exec_scope,
const cpp::ProgramDesc& prog); const cpp::ProgramDesc& prog);
// Serialize tensors to ostream. // Serialize tensors to ostream.
......
...@@ -40,18 +40,38 @@ TEST(ModelParser, LoadModelPb) { ...@@ -40,18 +40,38 @@ TEST(ModelParser, LoadModelPb) {
CHECK(!FLAGS_model_dir.empty()); CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog; cpp::ProgramDesc prog;
Scope scope; Scope scope;
LoadModelPb(FLAGS_model_dir, &scope, &prog); LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
} }
TEST(ModelParser, SaveModelPb) { TEST(ModelParser, SaveModelPb) {
CHECK(!FLAGS_model_dir.empty()); CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog; cpp::ProgramDesc prog;
Scope scope; Scope scope;
LoadModelPb(FLAGS_model_dir, &scope, &prog); LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
const std::string save_pb_model_path = FLAGS_model_dir + ".saved.pb"; const std::string save_pb_model_path = FLAGS_model_dir + ".saved.pb";
SaveModelPb(save_pb_model_path, scope, prog); SaveModelPb(save_pb_model_path, scope, prog);
} }
TEST(ModelParser, SaveModelCombinedPb) {
CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog;
Scope scope;
LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
const std::string save_pb_model_path = FLAGS_model_dir + ".saved.pb.combined";
SaveModelPb(save_pb_model_path, scope, prog, true);
}
TEST(ModelParser, LoadModelCombinedPb) {
CHECK(!FLAGS_model_dir.empty());
const std::string model_path = FLAGS_model_dir + ".saved.pb.combined";
cpp::ProgramDesc prog;
Scope scope;
std::string model_file_path = FLAGS_model_dir + ".saved.pb.combined/model";
std::string param_file_path = FLAGS_model_dir + ".saved.pb.combined/params";
LoadModelPb(
model_path, model_file_path, param_file_path, &scope, &prog, true);
}
TEST(ModelParser, SaveParamNaive) { TEST(ModelParser, SaveParamNaive) {
Scope scope; Scope scope;
auto* tensor = scope.Var("xxx")->GetMutable<lite::Tensor>(); auto* tensor = scope.Var("xxx")->GetMutable<lite::Tensor>();
...@@ -94,7 +114,7 @@ TEST(ModelParser, SaveModelNaive) { ...@@ -94,7 +114,7 @@ TEST(ModelParser, SaveModelNaive) {
CHECK(!FLAGS_model_dir.empty()); CHECK(!FLAGS_model_dir.empty());
cpp::ProgramDesc prog; cpp::ProgramDesc prog;
Scope scope; Scope scope;
LoadModelPb(FLAGS_model_dir, &scope, &prog); LoadModelPb(FLAGS_model_dir, "", "", &scope, &prog);
const std::string save_pb_model_path = FLAGS_model_dir + ".saved.naive"; const std::string save_pb_model_path = FLAGS_model_dir + ".saved.naive";
SaveModelNaive(save_pb_model_path, scope, prog); SaveModelNaive(save_pb_model_path, scope, prog);
} }
......
...@@ -61,6 +61,8 @@ void Run(DebugConfig* conf) { ...@@ -61,6 +61,8 @@ void Run(DebugConfig* conf) {
}}; }};
predictor.Build(conf->model_dir, predictor.Build(conf->model_dir,
"",
"",
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)},
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册