diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index bacb9931abe6674e4181a74d1afaff2e6e030bb4..4647f20bbe476d8763f94f707f3d88da7c7544df 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -24,13 +24,6 @@ namespace paddle { namespace lite { -static const char TAILORD_OPS_SOURCE_LIST_FILENAME[] = - ".tailored_ops_source_list"; -static const char TAILORD_OPS_LIST_NAME[] = ".tailored_ops_list"; -static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] = - ".tailored_kernels_source_list"; -static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list"; - void Predictor::SaveModel(const std::string &dir, lite_api::LiteModelType model_type, bool record_info) { diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 502ce812e1f4a7f520e89e6eaff020c5853f5308..504710d9fa29420b8762f31e0c675b59c6c626bd 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -29,6 +29,13 @@ namespace paddle { namespace lite { +static const char TAILORD_OPS_SOURCE_LIST_FILENAME[] = + ".tailored_ops_source_list"; +static const char TAILORD_OPS_LIST_NAME[] = ".tailored_ops_list"; +static const char TAILORD_KERNELS_SOURCE_LIST_FILENAME[] = + ".tailored_kernels_source_list"; +static const char TAILORD_KERNELS_LIST_NAME[] = ".tailored_kernels_list"; + /* * Predictor for inference, input a model, it will optimize and execute it. */ diff --git a/lite/api/model_optimize_tool.cc b/lite/api/model_optimize_tool.cc index daa57cd45632764172426cc41914abc7f82bea33..1c426e8568cf71b6f48edbbeb8a93fec2e89c594 100644 --- a/lite/api/model_optimize_tool.cc +++ b/lite/api/model_optimize_tool.cc @@ -20,6 +20,7 @@ // model_optimize_tool's compiling period #include "all_kernel_faked.cc" // NOLINT #include "kernel_src_map.h" // NOLINT +#include "lite/api/cxx_api.h" #include "lite/api/paddle_api.h" #include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_passes.h" @@ -31,6 +32,18 @@ DEFINE_string(model_dir, "", "path of the model. This option will be ignored if model_file " "and param_file are exist"); +DEFINE_string(model_filename, + "", + "model topo filename of the model in models set. This option" + " will be used to specific tailoring"); +DEFINE_string(param_filename, + "", + "model param filename of the model in models set. This option" + " will be used to specific tailoring"); +DEFINE_string(model_set_dir, + "", + "path of the models set. This option will be used to specific" + " tailoring"); DEFINE_string(model_file, "", "model file path of the combined-param model"); DEFINE_string(param_file, "", "param file path of the combined-param model"); DEFINE_string( @@ -58,24 +71,9 @@ void DisplayKernels() { LOG(INFO) << ::paddle::lite::KernelRegistry::Global().DebugString(); } -void Main() { - if (!FLAGS_model_file.empty() && !FLAGS_param_file.empty()) { - LOG(WARNING) - << "Load combined-param model. Option model_dir will be ignored"; - } - - if (FLAGS_display_kernels) { - DisplayKernels(); - exit(0); - } - - lite_api::CxxConfig config; - config.set_model_dir(FLAGS_model_dir); - config.set_model_file(FLAGS_model_file); - config.set_param_file(FLAGS_param_file); - +std::vector ParserValidPlaces() { std::vector valid_places; - auto target_reprs = lite::Split(FLAGS_valid_targets, " "); + auto target_reprs = lite::Split(FLAGS_valid_targets, ","); for (auto& target_repr : target_reprs) { if (target_repr == "arm") { valid_places.emplace_back(TARGET(kARM)); @@ -109,26 +107,130 @@ void Main() { valid_places.insert(valid_places.begin(), Place{TARGET(kARM), PRECISION(kInt8)}); } + return valid_places; +} + +void RunOptimize(const std::string& model_dir, + const std::string& model_file, + const std::string& param_file, + const std::string& optimize_out, + const std::string& optimize_out_type, + const std::vector& valid_places, + bool record_tailoring_info) { + if (!model_file.empty() && !param_file.empty()) { + LOG(WARNING) + << "Load combined-param model. Option model_dir will be ignored"; + } + + lite_api::CxxConfig config; + config.set_model_dir(model_dir); + config.set_model_file(model_file); + config.set_param_file(param_file); + config.set_valid_places(valid_places); auto predictor = lite_api::CreatePaddlePredictor(config); LiteModelType model_type; - if (FLAGS_optimize_out_type == "protobuf") { + if (optimize_out_type == "protobuf") { model_type = LiteModelType::kProtobuf; - } else if (FLAGS_optimize_out_type == "naive_buffer") { + } else if (optimize_out_type == "naive_buffer") { model_type = LiteModelType::kNaiveBuffer; } else { - LOG(FATAL) << "Unsupported Model type :" << FLAGS_optimize_out_type; + LOG(FATAL) << "Unsupported Model type :" << optimize_out_type; } - OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map); + OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map); predictor->SaveOptimizedModel( - FLAGS_optimize_out, model_type, FLAGS_record_tailoring_info); - if (FLAGS_record_tailoring_info) { + optimize_out, model_type, record_tailoring_info); + if (record_tailoring_info) { LOG(INFO) << "Record the information of tailored model into :" - << FLAGS_optimize_out; + << optimize_out; + } +} + +void CollectModelMetaInfo(const std::string& output_dir, + const std::vector& models, + const std::string& filename) { + std::set total; + for (const auto& name : models) { + std::string model_path = + lite::Join({output_dir, name, filename}, "/"); + auto lines = lite::ReadLines(model_path); + total.insert(lines.begin(), lines.end()); + } + std::string output_path = + lite::Join({output_dir, filename}, "/"); + lite::WriteLines(std::vector(total.begin(), total.end()), + output_path); +} + +void Main() { + if (FLAGS_display_kernels) { + DisplayKernels(); + exit(0); } + + auto valid_places = ParserValidPlaces(); + if (FLAGS_model_set_dir == "") { + RunOptimize(FLAGS_model_dir, + FLAGS_model_file, + FLAGS_param_file, + FLAGS_optimize_out, + FLAGS_optimize_out_type, + valid_places, + FLAGS_record_tailoring_info); + return; + } + + if (!FLAGS_record_tailoring_info) { + LOG(WARNING) << "--model_set_dir option only be used with " + "--record_tailoring_info=true together"; + return; + } + + auto model_dirs = lite::ListDir(FLAGS_model_set_dir, true); + if (model_dirs.size() == 0) { + LOG(FATAL) << "[" << FLAGS_model_set_dir << "] does not contain any model"; + } + // Optimize models in FLAGS_model_set_dir + for (const auto& name : model_dirs) { + std::string input_model_dir = + lite::Join({FLAGS_model_set_dir, name}, "/"); + std::string output_model_dir = + lite::Join({FLAGS_optimize_out, name}, "/"); + + std::string model_file = ""; + std::string param_file = ""; + + if (FLAGS_model_filename != "" && FLAGS_param_filename != "") { + model_file = + lite::Join({input_model_dir, FLAGS_model_filename}, "/"); + param_file = + lite::Join({input_model_dir, FLAGS_param_filename}, "/"); + } + + LOG(INFO) << "Start optimize model: " << input_model_dir; + RunOptimize(input_model_dir, + model_file, + param_file, + output_model_dir, + FLAGS_optimize_out_type, + valid_places, + FLAGS_record_tailoring_info); + LOG(INFO) << "Optimize done. "; + } + + // Collect all models information + CollectModelMetaInfo( + FLAGS_optimize_out, model_dirs, lite::TAILORD_OPS_SOURCE_LIST_FILENAME); + CollectModelMetaInfo( + FLAGS_optimize_out, model_dirs, lite::TAILORD_OPS_LIST_NAME); + CollectModelMetaInfo(FLAGS_optimize_out, + model_dirs, + lite::TAILORD_KERNELS_SOURCE_LIST_FILENAME); + CollectModelMetaInfo( + FLAGS_optimize_out, model_dirs, lite::TAILORD_KERNELS_LIST_NAME); } } // namespace lite_api diff --git a/lite/utils/io.h b/lite/utils/io.h index 98a0f39b084c1ec0767299501f6f359dab2017b3..92405cae862f062090665aecc8eb7f207cf059e7 100644 --- a/lite/utils/io.h +++ b/lite/utils/io.h @@ -14,9 +14,12 @@ #pragma once +#include #include +#include #include #include +#include #include "lite/utils/cp_logging.h" #include "lite/utils/string.h" @@ -46,11 +49,68 @@ static void MkDirRecur(const std::string& path) { // read buffer from file static std::string ReadFile(const std::string& filename) { std::ifstream ifile(filename.c_str()); + if (!ifile.is_open()) { + LOG(FATAL) << "Open file: [" << filename << "] failed."; + } std::ostringstream buf; char ch; while (buf && ifile.get(ch)) buf.put(ch); + ifile.close(); return buf.str(); } +// read lines from file +static std::vector ReadLines(const std::string& filename) { + std::ifstream ifile(filename.c_str()); + if (!ifile.is_open()) { + LOG(FATAL) << "Open file: [" << filename << "] failed."; + } + std::vector res; + std::string tmp; + while (getline(ifile, tmp)) res.push_back(tmp); + ifile.close(); + return res; +} + +static void WriteLines(const std::vector& lines, + const std::string& filename) { + std::ofstream ofile(filename.c_str()); + if (!ofile.is_open()) { + LOG(FATAL) << "Open file: [" << filename << "] failed."; + } + for (const auto& line : lines) { + ofile << line << "\n"; + } + ofile.close(); +} + +static bool IsDir(const std::string& path) { + DIR* dir_fd = opendir(path.c_str()); + if (dir_fd == nullptr) return false; + closedir(dir_fd); + return true; +} + +static std::vector ListDir(const std::string& path, + bool only_dir = false) { + if (!IsDir(path)) { + LOG(FATAL) << "[" << path << "] is not a valid dir path."; + } + + std::vector paths; + DIR* parent_dir_fd = opendir(path.c_str()); + dirent* dp; + while ((dp = readdir(parent_dir_fd)) != nullptr) { + // Exclude '.', '..' and hidden dir + std::string name(dp->d_name); + if (name == "." || name == ".." || name[0] == '.') continue; + if (IsDir(Join({path, name}, "/"))) { + paths.push_back(name); + } + } + closedir(parent_dir_fd); + return paths; +} + } // namespace lite } // namespace paddle