未验证 提交 1caba6ff 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework][Internal] Add set_passes_internal inference for CxxConfig (#3623)

上级 8b100da1
......@@ -35,7 +35,7 @@ namespace lite {
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config_ = config;
auto places = config.valid_places();
std::vector<std::string> passes{};
std::vector<std::string> passes = config.get_passes_internal();
#ifdef LITE_WITH_CUDA
// if kCUDA is included in valid places, it should be initialized first,
// otherwise skip this step.
......
......@@ -40,6 +40,11 @@ void OptBase::SetModelType(std::string optimize_out_type) {
}
}
void OptBase::SetPassesInternal(
const std::vector<std::string>& passes_internal) {
opt_config_.set_passes_internal(passes_internal);
}
void OptBase::SetValidPlaces(const std::string& valid_places) {
valid_places_.clear();
auto target_reprs = lite::Split(valid_places, ",");
......@@ -110,11 +115,13 @@ void OptBase::Run() {
void OptBase::RunOptimize(const std::string& model_dir_path,
const std::string& model_path,
const std::string& param_path,
const std::string& model_type,
const std::string& valid_places,
const std::string& optimized_out_path) {
SetModelDir(model_dir_path);
SetModelFile(model_path);
SetParamFile(param_path);
SetModelType(model_type);
SetValidPlaces(valid_places);
SetOptimizeOut(optimized_out_path);
CheckIfModelSupported(false);
......
......@@ -51,12 +51,16 @@ class LITE_API OptBase {
void SetOptimizeOut(const std::string &lite_out_name);
void RecordModelInfo(bool record_strip_info = true);
// set optimized_model type
void SetModelType(std::string model_type);
void SetModelType(std::string model_type = "naive_buffer");
// internal inference for developer, not recommanded.
// choose methods of model optimizing.
void SetPassesInternal(const std::vector<std::string> &passes_internal = {});
// transform and save the optimized model
void Run();
void RunOptimize(const std::string &model_dir_path = "",
const std::string &model_path = "",
const std::string &param_path = "",
const std::string &model_type = "",
const std::string &valid_places = "",
const std::string &optimized_out_path = "");
// fuctions of printing info
......
......@@ -137,6 +137,7 @@ class LITE_API CxxConfig : public ConfigBase {
std::vector<Place> valid_places_;
std::string model_file_;
std::string param_file_;
std::vector<std::string> passes_internal_{};
bool model_from_memory_{false};
#ifdef LITE_WITH_X86
int x86_math_library_math_threads_ = 1;
......@@ -165,7 +166,16 @@ class LITE_API CxxConfig : public ConfigBase {
param_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
model_from_memory_ = true;
}
// internal inference to choose passes for model optimizing,
// it's designed for internal developer and not recommanded
// for comman users.
void set_passes_internal(
const std::vector<std::string>& passes_internal = {}) {
passes_internal_ = passes_internal;
}
const std::vector<std::string>& get_passes_internal() const {
return passes_internal_;
}
const std::vector<Place>& valid_places() const { return valid_places_; }
std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; }
......
......@@ -65,6 +65,7 @@ void BindLiteOpt(py::module *m) {
.def("set_optimize_out", &OptBase::SetOptimizeOut)
.def("set_model_type", &OptBase::SetModelType)
.def("record_model_info", &OptBase::RecordModelInfo)
.def("set_passes_internal", &OptBase::SetPassesInternal)
.def("run", &OptBase::Run)
.def("run_optimize", &OptBase::RunOptimize)
.def("help", &OptBase::PrintHelpInfo)
......@@ -124,6 +125,7 @@ void BindLiteCxxConfig(py::module *m) {
.def("param_file", &CxxConfig::param_file)
.def("set_valid_places", &CxxConfig::set_valid_places)
.def("set_model_buffer", &CxxConfig::set_model_buffer)
.def("set_passes_internal", &CxxConfig::set_passes_internal)
.def("model_from_memory", &CxxConfig::model_from_memory);
#ifdef LITE_WITH_ARM
cxx_config.def("set_threads", &CxxConfig::set_threads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册