提交 a5e93766 编写于 作者: S sangoly 提交者: huzhiqiang

[Model Optimize Tool] enhance model optimize tool to supported specific...

[Model Optimize Tool] enhance model optimize tool to supported specific tailoring test=develop (#2480)
上级 c62fd634
...@@ -24,13 +24,6 @@ ...@@ -24,13 +24,6 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
static const char TAILORD_OPS_SOURCE_LIST_FILENAME[] =
".tailored_ops_source_list";
static const char TAILORD_OPS_LIST_NAME[] = ".tailored_ops_list";
static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] =
".tailored_kernels_source_list";
static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list";
void Predictor::SaveModel(const std::string &dir, void Predictor::SaveModel(const std::string &dir,
lite_api::LiteModelType model_type, lite_api::LiteModelType model_type,
bool record_info) { bool record_info) {
......
...@@ -29,6 +29,13 @@ ...@@ -29,6 +29,13 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
static const char TAILORD_OPS_SOURCE_LIST_FILENAME[] =
".tailored_ops_source_list";
static const char TAILORD_OPS_LIST_NAME[] = ".tailored_ops_list";
static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] =
".tailored_kernels_source_list";
static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list";
/* /*
* Predictor for inference, input a model, it will optimize and execute it. * Predictor for inference, input a model, it will optimize and execute it.
*/ */
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
// model_optimize_tool's compiling period // model_optimize_tool's compiling period
#include "all_kernel_faked.cc" // NOLINT #include "all_kernel_faked.cc" // NOLINT
#include "kernel_src_map.h" // NOLINT #include "kernel_src_map.h" // NOLINT
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_api.h" #include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h" #include "lite/api/paddle_use_passes.h"
...@@ -31,6 +32,18 @@ DEFINE_string(model_dir, ...@@ -31,6 +32,18 @@ DEFINE_string(model_dir,
"", "",
"path of the model. This option will be ignored if model_file " "path of the model. This option will be ignored if model_file "
"and param_file are exist"); "and param_file are exist");
DEFINE_string(model_filename,
"",
"model topo filename of the model in models set. This option"
" will be used to specific tailoring");
DEFINE_string(param_filename,
"",
"model param filename of the model in models set. This option"
" will be used to specific tailoring");
DEFINE_string(model_set_dir,
"",
"path of the models set. This option will be used to specific"
" tailoring");
DEFINE_string(model_file, "", "model file path of the combined-param model"); DEFINE_string(model_file, "", "model file path of the combined-param model");
DEFINE_string(param_file, "", "param file path of the combined-param model"); DEFINE_string(param_file, "", "param file path of the combined-param model");
DEFINE_string( DEFINE_string(
...@@ -58,24 +71,9 @@ void DisplayKernels() { ...@@ -58,24 +71,9 @@ void DisplayKernels() {
LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString(); LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString();
} }
void Main() { std::vector<Place> ParserValidPlaces() {
if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) {
LOG(WARNING)
<< "Load combined-param model. Option model_dir will be ignored";
}
if (FLAGS_display_kernels) {
DisplayKernels();
exit(0);
}
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_model_file(FLAGS_model_file);
config.set_param_file(FLAGS_param_file);
std::vector<Place> valid_places; std::vector<Place> valid_places;
auto target_reprs = lite::Split(FLAGS_valid_targets, " "); auto target_reprs = lite::Split(FLAGS_valid_targets, ",");
for (auto& target_repr : target_reprs) { for (auto& target_repr : target_reprs) {
if (target_repr == "arm") { if (target_repr == "arm") {
valid_places.emplace_back(TARGET(kARM)); valid_places.emplace_back(TARGET(kARM));
...@@ -109,26 +107,130 @@ void Main() { ...@@ -109,26 +107,130 @@ void Main() {
valid_places.insert(valid_places.begin(), valid_places.insert(valid_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)}); Place{TARGET(kARM), PRECISION(kInt8)});
} }
return valid_places;
}
void RunOptimize(const std::string& model_dir,
const std::string& model_file,
const std::string& param_file,
const std::string& optimize_out,
const std::string& optimize_out_type,
const std::vector<Place>& valid_places,
bool record_tailoring_info) {
if (!model_file.empty() && !param_file.empty()) {
LOG(WARNING)
<< "Load combined-param model. Option model_dir will be ignored";
}
lite_api::CxxConfig config;
config.set_model_dir(model_dir);
config.set_model_file(model_file);
config.set_param_file(param_file);
config.set_valid_places(valid_places); config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
LiteModelType model_type; LiteModelType model_type;
if (FLAGS_optimize_out_type == "protobuf") { if (optimize_out_type == "protobuf") {
model_type = LiteModelType::kProtobuf; model_type = LiteModelType::kProtobuf;
} else if (FLAGS_optimize_out_type == "naive_buffer") { } else if (optimize_out_type == "naive_buffer") {
model_type = LiteModelType::kNaiveBuffer; model_type = LiteModelType::kNaiveBuffer;
} else { } else {
LOG(FATAL) << "Unsupported Model type :" << FLAGS_optimize_out_type; LOG(FATAL) << "Unsupported Model type :" << optimize_out_type;
} }
OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map);
OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map);
predictor->SaveOptimizedModel( predictor->SaveOptimizedModel(
FLAGS_optimize_out, model_type, FLAGS_record_tailoring_info); optimize_out, model_type, record_tailoring_info);
if (FLAGS_record_tailoring_info) { if (record_tailoring_info) {
LOG(INFO) << "Record the information of tailored model into :" LOG(INFO) << "Record the information of tailored model into :"
<< FLAGS_optimize_out; << optimize_out;
}
}
void CollectModelMetaInfo(const std::string& output_dir,
const std::vector<std::string>& models,
const std::string& filename) {
std::set<std::string> total;
for (const auto& name : models) {
std::string model_path =
lite::Join<std::string>({output_dir, name, filename}, "/");
auto lines = lite::ReadLines(model_path);
total.insert(lines.begin(), lines.end());
}
std::string output_path =
lite::Join<std::string>({output_dir, filename}, "/");
lite::WriteLines(std::vector<std::string>(total.begin(), total.end()),
output_path);
}
void Main() {
if (FLAGS_display_kernels) {
DisplayKernels();
exit(0);
} }
auto valid_places = ParserValidPlaces();
if (FLAGS_model_set_dir == "") {
RunOptimize(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_param_file,
FLAGS_optimize_out,
FLAGS_optimize_out_type,
valid_places,
FLAGS_record_tailoring_info);
return;
}
if (!FLAGS_record_tailoring_info) {
LOG(WARNING) << "--model_set_dir option only be used with "
"--record_tailoring_info=true together";
return;
}
auto model_dirs = lite::ListDir(FLAGS_model_set_dir, true);
if (model_dirs.size() == 0) {
LOG(FATAL) << "[" << FLAGS_model_set_dir << "] does not contain any model";
}
// Optimize models in FLAGS_model_set_dir
for (const auto& name : model_dirs) {
std::string input_model_dir =
lite::Join<std::string>({FLAGS_model_set_dir, name}, "/");
std::string output_model_dir =
lite::Join<std::string>({FLAGS_optimize_out, name}, "/");
std::string model_file = "";
std::string param_file = "";
if (FLAGS_model_filename != "" && FLAGS_param_filename != "") {
model_file =
lite::Join<std::string>({input_model_dir, FLAGS_model_filename}, "/");
param_file =
lite::Join<std::string>({input_model_dir, FLAGS_param_filename}, "/");
}
LOG(INFO) << "Start optimize model: " << input_model_dir;
RunOptimize(input_model_dir,
model_file,
param_file,
output_model_dir,
FLAGS_optimize_out_type,
valid_places,
FLAGS_record_tailoring_info);
LOG(INFO) << "Optimize done. ";
}
// Collect all models information
CollectModelMetaInfo(
FLAGS_optimize_out, model_dirs, lite::TAILORD_OPS_SOURCE_LIST_FILENAME);
CollectModelMetaInfo(
FLAGS_optimize_out, model_dirs, lite::TAILORD_OPS_LIST_NAME);
CollectModelMetaInfo(FLAGS_optimize_out,
model_dirs,
lite::TAILORD_KERNELS_SOURCE_LIST_FILENAME);
CollectModelMetaInfo(
FLAGS_optimize_out, model_dirs, lite::TAILORD_KERNELS_LIST_NAME);
} }
} // namespace lite_api } // namespace lite_api
......
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
#pragma once #pragma once
#include <dirent.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h>
#include <fstream> #include <fstream>
#include <string> #include <string>
#include <vector>
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
...@@ -46,11 +49,68 @@ static void MkDirRecur(const std::string& path) { ...@@ -46,11 +49,68 @@ static void MkDirRecur(const std::string& path) {
// read buffer from file // read buffer from file
static std::string ReadFile(const std::string& filename) { static std::string ReadFile(const std::string& filename) {
std::ifstream ifile(filename.c_str()); std::ifstream ifile(filename.c_str());
if (!ifile.is_open()) {
LOG(FATAL) << "Open file: [" << filename << "] failed.";
}
std::ostringstream buf; std::ostringstream buf;
char ch; char ch;
while (buf && ifile.get(ch)) buf.put(ch); while (buf && ifile.get(ch)) buf.put(ch);
ifile.close();
return buf.str(); return buf.str();
} }
// read lines from file
static std::vector<std::string> ReadLines(const std::string& filename) {
std::ifstream ifile(filename.c_str());
if (!ifile.is_open()) {
LOG(FATAL) << "Open file: [" << filename << "] failed.";
}
std::vector<std::string> res;
std::string tmp;
while (getline(ifile, tmp)) res.push_back(tmp);
ifile.close();
return res;
}
static void WriteLines(const std::vector<std::string>& lines,
const std::string& filename) {
std::ofstream ofile(filename.c_str());
if (!ofile.is_open()) {
LOG(FATAL) << "Open file: [" << filename << "] failed.";
}
for (const auto& line : lines) {
ofile << line << "\n";
}
ofile.close();
}
static bool IsDir(const std::string& path) {
DIR* dir_fd = opendir(path.c_str());
if (dir_fd == nullptr) return false;
closedir(dir_fd);
return true;
}
static std::vector<std::string> ListDir(const std::string& path,
bool only_dir = false) {
if (!IsDir(path)) {
LOG(FATAL) << "[" << path << "] is not a valid dir path.";
}
std::vector<std::string> paths;
DIR* parent_dir_fd = opendir(path.c_str());
dirent* dp;
while ((dp = readdir(parent_dir_fd)) != nullptr) {
// Exclude '.', '..' and hidden dir
std::string name(dp->d_name);
if (name == "." || name == ".." || name[0] == '.') continue;
if (IsDir(Join<std::string>({path, name}, "/"))) {
paths.push_back(name);
}
}
closedir(parent_dir_fd);
return paths;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册