From b367361e8af6ab1a81e8482a12250c76028ff700 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Wed, 5 Jan 2022 07:04:43 +0100 Subject: [PATCH] Add enable_mkldnn to static/deploy/python/infer.py (#5049) * Add enable_mkldnn to infer.py --- static/deploy/python/infer.py | 37 +++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/static/deploy/python/infer.py b/static/deploy/python/infer.py index d3afb138e..2ac447d88 100644 --- a/static/deploy/python/infer.py +++ b/static/deploy/python/infer.py @@ -59,6 +59,7 @@ class Detector(object): device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU run_mode (str): mode of running(fluid/trt_fp32/trt_fp16) threshold (float): threshold to reserve the result for output. + enable_mkldnn (bool): whether use mkldnn with CPU. """ def __init__(self, @@ -67,7 +68,8 @@ class Detector(object): device='CPU', run_mode='fluid', threshold=0.5, - trt_calib_mode=False): + trt_calib_mode=False, + enable_mkldnn=False): self.config = config if self.config.use_python_inference: self.executor, self.program, self.fecth_targets = load_executor( @@ -78,7 +80,8 @@ class Detector(object): run_mode=run_mode, min_subgraph_size=self.config.min_subgraph_size, device=device, - trt_calib_mode=trt_calib_mode) + trt_calib_mode=trt_calib_mode, + enable_mkldnn=enable_mkldnn) def preprocess(self, im): preprocess_ops = [] @@ -225,14 +228,16 @@ class DetectorSOLOv2(Detector): device='CPU', run_mode='fluid', threshold=0.5, - trt_calib_mode=False): + trt_calib_mode=False, + enable_mkldnn=False): super(DetectorSOLOv2, self).__init__( config=config, model_dir=model_dir, device=device, run_mode=run_mode, threshold=threshold, - trt_calib_mode=trt_calib_mode) + trt_calib_mode=trt_calib_mode, + enable_mkldn=enable_mkldnn) def predict(self, image, @@ -385,13 +390,15 @@ def load_predictor(model_dir, batch_size=1, device='CPU', min_subgraph_size=3, - trt_calib_mode=False): + trt_calib_mode=False, + enable_mkldnn=False): """set AnalysisConfig, generate AnalysisPredictor Args: model_dir (str): root path of __model__ and __params__ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU trt_calib_mode (bool): If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True + enable_mkldnn (bool): Whether use mkldnn with CPU, default is False Returns: predictor (PaddlePredictor): AnalysisPredictor Raises: @@ -419,7 +426,10 @@ def load_predictor(model_dir, config.enable_xpu(10 * 1024 * 1024) else: config.disable_gpu() - + if enable_mkldnn: + config.set_mkldnn_cache_capacity(0) + config.enable_mkldnn() + config.pass_builder().append_pass("interpolate_mkldnn_pass") if run_mode in precision_map.keys(): config.enable_tensorrt_engine( workspace_size=1 << 10, @@ -432,7 +442,8 @@ def load_predictor(model_dir, # disable print log when predict config.disable_glog_info() # enable shared memory - config.enable_memory_optim() + if (not enable_mkldnn): + config.enable_memory_optim() # disable feed, fetch OP, needed by zero_copy_run config.switch_use_feed_fetch_ops(False) predictor = fluid.core.create_paddle_predictor(config) @@ -545,14 +556,16 @@ def main(): FLAGS.model_dir, device=FLAGS.device, run_mode=FLAGS.run_mode, - trt_calib_mode=FLAGS.trt_calib_mode) + trt_calib_mode=FLAGS.trt_calib_mode, + enable_mkldnn=FLAGS.enable_mkldnn) if config.arch == 'SOLOv2': detector = DetectorSOLOv2( config, FLAGS.model_dir, device=FLAGS.device, run_mode=FLAGS.run_mode, - trt_calib_mode=FLAGS.trt_calib_mode) + trt_calib_mode=FLAGS.trt_calib_mode, + enable_mkldnn=FLAGS.enable_mkldnn) # predict from image if FLAGS.image_file != '': predict_image(detector) @@ -618,7 +631,11 @@ if __name__ == '__main__': default=False, help="If the model is produced by TRT offline quantitative " "calibration, trt_calib_mode need to set True.") - + parser.add_argument( + "--enable_mkldnn", + type=ast.literal_eval, + default=False, + help="Whether use mkldnn with CPU.") FLAGS = parser.parse_args() print_arguments(FLAGS) if FLAGS.image_file != '' and FLAGS.video_file != '': -- GitLab