未验证 提交 7f87747c 编写于 作者: H huzhiqiang 提交者: GitHub

[Python opt][Usability] modify functions of python opt (#3465)

上级 b8234efb
...@@ -55,7 +55,7 @@ DEFINE_string(model_file, "", "model file path of the combined-param model"); ...@@ -55,7 +55,7 @@ DEFINE_string(model_file, "", "model file path of the combined-param model");
DEFINE_string(param_file, "", "param file path of the combined-param model"); DEFINE_string(param_file, "", "param file path of the combined-param model");
DEFINE_string( DEFINE_string(
optimize_out_type, optimize_out_type,
"protobuf", "naive_buffer",
"store type of the output optimized model. protobuf/naive_buffer"); "store type of the output optimized model. protobuf/naive_buffer");
DEFINE_bool(display_kernels, false, "Display kernel information"); DEFINE_bool(display_kernels, false, "Display kernel information");
DEFINE_bool(record_tailoring_info, DEFINE_bool(record_tailoring_info,
...@@ -207,7 +207,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) { ...@@ -207,7 +207,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) {
} }
std::cout << std::setiosflags(std::ios::internal); std::cout << std::setiosflags(std::ios::internal);
std::cout << std::setw(maximum_optype_length) << "OP_name"; std::cout << std::setw(maximum_optype_length) << "OP_name";
for (int i = 0; i < targets.size(); i++) { for (size_t i = 0; i < targets.size(); i++) {
std::cout << std::setw(10) << targets[i].substr(1); std::cout << std::setw(10) << targets[i].substr(1);
} }
std::cout << std::endl; std::cout << std::endl;
...@@ -215,7 +215,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) { ...@@ -215,7 +215,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) {
for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) { for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) {
std::cout << std::setw(maximum_optype_length) << it->first; std::cout << std::setw(maximum_optype_length) << it->first;
auto ops_valid_places = it->second; auto ops_valid_places = it->second;
for (int i = 0; i < targets.size(); i++) { for (size_t i = 0; i < targets.size(); i++) {
if (std::find(ops_valid_places.begin(), if (std::find(ops_valid_places.begin(),
ops_valid_places.end(), ops_valid_places.end(),
targets[i]) != ops_valid_places.end()) { targets[i]) != ops_valid_places.end()) {
...@@ -235,7 +235,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) { ...@@ -235,7 +235,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) {
} }
// Print OP info. // Print OP info.
auto ops_valid_places = supported_ops.at(*op); auto ops_valid_places = supported_ops.at(*op);
for (int i = 0; i < targets.size(); i++) { for (size_t i = 0; i < targets.size(); i++) {
if (std::find(ops_valid_places.begin(), if (std::find(ops_valid_places.begin(),
ops_valid_places.end(), ops_valid_places.end(),
targets[i]) != ops_valid_places.end()) { targets[i]) != ops_valid_places.end()) {
...@@ -288,11 +288,11 @@ void ParseInputCommand() { ...@@ -288,11 +288,11 @@ void ParseInputCommand() {
auto valid_places = paddle::lite_api::ParserValidPlaces(); auto valid_places = paddle::lite_api::ParserValidPlaces();
// get valid_targets string // get valid_targets string
std::vector<TargetType> target_types = {}; std::vector<TargetType> target_types = {};
for (int i = 0; i < valid_places.size(); i++) { for (size_t i = 0; i < valid_places.size(); i++) {
target_types.push_back(valid_places[i].target); target_types.push_back(valid_places[i].target);
} }
std::string targets_str = TargetToStr(target_types[0]); std::string targets_str = TargetToStr(target_types[0]);
for (int i = 1; i < target_types.size(); i++) { for (size_t i = 1; i < target_types.size(); i++) {
targets_str = targets_str + TargetToStr(target_types[i]); targets_str = targets_str + TargetToStr(target_types[i]);
} }
...@@ -301,7 +301,7 @@ void ParseInputCommand() { ...@@ -301,7 +301,7 @@ void ParseInputCommand() {
target_types.push_back(TARGET(kUnk)); target_types.push_back(TARGET(kUnk));
std::set<std::string> valid_ops; std::set<std::string> valid_ops;
for (int i = 0; i < target_types.size(); i++) { for (size_t i = 0; i < target_types.size(); i++) {
auto ops = supported_ops_target[static_cast<int>(target_types[i])]; auto ops = supported_ops_target[static_cast<int>(target_types[i])];
valid_ops.insert(ops.begin(), ops.end()); valid_ops.insert(ops.begin(), ops.end());
} }
...@@ -318,7 +318,7 @@ void CheckIfModelSupported() { ...@@ -318,7 +318,7 @@ void CheckIfModelSupported() {
auto valid_unktype_ops = supported_ops_target[static_cast<int>(TARGET(kUnk))]; auto valid_unktype_ops = supported_ops_target[static_cast<int>(TARGET(kUnk))];
valid_ops.insert( valid_ops.insert(
valid_ops.end(), valid_unktype_ops.begin(), valid_unktype_ops.end()); valid_ops.end(), valid_unktype_ops.begin(), valid_unktype_ops.end());
for (int i = 0; i < valid_places.size(); i++) { for (size_t i = 0; i < valid_places.size(); i++) {
auto target = valid_places[i].target; auto target = valid_places[i].target;
auto ops = supported_ops_target[static_cast<int>(target)]; auto ops = supported_ops_target[static_cast<int>(target)];
valid_ops.insert(valid_ops.end(), ops.begin(), ops.end()); valid_ops.insert(valid_ops.end(), ops.begin(), ops.end());
...@@ -340,7 +340,7 @@ void CheckIfModelSupported() { ...@@ -340,7 +340,7 @@ void CheckIfModelSupported() {
std::set<std::string> unsupported_ops; std::set<std::string> unsupported_ops;
std::set<std::string> input_model_ops; std::set<std::string> input_model_ops;
for (int index = 0; index < cpp_prog.BlocksSize(); index++) { for (size_t index = 0; index < cpp_prog.BlocksSize(); index++) {
auto current_block = cpp_prog.GetBlock<lite::cpp::BlockDesc>(index); auto current_block = cpp_prog.GetBlock<lite::cpp::BlockDesc>(index);
for (size_t i = 0; i < current_block->OpsSize(); ++i) { for (size_t i = 0; i < current_block->OpsSize(); ++i) {
auto& op_desc = *current_block->GetOp<lite::cpp::OpDesc>(i); auto& op_desc = *current_block->GetOp<lite::cpp::OpDesc>(i);
...@@ -364,13 +364,13 @@ void CheckIfModelSupported() { ...@@ -364,13 +364,13 @@ void CheckIfModelSupported() {
unsupported_ops_str = unsupported_ops_str + ", " + *op_str; unsupported_ops_str = unsupported_ops_str + ", " + *op_str;
} }
std::vector<TargetType> targets = {}; std::vector<TargetType> targets = {};
for (int i = 0; i < valid_places.size(); i++) { for (size_t i = 0; i < valid_places.size(); i++) {
targets.push_back(valid_places[i].target); targets.push_back(valid_places[i].target);
} }
std::sort(targets.begin(), targets.end()); std::sort(targets.begin(), targets.end());
targets.erase(unique(targets.begin(), targets.end()), targets.end()); targets.erase(unique(targets.begin(), targets.end()), targets.end());
std::string targets_str = TargetToStr(targets[0]); std::string targets_str = TargetToStr(targets[0]);
for (int i = 1; i < targets.size(); i++) { for (size_t i = 1; i < targets.size(); i++) {
targets_str = targets_str + "," + TargetToStr(targets[i]); targets_str = targets_str + "," + TargetToStr(targets[i]);
} }
......
...@@ -82,27 +82,56 @@ void OptBase::SetValidPlaces(const std::string& valid_places) { ...@@ -82,27 +82,56 @@ void OptBase::SetValidPlaces(const std::string& valid_places) {
"command argument 'valid_targets'"; "command argument 'valid_targets'";
} }
void OptBase::SetOptimizeOut(const std::string& optimized_out_path) { void OptBase::SetLiteOut(const std::string& lite_out_name) {
optimize_out_path_ = optimized_out_path; lite_out_name_ = lite_out_name;
} }
void OptBase::RunOptimize(bool record_strip_info) { void OptBase::RecordModelInfo(bool record_strip_info) {
record_strip_info_ = record_strip_info;
}
void OptBase::Run() {
CheckIfModelSupported(false); CheckIfModelSupported(false);
OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map); OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map);
opt_config_.set_valid_places(valid_places_); opt_config_.set_valid_places(valid_places_);
if (model_set_dir_ != "") { if (model_set_dir_ != "") {
RunOptimizeFromModelSet(record_strip_info); RunOptimizeFromModelSet(record_strip_info_);
} else { } else {
auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_); auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_);
opt_predictor->SaveOptimizedModel( opt_predictor->SaveOptimizedModel(
optimize_out_path_, model_type_, record_strip_info); lite_out_name_, model_type_, record_strip_info_);
auto resulted_model_name = auto resulted_model_name =
record_strip_info ? "information of striped model" : "optimized model"; record_strip_info_ ? "information of striped model" : "optimized model";
std::cout << "Save the " << resulted_model_name std::cout << "Save the " << resulted_model_name
<< " into :" << optimize_out_path_ << "successfully"; << " into :" << lite_out_name_ << "successfully";
} }
} }
void OptBase::RunOptimize(const std::string& model_dir_path,
const std::string& model_path,
const std::string& param_path,
const std::string& valid_places,
const std::string& optimized_out_path) {
SetModelDir(model_dir_path);
SetModelFile(model_path);
SetParamFile(param_path);
SetValidPlaces(valid_places);
SetLiteOut(optimized_out_path);
CheckIfModelSupported(false);
OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map);
opt_config_.set_valid_places(valid_places_);
if (model_set_dir_ != "") {
RunOptimizeFromModelSet(record_strip_info_);
} else {
auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_);
opt_predictor->SaveOptimizedModel(
lite_out_name_, model_type_, record_strip_info_);
auto resulted_model_name =
record_strip_info_ ? "information of striped model" : "optimized model";
std::cout << "Save the " << resulted_model_name
<< " into :" << lite_out_name_ << "successfully";
}
}
// collect ops info of modelset // collect ops info of modelset
void CollectModelMetaInfo(const std::string& output_dir, void CollectModelMetaInfo(const std::string& output_dir,
const std::vector<std::string>& models, const std::vector<std::string>& models,
...@@ -125,7 +154,7 @@ void OptBase::SetModelSetDir(const std::string& model_set_path) { ...@@ -125,7 +154,7 @@ void OptBase::SetModelSetDir(const std::string& model_set_path) {
} }
void OptBase::RunOptimizeFromModelSet(bool record_strip_info) { void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
// 1. mkdir of outputed optimized model set. // 1. mkdir of outputed optimized model set.
lite::MkDirRecur(optimize_out_path_); lite::MkDirRecur(lite_out_name_);
auto model_dirs = lite::ListDir(model_set_dir_, true); auto model_dirs = lite::ListDir(model_set_dir_, true);
if (model_dirs.size() == 0) { if (model_dirs.size() == 0) {
LOG(FATAL) << "[" << model_set_dir_ << "] does not contain any model"; LOG(FATAL) << "[" << model_set_dir_ << "] does not contain any model";
...@@ -138,7 +167,7 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) { ...@@ -138,7 +167,7 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
std::string input_model_dir = std::string input_model_dir =
lite::Join<std::string>({model_set_dir_, name}, "/"); lite::Join<std::string>({model_set_dir_, name}, "/");
std::string output_model_dir = std::string output_model_dir =
lite::Join<std::string>({optimize_out_path_, name}, "/"); lite::Join<std::string>({lite_out_name_, name}, "/");
if (opt_config_.model_file() != "" && opt_config_.param_file() != "") { if (opt_config_.model_file() != "" && opt_config_.param_file() != "") {
auto model_file_path = auto model_file_path =
...@@ -155,7 +184,7 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) { ...@@ -155,7 +184,7 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_); auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_);
opt_predictor->SaveOptimizedModel( opt_predictor->SaveOptimizedModel(
optimize_out_path_, model_type_, record_strip_info); lite_out_name_, model_type_, record_strip_info);
std::cout << "Optimize done. "; std::cout << "Optimize done. ";
} }
...@@ -164,46 +193,60 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) { ...@@ -164,46 +193,60 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
if (record_strip_info) { if (record_strip_info) {
// Collect all models information // Collect all models information
CollectModelMetaInfo( CollectModelMetaInfo(
optimize_out_path_, model_dirs, lite::TAILORD_OPS_SOURCE_LIST_FILENAME); lite_out_name_, model_dirs, lite::TAILORD_OPS_SOURCE_LIST_FILENAME);
CollectModelMetaInfo(
lite_out_name_, model_dirs, lite::TAILORD_OPS_LIST_NAME);
CollectModelMetaInfo( CollectModelMetaInfo(
optimize_out_path_, model_dirs, lite::TAILORD_OPS_LIST_NAME); lite_out_name_, model_dirs, lite::TAILORD_KERNELS_SOURCE_LIST_FILENAME);
CollectModelMetaInfo(optimize_out_path_,
model_dirs,
lite::TAILORD_KERNELS_SOURCE_LIST_FILENAME);
CollectModelMetaInfo( CollectModelMetaInfo(
optimize_out_path_, model_dirs, lite::TAILORD_KERNELS_LIST_NAME); lite_out_name_, model_dirs, lite::TAILORD_KERNELS_LIST_NAME);
std::cout << "Record the information of stripped models into :" std::cout << "Record the information of stripped models into :"
<< optimize_out_path_ << "successfully"; << lite_out_name_ << "successfully";
} }
} }
void OptBase::PrintHelpInfo() { void OptBase::PrintHelpInfo() {
const std::string opt_version = lite::version(); const std::string opt_version = lite::version();
const char help_info[] = const char help_info[] =
"At least one argument should be inputed. Valid arguments are listed " "------------------------------------------------------------------------"
"below:\n" "-----------------------------------------------------------\n"
" Valid arguments of Paddle-Lite opt are listed below:\n"
"------------------------------------------------------------------------"
"-----------------------------------------------------------\n"
" Arguments of help information:\n" " Arguments of help information:\n"
" `help()` Print help infomation\n" " `help()` Print help infomation\n"
" Arguments of model optimization:\n" "\n"
" Arguments of model transformation:\n"
" `set_model_dir(model_dir)`\n" " `set_model_dir(model_dir)`\n"
" `set_model_file(model_file_path)`\n" " `set_model_file(model_file_path)`\n"
" `set_param_file(param_file_path)`\n" " `set_param_file(param_file_path)`\n"
" `set_model_type(protobuf|naive_buffer)`\n" " `set_model_type(protobuf|naive_buffer)`: naive_buffer by "
" `set_optimize_out(output_optimize_model_dir)`\n" "default\n"
" `set_lite_out(output_optimize_model_dir)`\n"
" `set_valid_places(arm|opencl|x86|npu|xpu|rknpu|apu)`\n" " `set_valid_places(arm|opencl|x86|npu|xpu|rknpu|apu)`\n"
" `run_optimize(false|true)`\n" " `record_model_info(false|true)`: refer to whether to record ops "
" ` ----fasle&true refer to whether to record ops info for " "info for striping lib, false by default`\n"
"tailoring lib, false by default`\n" " `run() : start model transformation`\n"
" Arguments of model checking and ops information:\n" " eg. `opt.set_model_dir(\"./mobilenetv1\"); "
"opt.set_lite_out(\"mobilenetv1_opt\"); opt.set_valid_places(\"arm\"); "
"opt.run();`\n"
"\n"
" You can also transform model through a single input argument:\n"
" `run_optimize(model_dir, model_file_path, param_file_path, "
"model_type, valid_places, lite_out_name) `\n"
" eg. `opt.run_optimize(\"./mobilenetv1\", \"\", \"\", "
"\"naive_buffer\", \"arm\", \"mobilenetv1_opt\");`"
"\n"
" Arguments of checking model and printing ops information:\n"
" `print_all_ops()` Display all the valid operators of " " `print_all_ops()` Display all the valid operators of "
"Paddle-Lite\n" "Paddle-Lite\n"
" `print_supported_ops` Display supported operators of valid " " `print_supported_ops` Display supported operators of valid "
"places\n" "places\n"
" `check_if_model_supported()` Check if the input model is " " `check_if_model_supported()` Check if the input model is "
"supported\n"; "supported\n"
"------------------------------------------------------------------------"
std::cout << "opt version:" << opt_version << std::endl "-----------------------------------------------------------\n";
<< help_info << std::endl; std::cout << "opt version:" << opt_version << std::endl << help_info;
} }
// 2. Print supported info of inputed ops // 2. Print supported info of inputed ops
void OptBase::PrintOpsInfo(const std::set<std::string>& valid_ops) { void OptBase::PrintOpsInfo(const std::set<std::string>& valid_ops) {
......
...@@ -44,16 +44,21 @@ class LITE_API OptBase { ...@@ -44,16 +44,21 @@ class LITE_API OptBase {
public: public:
OptBase() = default; OptBase() = default;
void SetModelSetDir(const std::string &model_set_path); void SetModelSetDir(const std::string &model_set_path);
void SetModelDir(const std::string &model_path); void SetModelDir(const std::string &model_dir_path);
void SetModelFile(const std::string &model_path); void SetModelFile(const std::string &model_path);
void SetParamFile(const std::string &param_path); void SetParamFile(const std::string &param_path);
void SetValidPlaces(const std::string &valid_places); void SetValidPlaces(const std::string &valid_places);
void SetOptimizeOut(const std::string &optimized_out_path); void SetLiteOut(const std::string &lite_out_name);
void RecordModelInfo(bool record_strip_info = true);
// set optimized_model type // set optimized_model type
void SetModelType(std::string model_type); void SetModelType(std::string model_type);
// transform and save the optimized model // transform and save the optimized model
void RunOptimize(bool record_strip_info = false); void Run();
void RunOptimize(const std::string &model_dir_path = "",
const std::string &model_path = "",
const std::string &param_path = "",
const std::string &valid_places = "",
const std::string &optimized_out_path = "");
// fuctions of printing info // fuctions of printing info
// 1. help info // 1. help info
void PrintHelpInfo(); void PrintHelpInfo();
...@@ -71,12 +76,12 @@ class LITE_API OptBase { ...@@ -71,12 +76,12 @@ class LITE_API OptBase {
// valid places for the optimized_model // valid places for the optimized_model
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
// filename of the optimized_model // filename of the optimized_model
std::string optimize_out_path_; std::string lite_out_name_;
// type of the optimized_model, kNaiveBuffer default. // type of the optimized_model, kNaiveBuffer default.
LiteModelType model_type_{LiteModelType::kNaiveBuffer}; LiteModelType model_type_{LiteModelType::kNaiveBuffer};
// Dir path of a set of models, this should be combined with model // Dir path of a set of models, this should be combined with model
std::string model_set_dir_; std::string model_set_dir_;
bool record_strip_info_{false};
void RunOptimizeFromModelSet(bool record_strip_info = false); void RunOptimizeFromModelSet(bool record_strip_info = false);
}; };
......
...@@ -62,8 +62,10 @@ void BindLiteOpt(py::module *m) { ...@@ -62,8 +62,10 @@ void BindLiteOpt(py::module *m) {
.def("set_model_file", &OptBase::SetModelFile) .def("set_model_file", &OptBase::SetModelFile)
.def("set_param_file", &OptBase::SetParamFile) .def("set_param_file", &OptBase::SetParamFile)
.def("set_valid_places", &OptBase::SetValidPlaces) .def("set_valid_places", &OptBase::SetValidPlaces)
.def("set_optimize_out", &OptBase::SetOptimizeOut) .def("set_lite_out", &OptBase::SetLiteOut)
.def("set_model_type", &OptBase::SetModelType) .def("set_model_type", &OptBase::SetModelType)
.def("record_model_info", &OptBase::RecordModelInfo)
.def("run", &OptBase::Run)
.def("run_optimize", &OptBase::RunOptimize) .def("run_optimize", &OptBase::RunOptimize)
.def("help", &OptBase::PrintHelpInfo) .def("help", &OptBase::PrintHelpInfo)
.def("print_supported_ops", &OptBase::PrintSupportedOps) .def("print_supported_ops", &OptBase::PrintSupportedOps)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册