From 4045bdeee67386a840b5bc7fcd284e713a6a3597 Mon Sep 17 00:00:00 2001 From: TeslaZhao Date: Wed, 26 May 2021 19:56:19 +0800 Subject: [PATCH] python pipeline add mkldnn --- python/examples/pipeline/ocr/benchmark.sh | 8 +-- python/examples/pipeline/ocr/config.yml | 28 ++++++++-- .../pipeline/simple_web_service/config.yml | 15 ++++-- python/paddle_serving_app/local_predict.py | 51 ++++++++++++++----- python/pipeline/local_service_handler.py | 38 +++++++++++--- python/pipeline/operator.py | 50 +++++++++++++++--- python/pipeline/pipeline_server.py | 8 +++ 7 files changed, 159 insertions(+), 39 deletions(-) diff --git a/python/examples/pipeline/ocr/benchmark.sh b/python/examples/pipeline/ocr/benchmark.sh index bf9ac2b0..e9f3b9eb 100644 --- a/python/examples/pipeline/ocr/benchmark.sh +++ b/python/examples/pipeline/ocr/benchmark.sh @@ -1,5 +1,5 @@ export FLAGS_profile_pipeline=1 -alias python3="python3.6" +alias python3="python3.7" modelname="ocr" # HTTP @@ -11,11 +11,11 @@ rm -rf profile_log_$modelname echo "Starting HTTP Clients..." # Start a client in each thread, tesing the case of multiple threads. -for thread_num in 1 2 4 8 12 16 +for thread_num in 1 2 4 6 8 12 16 do for batch_size in 1 do - echo '----$modelname thread num: $thread_num batch size: $batch_size mode:http ----' >>profile_log_$modelname + echo "----$modelname thread num: $thread_num batch size: $batch_size mode:http ----" >>profile_log_$modelname # Start one web service, If you start the service yourself, you can ignore it here. #python3 web_service.py >web.log 2>&1 & #sleep 3 @@ -51,7 +51,7 @@ sleep 3 # Create yaml,If you already have the config.yaml, ignore it. #python3 benchmark.py yaml local_predictor 1 gpu -rm -rf profile_log_$modelname +#rm -rf profile_log_$modelname # Start a client in each thread, tesing the case of multiple threads. for thread_num in 1 2 4 6 8 12 16 diff --git a/python/examples/pipeline/ocr/config.yml b/python/examples/pipeline/ocr/config.yml index 4f725e18..58e3ed54 100644 --- a/python/examples/pipeline/ocr/config.yml +++ b/python/examples/pipeline/ocr/config.yml @@ -6,7 +6,7 @@ http_port: 9999 #worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG ##当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num -worker_num: 5 +worker_num: 20 #build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG build_dag_each_worker: false @@ -26,7 +26,7 @@ dag: op: det: #并发数,is_thread_op=True时,为线程并发;否则为进程并发 - concurrency: 2 + concurrency: 6 #当op配置没有server_endpoints时,从local_service_conf读取本地服务配置 local_service_conf: @@ -40,10 +40,19 @@ op: fetch_list: ["concat_1.tmp_0"] #计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡 - devices: "0" + devices: "" + + #use_mkldnn + #use_mkldnn: True + + #thread_num + thread_num: 2 + + #ir_optim + ir_optim: True rec: #并发数,is_thread_op=True时,为线程并发;否则为进程并发 - concurrency: 2 + concurrency: 3 #超时时间, 单位ms timeout: -1 @@ -64,4 +73,13 @@ op: fetch_list: ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] #计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡 - devices: "0" + devices: "" + + #use_mkldnn + #use_mkldnn: True + + #thread_num + thread_num: 2 + + #ir_optim + ir_optim: True diff --git a/python/examples/pipeline/simple_web_service/config.yml b/python/examples/pipeline/simple_web_service/config.yml index 43455bcb..12fae64e 100644 --- a/python/examples/pipeline/simple_web_service/config.yml +++ b/python/examples/pipeline/simple_web_service/config.yml @@ -9,10 +9,14 @@ http_port: 18082 dag: #op资源类型, True, 为线程模型;False,为进程模型 is_thread_op: False + + #tracer + tracer: + interval_s: 10 op: uci: #并发数,is_thread_op=True时,为线程并发;否则为进程并发 - concurrency: 2 + concurrency: 1 #当op配置没有server_endpoints时,从local_service_conf读取本地服务配置 local_service_conf: @@ -35,7 +39,10 @@ op: #precsion, 预测精度,降低预测精度可提升预测速度 #GPU 支持: "fp32"(default), "fp16", "int8"; #CPU 支持: "fp32"(default), "fp16", "bf16"(mkldnn); 不支持: "int8" - precision: "FP16" + precision: "fp32" + + #ir_optim开关, 默认False + ir_optim: True - #ir_optim开关 - ir_optim: False + #use_mkldnn开关, 默认False, use_mkldnn与ir_optim同时打开才有性能提升 + use_mkldnn: True diff --git a/python/paddle_serving_app/local_predict.py b/python/paddle_serving_app/local_predict.py index 945e891a..382d2317 100644 --- a/python/paddle_serving_app/local_predict.py +++ b/python/paddle_serving_app/local_predict.py @@ -64,6 +64,10 @@ class LocalPredictor(object): use_xpu=False, precision="fp32", use_calib=False, + use_mkldnn=False, + mkldnn_cache_capacity=0, + mkldnn_op_list=None, + mkldnn_bf16_op_list=None, use_feed_fetch_ops=False): """ Load model configs and create the paddle predictor by Paddle Inference API. @@ -73,7 +77,7 @@ class LocalPredictor(object): use_gpu: calculating with gpu, False default. gpu_id: gpu id, 0 default. use_profile: use predictor profiles, False default. - thread_num: thread nums, default 1. + thread_num: thread nums of cpu math library, default 1. mem_optim: memory optimization, True default. ir_optim: open calculation chart optimization, False default. use_trt: use nvidia TensorRT optimization, False default @@ -81,6 +85,10 @@ class LocalPredictor(object): use_xpu: run predict on Baidu Kunlun, False default precision: precision mode, "fp32" default use_calib: use TensorRT calibration, False default + use_mkldnn: use MKLDNN, False default. + mkldnn_cache_capacity: cache capacity for input shapes, 0 default. + mkldnn_op_list: op list accelerated using MKLDNN, None default. + mkldnn_bf16_op_list: op list accelerated using MKLDNN bf16, None default. use_feed_fetch_ops: use feed/fetch ops, False default. """ client_config = "{}/serving_server_conf.prototxt".format(model_path) @@ -96,13 +104,15 @@ class LocalPredictor(object): config = paddle_infer.Config(model_path) logger.info( - "LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\ - gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\ - use_trt:{}, use_lite:{}, use_xpu: {}, precision: {}, use_calib: {},\ - use_feed_fetch_ops:{}" - .format(model_path, use_gpu, gpu_id, use_profile, thread_num, - mem_optim, ir_optim, use_trt, use_lite, use_xpu, precision, - use_calib, use_feed_fetch_ops)) + "LocalPredictor load_model_config params: model_path:{}, use_gpu:{}, " + "gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{}, " + "use_trt:{}, use_lite:{}, use_xpu:{}, precision:{}, use_calib:{}, " + "use_mkldnn:{}, mkldnn_cache_capacity:{}, mkldnn_op_list:{}, " + "mkldnn_bf16_op_list:{}, use_feed_fetch_ops:{}, ".format( + model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim, + ir_optim, use_trt, use_lite, use_xpu, precision, use_calib, + use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list, + mkldnn_bf16_op_list, use_feed_fetch_ops)) self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] @@ -118,21 +128,35 @@ class LocalPredictor(object): self.fetch_names_to_idx_[var.alias_name] = i self.fetch_names_to_type_[var.alias_name] = var.fetch_type + # set precision of inference. precision_type = paddle_infer.PrecisionType.Float32 if precision is not None and precision.lower() in precision_map: precision_type = precision_map[precision.lower()] else: logger.warning("precision error!!! Please check precision:{}". format(precision)) + # set profile if use_profile: config.enable_profile() + # set memory optimization if mem_optim: config.enable_memory_optim() + # set ir optimization, threads of cpu math library config.switch_ir_optim(ir_optim) - config.set_cpu_math_library_num_threads(thread_num) + # use feed & fetch ops config.switch_use_feed_fetch_ops(use_feed_fetch_ops) + # pass optim config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + # set cpu & mkldnn + config.set_cpu_math_library_num_threads(thread_num) + if use_mkldnn: + config.enable_mkldnn() + if mkldnn_cache_capacity > 0: + config.set_mkldnn_cache_capacity(mkldnn_cache_capacity) + if mkldnn_op_list is not None: + config.set_mkldnn_op(mkldnn_op_list) + # set gpu if not use_gpu: config.disable_gpu() else: @@ -145,18 +169,18 @@ class LocalPredictor(object): min_subgraph_size=3, use_static=False, use_calib_mode=False) - + # set lite if use_lite: config.enable_lite_engine( precision_mode=precision_type, zero_copy=True, passes_filter=[], ops_filter=[]) - + # set xpu if use_xpu: # 2MB l3 cache config.enable_xpu(8 * 1024 * 1024) - + # set cpu low precision if not use_gpu and not use_lite: if precision_type == paddle_infer.PrecisionType.Int8: logger.warning( @@ -165,6 +189,9 @@ class LocalPredictor(object): #config.enable_quantizer() if precision is not None and precision.lower() == "bf16": config.enable_mkldnn_bfloat16() + if mkldnn_bf16_op_list is not None: + config.set_bfloat16_op(mkldnn_bf16_op_list) + self.predictor = paddle_infer.create_predictor(config) def predict(self, feed=None, fetch=None, batch=False, log_id=0): diff --git a/python/pipeline/local_service_handler.py b/python/pipeline/local_service_handler.py index a15a3eeb..d04b9654 100644 --- a/python/pipeline/local_service_handler.py +++ b/python/pipeline/local_service_handler.py @@ -45,7 +45,11 @@ class LocalServiceHandler(object): ir_optim=False, available_port_generator=None, use_profile=False, - precision="fp32"): + precision="fp32", + use_mkldnn=False, + mkldnn_cache_capacity=0, + mkldnn_op_list=None, + mkldnn_bf16_op_list=None): """ Initialization of localservicehandler @@ -64,6 +68,10 @@ class LocalServiceHandler(object): available_port_generator: generate available ports use_profile: use profiling, False default. precision: inference precesion, e.g. "fp32", "fp16", "int8" + use_mkldnn: use mkldnn, default False. + mkldnn_cache_capacity: cache capacity of mkldnn, 0 means no limit. + mkldnn_op_list: OP list optimized by mkldnn, None default. + mkldnn_bf16_op_list: OP list optimized by mkldnn bf16, None default. Returns: None @@ -78,6 +86,10 @@ class LocalServiceHandler(object): self._use_trt = False self._use_lite = False self._use_xpu = False + self._use_mkldnn = False + self._mkldnn_cache_capacity = 0 + self._mkldnn_op_list = None + self._mkldnn_bf16_op_list = None if device_type == -1: # device_type is not set, determined by `devices`, @@ -140,16 +152,24 @@ class LocalServiceHandler(object): self._use_profile = use_profile self._fetch_names = fetch_names self._precision = precision + self._use_mkldnn = use_mkldnn + self._mkldnn_cache_capacity = mkldnn_cache_capacity + self._mkldnn_op_list = mkldnn_op_list + self._mkldnn_bf16_op_list = mkldnn_bf16_op_list _LOGGER.info( "Models({}) will be launched by device {}. use_gpu:{}, " "use_trt:{}, use_lite:{}, use_xpu:{}, device_type:{}, devices:{}, " "mem_optim:{}, ir_optim:{}, use_profile:{}, thread_num:{}, " - "client_type:{}, fetch_names:{} precision:{}".format( + "client_type:{}, fetch_names:{}, precision:{}, use_mkldnn:{}, " + "mkldnn_cache_capacity:{}, mkldnn_op_list:{}, " + "mkldnn_bf16_op_list:{}".format( model_config, self._device_name, self._use_gpu, self._use_trt, - self._use_lite, self._use_xpu, device_type, self._devices, self. - _mem_optim, self._ir_optim, self._use_profile, self._thread_num, - self._client_type, self._fetch_names, self._precision)) + self._use_lite, self._use_xpu, device_type, self._devices, + self._mem_optim, self._ir_optim, self._use_profile, + self._thread_num, self._client_type, self._fetch_names, + self._precision, self._use_mkldnn, self._mkldnn_cache_capacity, + self._mkldnn_op_list, self._mkldnn_bf16_op_list)) def get_fetch_list(self): return self._fetch_names @@ -189,7 +209,7 @@ class LocalServiceHandler(object): from paddle_serving_app.local_predict import LocalPredictor if self._local_predictor_client is None: self._local_predictor_client = LocalPredictor() - + # load model config and init predictor self._local_predictor_client.load_model_config( model_path=self._model_config, use_gpu=self._use_gpu, @@ -201,7 +221,11 @@ class LocalServiceHandler(object): use_trt=self._use_trt, use_lite=self._use_lite, use_xpu=self._use_xpu, - precision=self._precision) + precision=self._precision, + use_mkldnn=self._use_mkldnn, + mkldnn_cache_capacity=self._mkldnn_cache_capacity, + mkldnn_op_list=self._mkldnn_op_list, + mkldnn_bf16_op_list=self._mkldnn_bf16_op_list) return self._local_predictor_client def get_client_config(self): diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 821daee6..eab2c3a5 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -139,6 +139,11 @@ class Op(object): self.mem_optim = False self.ir_optim = False self.precision = "fp32" + self.use_mkldnn = False + self.mkldnn_cache_capacity = 0 + self.mkldnn_op_list = None + self.mkldnn_bf16_op_list = None + if self._server_endpoints is None: server_endpoints = conf.get("server_endpoints", []) if len(server_endpoints) != 0: @@ -161,6 +166,14 @@ class Op(object): self.ir_optim = local_service_conf.get("ir_optim") self._fetch_names = local_service_conf.get("fetch_list") self.precision = local_service_conf.get("precision") + self.use_mkldnn = local_service_conf.get("use_mkldnn") + self.mkldnn_cache_capacity = local_service_conf.get( + "mkldnn_cache_capacity") + self.mkldnn_op_list = local_service_conf.get( + "mkldnn_op_list") + self.mkldnn_bf16_op_list = local_service_conf.get( + "mkldnn_bf16_op_list") + if self.model_config is None: self.with_serving = False else: @@ -176,7 +189,12 @@ class Op(object): devices=self.devices, mem_optim=self.mem_optim, ir_optim=self.ir_optim, - precision=self.precision) + precision=self.precision, + use_mkldnn=self.use_mkldnn, + mkldnn_cache_capacity=self. + mkldnn_cache_capacity, + mkldnn_op_list=self.mkldnn_bf16_op_list, + mkldnn_bf16_op_list=self.mkldnn_bf16_op_list) service_handler.prepare_server() # get fetch_list serivce_ports = service_handler.get_port_list() self._server_endpoints = [ @@ -199,7 +217,12 @@ class Op(object): fetch_names=self._fetch_names, mem_optim=self.mem_optim, ir_optim=self.ir_optim, - precision=self.precision) + precision=self.precision, + use_mkldnn=self.use_mkldnn, + mkldnn_cache_capacity=self. + mkldnn_cache_capacity, + mkldnn_op_list=self.mkldnn_op_list, + mkldnn_bf16_op_list=self.mkldnn_bf16_op_list) if self._client_config is None: self._client_config = service_handler.get_client_config( ) @@ -564,7 +587,9 @@ class Op(object): self._get_output_channels(), False, trace_buffer, self.model_config, self.workdir, self.thread_num, self.device_type, self.devices, self.mem_optim, - self.ir_optim, self.precision)) + self.ir_optim, self.precision, self.use_mkldnn, + self.mkldnn_cache_capacity, self.mkldnn_op_list, + self.mkldnn_bf16_op_list)) p.daemon = True p.start() process.append(p) @@ -598,7 +623,9 @@ class Op(object): self._get_output_channels(), True, trace_buffer, self.model_config, self.workdir, self.thread_num, self.device_type, self.devices, self.mem_optim, - self.ir_optim, self.precision)) + self.ir_optim, self.precision, self.use_mkldnn, + self.mkldnn_cache_capacity, self.mkldnn_op_list, + self.mkldnn_bf16_op_list)) # When a process exits, it attempts to terminate # all of its daemonic child processes. t.daemon = True @@ -1068,7 +1095,8 @@ class Op(object): def _run(self, concurrency_idx, input_channel, output_channels, is_thread_op, trace_buffer, model_config, workdir, thread_num, - device_type, devices, mem_optim, ir_optim, precision): + device_type, devices, mem_optim, ir_optim, precision, use_mkldnn, + mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list): """ _run() is the entry function of OP process / thread model.When client type is local_predictor in process mode, the CUDA environment needs to @@ -1090,7 +1118,11 @@ class Op(object): devices: gpu id list[gpu], "" default[cpu] mem_optim: use memory/graphics memory optimization, True default. ir_optim: use calculation chart optimization, False default. - precision: inference precision, e.g. "fp32", "fp16", "int8" + precision: inference precision, e.g. "fp32", "fp16", "int8", "bf16" + use_mkldnn: use mkldnn, default False. + mkldnn_cache_capacity: cache capacity of mkldnn, 0 means no limit. + mkldnn_op_list: OP list optimized by mkldnn, None default. + mkldnn_bf16_op_list: OP list optimized by mkldnn bf16, None default. Returns: None @@ -1110,7 +1142,11 @@ class Op(object): devices=devices, mem_optim=mem_optim, ir_optim=ir_optim, - precision=precision) + precision=precision, + use_mkldnn=use_mkldnn, + mkldnn_cache_capacity=mkldnn_cache_capacity, + mkldnn_op_list=mkldnn_op_list, + mkldnn_bf16_op_list=mkldnn_bf16_op_list) _LOGGER.info("Init cuda env in process {}".format( concurrency_idx)) diff --git a/python/pipeline/pipeline_server.py b/python/pipeline/pipeline_server.py index 0afa3872..5fcc3187 100644 --- a/python/pipeline/pipeline_server.py +++ b/python/pipeline/pipeline_server.py @@ -239,6 +239,8 @@ class PipelineServer(object): "ir_optim": False, "precision": "fp32", "use_calib": False, + "use_mkldnn": False, + "mkldnn_cache_capacity": 0, }, } for op in self._used_op: @@ -397,6 +399,8 @@ class ServerYamlConfChecker(object): "ir_optim": False, "precision": "fp32", "use_calib": False, + "use_mkldnn": False, + "mkldnn_cache_capacity": 0, } conf_type = { "model_config": str, @@ -408,6 +412,10 @@ class ServerYamlConfChecker(object): "ir_optim": bool, "precision": str, "use_calib": bool, + "use_mkldnn": bool, + "mkldnn_cache_capacity": int, + "mkldnn_op_list": list, + "mkldnn_bf16_op_list": list, } conf_qualification = {"thread_num": (">=", 1), } ServerYamlConfChecker.check_conf(conf, default_conf, conf_type, -- GitLab