提交 46d7a8da 编写于 作者: M MRXLT

add ir optim for cpu

上级 b93ffa5a
......@@ -43,6 +43,7 @@ message EngineDesc {
optional bool enable_memory_optimization = 13;
optional bool static_optimization = 14;
optional bool force_update_static_cache = 15;
optional bool enable_ir_optimization = 16;
};
// model_toolkit conf
......
......@@ -35,6 +35,7 @@ class InferEngineCreationParams {
InferEngineCreationParams() {
_path = "";
_enable_memory_optimization = false;
_enable_ir_optimization = false;
_static_optimization = false;
_force_update_static_cache = false;
}
......@@ -45,10 +46,16 @@ class InferEngineCreationParams {
_enable_memory_optimization = enable_memory_optimization;
}
void set_enable_ir_optimization(bool enable_ir_optimization) {
_enable_ir_optimization = enable_ir_optimization;
}
bool enable_memory_optimization() const {
return _enable_memory_optimization;
}
bool enable_ir_optimization() const { return _enable_ir_optimization; }
void set_static_optimization(bool static_optimization = false) {
_static_optimization = static_optimization;
}
......@@ -68,6 +75,7 @@ class InferEngineCreationParams {
<< "model_path = " << _path << ", "
<< "enable_memory_optimization = " << _enable_memory_optimization
<< ", "
<< "enable_ir_optimization = " << _enable_ir_optimization << ", "
<< "static_optimization = " << _static_optimization << ", "
<< "force_update_static_cache = " << _force_update_static_cache;
}
......@@ -75,6 +83,7 @@ class InferEngineCreationParams {
private:
std::string _path;
bool _enable_memory_optimization;
bool _enable_ir_optimization;
bool _static_optimization;
bool _force_update_static_cache;
};
......@@ -150,6 +159,11 @@ class ReloadableInferEngine : public InferEngine {
force_update_static_cache = conf.force_update_static_cache();
}
if (conf.has_enable_ir_optimization()) {
_infer_engine_params.set_enable_ir_optimization(
conf.enable_ir_optimization());
}
_infer_engine_params.set_path(_model_data_path);
if (enable_memory_optimization) {
_infer_engine_params.set_enable_memory_optimization(true);
......
......@@ -194,6 +194,12 @@ class FluidCpuAnalysisDirCore : public FluidFamilyCore {
analysis_config.EnableMemoryOptim();
}
if (params.enable_ir_optimization()) {
analysis_config.SwitchIrOptim(true);
} else {
analysis_config.SwitchIrOptim(false);
}
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
......
......@@ -198,6 +198,12 @@ class FluidGpuAnalysisDirCore : public FluidFamilyCore {
analysis_config.EnableMemoryOptim();
}
if (params.enable_ir_optimization()) {
analysis_config.SwitchIrOptim(true);
} else {
analysis_config.SwitchIrOptim(false);
}
AutoLock lock(GlobalPaddleCreateMutex::instance());
_core =
paddle::CreatePaddlePredictor<paddle::AnalysisConfig>(analysis_config);
......
......@@ -127,6 +127,7 @@ class Server(object):
self.model_toolkit_conf = None
self.resource_conf = None
self.memory_optimization = False
self.ir_optimization = False
self.model_conf = None
self.workflow_fn = "workflow.prototxt"
self.resource_fn = "resource.prototxt"
......@@ -175,6 +176,9 @@ class Server(object):
def set_memory_optimize(self, flag=False):
self.memory_optimization = flag
def set_ir_optimize(self, flag=False):
self.ir_optimization = flag
def check_local_bin(self):
if "SERVING_BIN" in os.environ:
self.use_local_bin = True
......@@ -195,6 +199,7 @@ class Server(object):
engine.enable_batch_align = 0
engine.model_data_path = model_config_path
engine.enable_memory_optimization = self.memory_optimization
engine.enable_ir_optimization = self.ir_optimization
engine.static_optimization = False
engine.force_update_static_cache = False
......
......@@ -41,6 +41,8 @@ def parse_args(): # pylint: disable=doc-string-missing
"--device", type=str, default="cpu", help="Type of device")
parser.add_argument(
"--mem_optim", type=bool, default=False, help="Memory optimize")
parser.add_argument(
"--ir_optim", type=bool, default=False, help="Graph optimize")
parser.add_argument(
"--max_body_size",
type=int,
......@@ -57,6 +59,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
workdir = args.workdir
device = args.device
mem_optim = args.mem_optim
ir_optim = args.ir_optim
max_body_size = args.max_body_size
if model == "":
......@@ -78,6 +81,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.set_max_body_size(max_body_size)
server.set_port(port)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册