未验证 提交 10f3dfc3 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework][Internal] Add set_passes_internal inference for CxxConfig (#3614)

上级 ac3bdf1d
...@@ -35,7 +35,7 @@ namespace lite { ...@@ -35,7 +35,7 @@ namespace lite {
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config_ = config; config_ = config;
auto places = config.valid_places(); auto places = config.valid_places();
std::vector<std::string> passes{}; std::vector<std::string> passes = config.get_passes_internal();
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
// if kCUDA is included in valid places, it should be initialized first, // if kCUDA is included in valid places, it should be initialized first,
// otherwise skip this step. // otherwise skip this step.
......
...@@ -40,6 +40,11 @@ void OptBase::SetModelType(std::string optimize_out_type) { ...@@ -40,6 +40,11 @@ void OptBase::SetModelType(std::string optimize_out_type) {
} }
} }
void OptBase::SetPassesInternal(
const std::vector<std::string>& passes_internal) {
opt_config_.set_passes_internal(passes_internal);
}
void OptBase::SetValidPlaces(const std::string& valid_places) { void OptBase::SetValidPlaces(const std::string& valid_places) {
valid_places_.clear(); valid_places_.clear();
auto target_reprs = lite::Split(valid_places, ","); auto target_reprs = lite::Split(valid_places, ",");
...@@ -110,11 +115,13 @@ void OptBase::Run() { ...@@ -110,11 +115,13 @@ void OptBase::Run() {
void OptBase::RunOptimize(const std::string& model_dir_path, void OptBase::RunOptimize(const std::string& model_dir_path,
const std::string& model_path, const std::string& model_path,
const std::string& param_path, const std::string& param_path,
const std::string& model_type,
const std::string& valid_places, const std::string& valid_places,
const std::string& optimized_out_path) { const std::string& optimized_out_path) {
SetModelDir(model_dir_path); SetModelDir(model_dir_path);
SetModelFile(model_path); SetModelFile(model_path);
SetParamFile(param_path); SetParamFile(param_path);
SetModelType(model_type);
SetValidPlaces(valid_places); SetValidPlaces(valid_places);
SetOptimizeOut(optimized_out_path); SetOptimizeOut(optimized_out_path);
CheckIfModelSupported(false); CheckIfModelSupported(false);
......
...@@ -51,12 +51,16 @@ class LITE_API OptBase { ...@@ -51,12 +51,16 @@ class LITE_API OptBase {
void SetOptimizeOut(const std::string &lite_out_name); void SetOptimizeOut(const std::string &lite_out_name);
void RecordModelInfo(bool record_strip_info = true); void RecordModelInfo(bool record_strip_info = true);
// set optimized_model type // set optimized_model type
void SetModelType(std::string model_type); void SetModelType(std::string model_type = "naive_buffer");
// internal inference for developer, not recommanded.
// choose methods of model optimizing.
void SetPassesInternal(const std::vector<std::string> &passes_internal = {});
// transform and save the optimized model // transform and save the optimized model
void Run(); void Run();
void RunOptimize(const std::string &model_dir_path = "", void RunOptimize(const std::string &model_dir_path = "",
const std::string &model_path = "", const std::string &model_path = "",
const std::string &param_path = "", const std::string &param_path = "",
const std::string &model_type = "",
const std::string &valid_places = "", const std::string &valid_places = "",
const std::string &optimized_out_path = ""); const std::string &optimized_out_path = "");
// fuctions of printing info // fuctions of printing info
......
...@@ -146,6 +146,7 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -146,6 +146,7 @@ class LITE_API CxxConfig : public ConfigBase {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
std::string model_file_; std::string model_file_;
std::string param_file_; std::string param_file_;
std::vector<std::string> passes_internal_{};
bool model_from_memory_{false}; bool model_from_memory_{false};
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
int x86_math_library_math_threads_ = 1; int x86_math_library_math_threads_ = 1;
...@@ -174,7 +175,16 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -174,7 +175,16 @@ class LITE_API CxxConfig : public ConfigBase {
param_file_ = std::string(param_buffer, param_buffer + param_buffer_size); param_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true; model_from_memory_ = true;
} }
// internal inference to choose passes for model optimizing,
// it's designed for internal developer and not recommanded
// for comman users.
void set_passes_internal(
const std::vector<std::string>& passes_internal = {}) {
passes_internal_ = passes_internal;
}
const std::vector<std::string>& get_passes_internal() const {
return passes_internal_;
}
const std::vector<Place>& valid_places() const { return valid_places_; } const std::vector<Place>& valid_places() const { return valid_places_; }
std::string model_file() const { return model_file_; } std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; } std::string param_file() const { return param_file_; }
......
...@@ -65,6 +65,7 @@ void BindLiteOpt(py::module *m) { ...@@ -65,6 +65,7 @@ void BindLiteOpt(py::module *m) {
.def("set_optimize_out", &OptBase::SetOptimizeOut) .def("set_optimize_out", &OptBase::SetOptimizeOut)
.def("set_model_type", &OptBase::SetModelType) .def("set_model_type", &OptBase::SetModelType)
.def("record_model_info", &OptBase::RecordModelInfo) .def("record_model_info", &OptBase::RecordModelInfo)
.def("set_passes_internal", &OptBase::SetPassesInternal)
.def("run", &OptBase::Run) .def("run", &OptBase::Run)
.def("run_optimize", &OptBase::RunOptimize) .def("run_optimize", &OptBase::RunOptimize)
.def("help", &OptBase::PrintHelpInfo) .def("help", &OptBase::PrintHelpInfo)
...@@ -124,6 +125,7 @@ void BindLiteCxxConfig(py::module *m) { ...@@ -124,6 +125,7 @@ void BindLiteCxxConfig(py::module *m) {
.def("param_file", &CxxConfig::param_file) .def("param_file", &CxxConfig::param_file)
.def("set_valid_places", &CxxConfig::set_valid_places) .def("set_valid_places", &CxxConfig::set_valid_places)
.def("set_model_buffer", &CxxConfig::set_model_buffer) .def("set_model_buffer", &CxxConfig::set_model_buffer)
.def("set_passes_internal", &CxxConfig::set_passes_internal)
.def("model_from_memory", &CxxConfig::model_from_memory); .def("model_from_memory", &CxxConfig::model_from_memory);
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
cxx_config.def("set_threads", &CxxConfig::set_threads) cxx_config.def("set_threads", &CxxConfig::set_threads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册