未验证 提交 b367361e 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add enable_mkldnn to static/deploy/python/infer.py (#5049)

* Add enable_mkldnn to infer.py
上级 0978885e
...@@ -59,6 +59,7 @@ class Detector(object): ...@@ -59,6 +59,7 @@ class Detector(object):
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
threshold (float): threshold to reserve the result for output. threshold (float): threshold to reserve the result for output.
enable_mkldnn (bool): whether use mkldnn with CPU.
""" """
def __init__(self, def __init__(self,
...@@ -67,7 +68,8 @@ class Detector(object): ...@@ -67,7 +68,8 @@ class Detector(object):
device='CPU', device='CPU',
run_mode='fluid', run_mode='fluid',
threshold=0.5, threshold=0.5,
trt_calib_mode=False): trt_calib_mode=False,
enable_mkldnn=False):
self.config = config self.config = config
if self.config.use_python_inference: if self.config.use_python_inference:
self.executor, self.program, self.fecth_targets = load_executor( self.executor, self.program, self.fecth_targets = load_executor(
...@@ -78,7 +80,8 @@ class Detector(object): ...@@ -78,7 +80,8 @@ class Detector(object):
run_mode=run_mode, run_mode=run_mode,
min_subgraph_size=self.config.min_subgraph_size, min_subgraph_size=self.config.min_subgraph_size,
device=device, device=device,
trt_calib_mode=trt_calib_mode) trt_calib_mode=trt_calib_mode,
enable_mkldnn=enable_mkldnn)
def preprocess(self, im): def preprocess(self, im):
preprocess_ops = [] preprocess_ops = []
...@@ -225,14 +228,16 @@ class DetectorSOLOv2(Detector): ...@@ -225,14 +228,16 @@ class DetectorSOLOv2(Detector):
device='CPU', device='CPU',
run_mode='fluid', run_mode='fluid',
threshold=0.5, threshold=0.5,
trt_calib_mode=False): trt_calib_mode=False,
enable_mkldnn=False):
super(DetectorSOLOv2, self).__init__( super(DetectorSOLOv2, self).__init__(
config=config, config=config,
model_dir=model_dir, model_dir=model_dir,
device=device, device=device,
run_mode=run_mode, run_mode=run_mode,
threshold=threshold, threshold=threshold,
trt_calib_mode=trt_calib_mode) trt_calib_mode=trt_calib_mode,
enable_mkldn=enable_mkldnn)
def predict(self, def predict(self,
image, image,
...@@ -385,13 +390,15 @@ def load_predictor(model_dir, ...@@ -385,13 +390,15 @@ def load_predictor(model_dir,
batch_size=1, batch_size=1,
device='CPU', device='CPU',
min_subgraph_size=3, min_subgraph_size=3,
trt_calib_mode=False): trt_calib_mode=False,
enable_mkldnn=False):
"""set AnalysisConfig, generate AnalysisPredictor """set AnalysisConfig, generate AnalysisPredictor
Args: Args:
model_dir (str): root path of __model__ and __params__ model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
trt_calib_mode (bool): If the model is produced by TRT offline quantitative trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True calibration, trt_calib_mode need to set True
enable_mkldnn (bool): Whether use mkldnn with CPU, default is False
Returns: Returns:
predictor (PaddlePredictor): AnalysisPredictor predictor (PaddlePredictor): AnalysisPredictor
Raises: Raises:
...@@ -419,7 +426,10 @@ def load_predictor(model_dir, ...@@ -419,7 +426,10 @@ def load_predictor(model_dir,
config.enable_xpu(10 * 1024 * 1024) config.enable_xpu(10 * 1024 * 1024)
else: else:
config.disable_gpu() config.disable_gpu()
if enable_mkldnn:
config.set_mkldnn_cache_capacity(0)
config.enable_mkldnn()
config.pass_builder().append_pass("interpolate_mkldnn_pass")
if run_mode in precision_map.keys(): if run_mode in precision_map.keys():
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
workspace_size=1 << 10, workspace_size=1 << 10,
...@@ -432,6 +442,7 @@ def load_predictor(model_dir, ...@@ -432,6 +442,7 @@ def load_predictor(model_dir,
# disable print log when predict # disable print log when predict
config.disable_glog_info() config.disable_glog_info()
# enable shared memory # enable shared memory
if (not enable_mkldnn):
config.enable_memory_optim() config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run # disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
...@@ -545,14 +556,16 @@ def main(): ...@@ -545,14 +556,16 @@ def main():
FLAGS.model_dir, FLAGS.model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
trt_calib_mode=FLAGS.trt_calib_mode) trt_calib_mode=FLAGS.trt_calib_mode,
enable_mkldnn=FLAGS.enable_mkldnn)
if config.arch == 'SOLOv2': if config.arch == 'SOLOv2':
detector = DetectorSOLOv2( detector = DetectorSOLOv2(
config, config,
FLAGS.model_dir, FLAGS.model_dir,
device=FLAGS.device, device=FLAGS.device,
run_mode=FLAGS.run_mode, run_mode=FLAGS.run_mode,
trt_calib_mode=FLAGS.trt_calib_mode) trt_calib_mode=FLAGS.trt_calib_mode,
enable_mkldnn=FLAGS.enable_mkldnn)
# predict from image # predict from image
if FLAGS.image_file != '': if FLAGS.image_file != '':
predict_image(detector) predict_image(detector)
...@@ -618,7 +631,11 @@ if __name__ == '__main__': ...@@ -618,7 +631,11 @@ if __name__ == '__main__':
default=False, default=False,
help="If the model is produced by TRT offline quantitative " help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.") "calibration, trt_calib_mode need to set True.")
parser.add_argument(
"--enable_mkldnn",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
print_arguments(FLAGS) print_arguments(FLAGS)
if FLAGS.image_file != '' and FLAGS.video_file != '': if FLAGS.image_file != '' and FLAGS.video_file != '':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册