提交 7bca11ca 编写于 作者: M MRXLT

add ir optim for cpu

上级 082fedde
...@@ -43,6 +43,7 @@ message EngineDesc { ...@@ -43,6 +43,7 @@ message EngineDesc {
optional bool enable_memory_optimization = 13; optional bool enable_memory_optimization = 13;
optional bool static_optimization = 14; optional bool static_optimization = 14;
optional bool force_update_static_cache = 15; optional bool force_update_static_cache = 15;
optional bool enable_ir_optimization = 16;
}; };
// model_toolkit conf // model_toolkit conf
......
...@@ -35,6 +35,7 @@ class InferEngineCreationParams { ...@@ -35,6 +35,7 @@ class InferEngineCreationParams {
InferEngineCreationParams() { InferEngineCreationParams() {
_path = ""; _path = "";
_enable_memory_optimization = false; _enable_memory_optimization = false;
_enable_ir_optimization = false;
_static_optimization = false; _static_optimization = false;
_force_update_static_cache = false; _force_update_static_cache = false;
} }
...@@ -45,10 +46,16 @@ class InferEngineCreationParams { ...@@ -45,10 +46,16 @@ class InferEngineCreationParams {
_enable_memory_optimization = enable_memory_optimization; _enable_memory_optimization = enable_memory_optimization;
} }
void set_enable_ir_optimization(bool enable_ir_optimization) {
_enable_ir_optimization = enable_ir_optimization;
}
bool enable_memory_optimization() const { bool enable_memory_optimization() const {
return _enable_memory_optimization; return _enable_memory_optimization;
} }
bool enable_ir_optimization() const { return _enable_ir_optimization; }
void set_static_optimization(bool static_optimization = false) { void set_static_optimization(bool static_optimization = false) {
_static_optimization = static_optimization; _static_optimization = static_optimization;
} }
...@@ -68,6 +75,7 @@ class InferEngineCreationParams { ...@@ -68,6 +75,7 @@ class InferEngineCreationParams {
<< "model_path = " << _path << ", " << "model_path = " << _path << ", "
<< "enable_memory_optimization = " << _enable_memory_optimization << "enable_memory_optimization = " << _enable_memory_optimization
<< ", " << ", "
<< "enable_ir_optimization = " << _enable_ir_optimization << ", "
<< "static_optimization = " << _static_optimization << ", " << "static_optimization = " << _static_optimization << ", "
<< "force_update_static_cache = " << _force_update_static_cache; << "force_update_static_cache = " << _force_update_static_cache;
} }
...@@ -75,6 +83,7 @@ class InferEngineCreationParams { ...@@ -75,6 +83,7 @@ class InferEngineCreationParams {
private: private:
std::string _path; std::string _path;
bool _enable_memory_optimization; bool _enable_memory_optimization;
bool _enable_ir_optimization;
bool _static_optimization; bool _static_optimization;
bool _force_update_static_cache; bool _force_update_static_cache;
}; };
...@@ -150,6 +159,11 @@ class ReloadableInferEngine : public InferEngine { ...@@ -150,6 +159,11 @@ class ReloadableInferEngine : public InferEngine {
force_update_static_cache = conf.force_update_static_cache(); force_update_static_cache = conf.force_update_static_cache();
} }
if (conf.has_enable_ir_optimization()) {
_infer_engine_params.set_enable_ir_optimization(
conf.enable_ir_optimization());
}
_infer_engine_params.set_path(_model_data_path); _infer_engine_params.set_path(_model_data_path);
if (enable_memory_optimization) { if (enable_memory_optimization) {
_infer_engine_params.set_enable_memory_optimization(true); _infer_engine_params.set_enable_memory_optimization(true);
......
...@@ -194,6 +194,12 @@ class FluidCpuAnalysisDirCore : public FluidFamilyCore { ...@@ -194,6 +194,12 @@ class FluidCpuAnalysisDirCore : public FluidFamilyCore {
analysis_config.EnableMemoryOptim(); analysis_config.EnableMemoryOptim();
} }
if (params.enable_ir_optimization()) {
analysis_config.SwitchIrOptim(true);
} else {
analysis_config.SwitchIrOptim(false);
}
AutoLock lock(GlobalPaddleCreateMutex::instance()); AutoLock lock(GlobalPaddleCreateMutex::instance());
_core = _core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config); paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
......
...@@ -198,6 +198,12 @@ class FluidGpuAnalysisDirCore : public FluidFamilyCore { ...@@ -198,6 +198,12 @@ class FluidGpuAnalysisDirCore : public FluidFamilyCore {
analysis_config.EnableMemoryOptim(); analysis_config.EnableMemoryOptim();
} }
if (params.enable_ir_optimization()) {
analysis_config.SwitchIrOptim(true);
} else {
analysis_config.SwitchIrOptim(false);
}
AutoLock lock(GlobalPaddleCreateMutex::instance()); AutoLock lock(GlobalPaddleCreateMutex::instance());
_core = _core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config); paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
......
...@@ -127,6 +127,7 @@ class Server(object): ...@@ -127,6 +127,7 @@ class Server(object):
self.model_toolkit_conf = None self.model_toolkit_conf = None
self.resource_conf = None self.resource_conf = None
self.memory_optimization = False self.memory_optimization = False
self.ir_optimization = False
self.model_conf = None self.model_conf = None
self.workflow_fn = "workflow.prototxt" self.workflow_fn = "workflow.prototxt"
self.resource_fn = "resource.prototxt" self.resource_fn = "resource.prototxt"
...@@ -175,6 +176,9 @@ class Server(object): ...@@ -175,6 +176,9 @@ class Server(object):
def set_memory_optimize(self, flag=False): def set_memory_optimize(self, flag=False):
self.memory_optimization = flag self.memory_optimization = flag
def set_ir_optimize(self, flag=False):
self.ir_optimization = flag
def check_local_bin(self): def check_local_bin(self):
if "SERVING_BIN" in os.environ: if "SERVING_BIN" in os.environ:
self.use_local_bin = True self.use_local_bin = True
...@@ -195,6 +199,7 @@ class Server(object): ...@@ -195,6 +199,7 @@ class Server(object):
engine.enable_batch_align = 0 engine.enable_batch_align = 0
engine.model_data_path = model_config_path engine.model_data_path = model_config_path
engine.enable_memory_optimization = self.memory_optimization engine.enable_memory_optimization = self.memory_optimization
engine.enable_ir_optimization = self.ir_optimization
engine.static_optimization = False engine.static_optimization = False
engine.force_update_static_cache = False engine.force_update_static_cache = False
...@@ -244,7 +249,7 @@ class Server(object): ...@@ -244,7 +249,7 @@ class Server(object):
workflow_oi_config_path = None workflow_oi_config_path = None
if isinstance(model_config_paths, str): if isinstance(model_config_paths, str):
# If there is only one model path, use the default infer_op. # If there is only one model path, use the default infer_op.
# Because there are several infer_op type, we need to find # Because there are several infer_op type, we need to find
# it from workflow_conf. # it from workflow_conf.
default_engine_names = [ default_engine_names = [
'general_infer_0', 'general_dist_kv_infer_0', 'general_infer_0', 'general_dist_kv_infer_0',
......
...@@ -41,6 +41,8 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -41,6 +41,8 @@ def parse_args(): # pylint: disable=doc-string-missing
"--device", type=str, default="cpu", help="Type of device") "--device", type=str, default="cpu", help="Type of device")
parser.add_argument( parser.add_argument(
"--mem_optim", type=bool, default=False, help="Memory optimize") "--mem_optim", type=bool, default=False, help="Memory optimize")
parser.add_argument(
"--ir_optim", type=bool, default=False, help="Graph optimize")
parser.add_argument( parser.add_argument(
"--max_body_size", "--max_body_size",
type=int, type=int,
...@@ -57,6 +59,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -57,6 +59,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
workdir = args.workdir workdir = args.workdir
device = args.device device = args.device
mem_optim = args.mem_optim mem_optim = args.mem_optim
ir_optim = args.ir_optim
max_body_size = args.max_body_size max_body_size = args.max_body_size
if model == "": if model == "":
...@@ -78,6 +81,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -78,6 +81,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num) server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim) server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.set_max_body_size(max_body_size) server.set_max_body_size(max_body_size)
server.set_port(port) server.set_port(port)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册