提交 4045bdee 编写于 作者: T TeslaZhao

python pipeline add mkldnn

上级 d96e4b24
export FLAGS_profile_pipeline=1 export FLAGS_profile_pipeline=1
alias python3="python3.6" alias python3="python3.7"
modelname="ocr" modelname="ocr"
# HTTP # HTTP
...@@ -11,11 +11,11 @@ rm -rf profile_log_$modelname ...@@ -11,11 +11,11 @@ rm -rf profile_log_$modelname
echo "Starting HTTP Clients..." echo "Starting HTTP Clients..."
# Start a client in each thread, tesing the case of multiple threads. # Start a client in each thread, tesing the case of multiple threads.
for thread_num in 1 2 4 8 12 16 for thread_num in 1 2 4 6 8 12 16
do do
for batch_size in 1 for batch_size in 1
do do
echo '----$modelname thread num: $thread_num batch size: $batch_size mode:http ----' >>profile_log_$modelname echo "----$modelname thread num: $thread_num batch size: $batch_size mode:http ----" >>profile_log_$modelname
# Start one web service, If you start the service yourself, you can ignore it here. # Start one web service, If you start the service yourself, you can ignore it here.
#python3 web_service.py >web.log 2>&1 & #python3 web_service.py >web.log 2>&1 &
#sleep 3 #sleep 3
...@@ -51,7 +51,7 @@ sleep 3 ...@@ -51,7 +51,7 @@ sleep 3
# Create yaml,If you already have the config.yaml, ignore it. # Create yaml,If you already have the config.yaml, ignore it.
#python3 benchmark.py yaml local_predictor 1 gpu #python3 benchmark.py yaml local_predictor 1 gpu
rm -rf profile_log_$modelname #rm -rf profile_log_$modelname
# Start a client in each thread, tesing the case of multiple threads. # Start a client in each thread, tesing the case of multiple threads.
for thread_num in 1 2 4 6 8 12 16 for thread_num in 1 2 4 6 8 12 16
......
...@@ -6,7 +6,7 @@ http_port: 9999 ...@@ -6,7 +6,7 @@ http_port: 9999
#worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG #worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG
##当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num ##当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num
worker_num: 5 worker_num: 20
#build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG #build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG
build_dag_each_worker: false build_dag_each_worker: false
...@@ -26,7 +26,7 @@ dag: ...@@ -26,7 +26,7 @@ dag:
op: op:
det: det:
#并发数,is_thread_op=True时,为线程并发;否则为进程并发 #并发数,is_thread_op=True时,为线程并发;否则为进程并发
concurrency: 2 concurrency: 6
#当op配置没有server_endpoints时,从local_service_conf读取本地服务配置 #当op配置没有server_endpoints时,从local_service_conf读取本地服务配置
local_service_conf: local_service_conf:
...@@ -40,10 +40,19 @@ op: ...@@ -40,10 +40,19 @@ op:
fetch_list: ["concat_1.tmp_0"] fetch_list: ["concat_1.tmp_0"]
#计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡 #计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
devices: "0" devices: ""
#use_mkldnn
#use_mkldnn: True
#thread_num
thread_num: 2
#ir_optim
ir_optim: True
rec: rec:
#并发数,is_thread_op=True时,为线程并发;否则为进程并发 #并发数,is_thread_op=True时,为线程并发;否则为进程并发
concurrency: 2 concurrency: 3
#超时时间, 单位ms #超时时间, 单位ms
timeout: -1 timeout: -1
...@@ -64,4 +73,13 @@ op: ...@@ -64,4 +73,13 @@ op:
fetch_list: ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"] fetch_list: ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
#计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡 #计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
devices: "0" devices: ""
#use_mkldnn
#use_mkldnn: True
#thread_num
thread_num: 2
#ir_optim
ir_optim: True
...@@ -9,10 +9,14 @@ http_port: 18082 ...@@ -9,10 +9,14 @@ http_port: 18082
dag: dag:
#op资源类型, True, 为线程模型;False,为进程模型 #op资源类型, True, 为线程模型;False,为进程模型
is_thread_op: False is_thread_op: False
#tracer
tracer:
interval_s: 10
op: op:
uci: uci:
#并发数,is_thread_op=True时,为线程并发;否则为进程并发 #并发数,is_thread_op=True时,为线程并发;否则为进程并发
concurrency: 2 concurrency: 1
#当op配置没有server_endpoints时,从local_service_conf读取本地服务配置 #当op配置没有server_endpoints时,从local_service_conf读取本地服务配置
local_service_conf: local_service_conf:
...@@ -35,7 +39,10 @@ op: ...@@ -35,7 +39,10 @@ op:
#precsion, 预测精度,降低预测精度可提升预测速度 #precsion, 预测精度,降低预测精度可提升预测速度
#GPU 支持: "fp32"(default), "fp16", "int8"; #GPU 支持: "fp32"(default), "fp16", "int8";
#CPU 支持: "fp32"(default), "fp16", "bf16"(mkldnn); 不支持: "int8" #CPU 支持: "fp32"(default), "fp16", "bf16"(mkldnn); 不支持: "int8"
precision: "FP16" precision: "fp32"
#ir_optim开关, 默认False
ir_optim: True
#ir_optim开关 #use_mkldnn开关, 默认False, use_mkldnn与ir_optim同时打开才有性能提升
ir_optim: False use_mkldnn: True
...@@ -64,6 +64,10 @@ class LocalPredictor(object): ...@@ -64,6 +64,10 @@ class LocalPredictor(object):
use_xpu=False, use_xpu=False,
precision="fp32", precision="fp32",
use_calib=False, use_calib=False,
use_mkldnn=False,
mkldnn_cache_capacity=0,
mkldnn_op_list=None,
mkldnn_bf16_op_list=None,
use_feed_fetch_ops=False): use_feed_fetch_ops=False):
""" """
Load model configs and create the paddle predictor by Paddle Inference API. Load model configs and create the paddle predictor by Paddle Inference API.
...@@ -73,7 +77,7 @@ class LocalPredictor(object): ...@@ -73,7 +77,7 @@ class LocalPredictor(object):
use_gpu: calculating with gpu, False default. use_gpu: calculating with gpu, False default.
gpu_id: gpu id, 0 default. gpu_id: gpu id, 0 default.
use_profile: use predictor profiles, False default. use_profile: use predictor profiles, False default.
thread_num: thread nums, default 1. thread_num: thread nums of cpu math library, default 1.
mem_optim: memory optimization, True default. mem_optim: memory optimization, True default.
ir_optim: open calculation chart optimization, False default. ir_optim: open calculation chart optimization, False default.
use_trt: use nvidia TensorRT optimization, False default use_trt: use nvidia TensorRT optimization, False default
...@@ -81,6 +85,10 @@ class LocalPredictor(object): ...@@ -81,6 +85,10 @@ class LocalPredictor(object):
use_xpu: run predict on Baidu Kunlun, False default use_xpu: run predict on Baidu Kunlun, False default
precision: precision mode, "fp32" default precision: precision mode, "fp32" default
use_calib: use TensorRT calibration, False default use_calib: use TensorRT calibration, False default
use_mkldnn: use MKLDNN, False default.
mkldnn_cache_capacity: cache capacity for input shapes, 0 default.
mkldnn_op_list: op list accelerated using MKLDNN, None default.
mkldnn_bf16_op_list: op list accelerated using MKLDNN bf16, None default.
use_feed_fetch_ops: use feed/fetch ops, False default. use_feed_fetch_ops: use feed/fetch ops, False default.
""" """
client_config = "{}/serving_server_conf.prototxt".format(model_path) client_config = "{}/serving_server_conf.prototxt".format(model_path)
...@@ -96,13 +104,15 @@ class LocalPredictor(object): ...@@ -96,13 +104,15 @@ class LocalPredictor(object):
config = paddle_infer.Config(model_path) config = paddle_infer.Config(model_path)
logger.info( logger.info(
"LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\ "LocalPredictor load_model_config params: model_path:{}, use_gpu:{}, "
gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\ "gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{}, "
use_trt:{}, use_lite:{}, use_xpu: {}, precision: {}, use_calib: {},\ "use_trt:{}, use_lite:{}, use_xpu:{}, precision:{}, use_calib:{}, "
use_feed_fetch_ops:{}" "use_mkldnn:{}, mkldnn_cache_capacity:{}, mkldnn_op_list:{}, "
.format(model_path, use_gpu, gpu_id, use_profile, thread_num, "mkldnn_bf16_op_list:{}, use_feed_fetch_ops:{}, ".format(
mem_optim, ir_optim, use_trt, use_lite, use_xpu, precision, model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
use_calib, use_feed_fetch_ops)) ir_optim, use_trt, use_lite, use_xpu, precision, use_calib,
use_mkldnn, mkldnn_cache_capacity, mkldnn_op_list,
mkldnn_bf16_op_list, use_feed_fetch_ops))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
...@@ -118,21 +128,35 @@ class LocalPredictor(object): ...@@ -118,21 +128,35 @@ class LocalPredictor(object):
self.fetch_names_to_idx_[var.alias_name] = i self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type self.fetch_names_to_type_[var.alias_name] = var.fetch_type
# set precision of inference.
precision_type = paddle_infer.PrecisionType.Float32 precision_type = paddle_infer.PrecisionType.Float32
if precision is not None and precision.lower() in precision_map: if precision is not None and precision.lower() in precision_map:
precision_type = precision_map[precision.lower()] precision_type = precision_map[precision.lower()]
else: else:
logger.warning("precision error!!! Please check precision:{}". logger.warning("precision error!!! Please check precision:{}".
format(precision)) format(precision))
# set profile
if use_profile: if use_profile:
config.enable_profile() config.enable_profile()
# set memory optimization
if mem_optim: if mem_optim:
config.enable_memory_optim() config.enable_memory_optim()
# set ir optimization, threads of cpu math library
config.switch_ir_optim(ir_optim) config.switch_ir_optim(ir_optim)
config.set_cpu_math_library_num_threads(thread_num) # use feed & fetch ops
config.switch_use_feed_fetch_ops(use_feed_fetch_ops) config.switch_use_feed_fetch_ops(use_feed_fetch_ops)
# pass optim
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
# set cpu & mkldnn
config.set_cpu_math_library_num_threads(thread_num)
if use_mkldnn:
config.enable_mkldnn()
if mkldnn_cache_capacity > 0:
config.set_mkldnn_cache_capacity(mkldnn_cache_capacity)
if mkldnn_op_list is not None:
config.set_mkldnn_op(mkldnn_op_list)
# set gpu
if not use_gpu: if not use_gpu:
config.disable_gpu() config.disable_gpu()
else: else:
...@@ -145,18 +169,18 @@ class LocalPredictor(object): ...@@ -145,18 +169,18 @@ class LocalPredictor(object):
min_subgraph_size=3, min_subgraph_size=3,
use_static=False, use_static=False,
use_calib_mode=False) use_calib_mode=False)
# set lite
if use_lite: if use_lite:
config.enable_lite_engine( config.enable_lite_engine(
precision_mode=precision_type, precision_mode=precision_type,
zero_copy=True, zero_copy=True,
passes_filter=[], passes_filter=[],
ops_filter=[]) ops_filter=[])
# set xpu
if use_xpu: if use_xpu:
# 2MB l3 cache # 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024) config.enable_xpu(8 * 1024 * 1024)
# set cpu low precision
if not use_gpu and not use_lite: if not use_gpu and not use_lite:
if precision_type == paddle_infer.PrecisionType.Int8: if precision_type == paddle_infer.PrecisionType.Int8:
logger.warning( logger.warning(
...@@ -165,6 +189,9 @@ class LocalPredictor(object): ...@@ -165,6 +189,9 @@ class LocalPredictor(object):
#config.enable_quantizer() #config.enable_quantizer()
if precision is not None and precision.lower() == "bf16": if precision is not None and precision.lower() == "bf16":
config.enable_mkldnn_bfloat16() config.enable_mkldnn_bfloat16()
if mkldnn_bf16_op_list is not None:
config.set_bfloat16_op(mkldnn_bf16_op_list)
self.predictor = paddle_infer.create_predictor(config) self.predictor = paddle_infer.create_predictor(config)
def predict(self, feed=None, fetch=None, batch=False, log_id=0): def predict(self, feed=None, fetch=None, batch=False, log_id=0):
......
...@@ -45,7 +45,11 @@ class LocalServiceHandler(object): ...@@ -45,7 +45,11 @@ class LocalServiceHandler(object):
ir_optim=False, ir_optim=False,
available_port_generator=None, available_port_generator=None,
use_profile=False, use_profile=False,
precision="fp32"): precision="fp32",
use_mkldnn=False,
mkldnn_cache_capacity=0,
mkldnn_op_list=None,
mkldnn_bf16_op_list=None):
""" """
Initialization of localservicehandler Initialization of localservicehandler
...@@ -64,6 +68,10 @@ class LocalServiceHandler(object): ...@@ -64,6 +68,10 @@ class LocalServiceHandler(object):
available_port_generator: generate available ports available_port_generator: generate available ports
use_profile: use profiling, False default. use_profile: use profiling, False default.
precision: inference precesion, e.g. "fp32", "fp16", "int8" precision: inference precesion, e.g. "fp32", "fp16", "int8"
use_mkldnn: use mkldnn, default False.
mkldnn_cache_capacity: cache capacity of mkldnn, 0 means no limit.
mkldnn_op_list: OP list optimized by mkldnn, None default.
mkldnn_bf16_op_list: OP list optimized by mkldnn bf16, None default.
Returns: Returns:
None None
...@@ -78,6 +86,10 @@ class LocalServiceHandler(object): ...@@ -78,6 +86,10 @@ class LocalServiceHandler(object):
self._use_trt = False self._use_trt = False
self._use_lite = False self._use_lite = False
self._use_xpu = False self._use_xpu = False
self._use_mkldnn = False
self._mkldnn_cache_capacity = 0
self._mkldnn_op_list = None
self._mkldnn_bf16_op_list = None
if device_type == -1: if device_type == -1:
# device_type is not set, determined by `devices`, # device_type is not set, determined by `devices`,
...@@ -140,16 +152,24 @@ class LocalServiceHandler(object): ...@@ -140,16 +152,24 @@ class LocalServiceHandler(object):
self._use_profile = use_profile self._use_profile = use_profile
self._fetch_names = fetch_names self._fetch_names = fetch_names
self._precision = precision self._precision = precision
self._use_mkldnn = use_mkldnn
self._mkldnn_cache_capacity = mkldnn_cache_capacity
self._mkldnn_op_list = mkldnn_op_list
self._mkldnn_bf16_op_list = mkldnn_bf16_op_list
_LOGGER.info( _LOGGER.info(
"Models({}) will be launched by device {}. use_gpu:{}, " "Models({}) will be launched by device {}. use_gpu:{}, "
"use_trt:{}, use_lite:{}, use_xpu:{}, device_type:{}, devices:{}, " "use_trt:{}, use_lite:{}, use_xpu:{}, device_type:{}, devices:{}, "
"mem_optim:{}, ir_optim:{}, use_profile:{}, thread_num:{}, " "mem_optim:{}, ir_optim:{}, use_profile:{}, thread_num:{}, "
"client_type:{}, fetch_names:{} precision:{}".format( "client_type:{}, fetch_names:{}, precision:{}, use_mkldnn:{}, "
"mkldnn_cache_capacity:{}, mkldnn_op_list:{}, "
"mkldnn_bf16_op_list:{}".format(
model_config, self._device_name, self._use_gpu, self._use_trt, model_config, self._device_name, self._use_gpu, self._use_trt,
self._use_lite, self._use_xpu, device_type, self._devices, self. self._use_lite, self._use_xpu, device_type, self._devices,
_mem_optim, self._ir_optim, self._use_profile, self._thread_num, self._mem_optim, self._ir_optim, self._use_profile,
self._client_type, self._fetch_names, self._precision)) self._thread_num, self._client_type, self._fetch_names,
self._precision, self._use_mkldnn, self._mkldnn_cache_capacity,
self._mkldnn_op_list, self._mkldnn_bf16_op_list))
def get_fetch_list(self): def get_fetch_list(self):
return self._fetch_names return self._fetch_names
...@@ -189,7 +209,7 @@ class LocalServiceHandler(object): ...@@ -189,7 +209,7 @@ class LocalServiceHandler(object):
from paddle_serving_app.local_predict import LocalPredictor from paddle_serving_app.local_predict import LocalPredictor
if self._local_predictor_client is None: if self._local_predictor_client is None:
self._local_predictor_client = LocalPredictor() self._local_predictor_client = LocalPredictor()
# load model config and init predictor
self._local_predictor_client.load_model_config( self._local_predictor_client.load_model_config(
model_path=self._model_config, model_path=self._model_config,
use_gpu=self._use_gpu, use_gpu=self._use_gpu,
...@@ -201,7 +221,11 @@ class LocalServiceHandler(object): ...@@ -201,7 +221,11 @@ class LocalServiceHandler(object):
use_trt=self._use_trt, use_trt=self._use_trt,
use_lite=self._use_lite, use_lite=self._use_lite,
use_xpu=self._use_xpu, use_xpu=self._use_xpu,
precision=self._precision) precision=self._precision,
use_mkldnn=self._use_mkldnn,
mkldnn_cache_capacity=self._mkldnn_cache_capacity,
mkldnn_op_list=self._mkldnn_op_list,
mkldnn_bf16_op_list=self._mkldnn_bf16_op_list)
return self._local_predictor_client return self._local_predictor_client
def get_client_config(self): def get_client_config(self):
......
...@@ -139,6 +139,11 @@ class Op(object): ...@@ -139,6 +139,11 @@ class Op(object):
self.mem_optim = False self.mem_optim = False
self.ir_optim = False self.ir_optim = False
self.precision = "fp32" self.precision = "fp32"
self.use_mkldnn = False
self.mkldnn_cache_capacity = 0
self.mkldnn_op_list = None
self.mkldnn_bf16_op_list = None
if self._server_endpoints is None: if self._server_endpoints is None:
server_endpoints = conf.get("server_endpoints", []) server_endpoints = conf.get("server_endpoints", [])
if len(server_endpoints) != 0: if len(server_endpoints) != 0:
...@@ -161,6 +166,14 @@ class Op(object): ...@@ -161,6 +166,14 @@ class Op(object):
self.ir_optim = local_service_conf.get("ir_optim") self.ir_optim = local_service_conf.get("ir_optim")
self._fetch_names = local_service_conf.get("fetch_list") self._fetch_names = local_service_conf.get("fetch_list")
self.precision = local_service_conf.get("precision") self.precision = local_service_conf.get("precision")
self.use_mkldnn = local_service_conf.get("use_mkldnn")
self.mkldnn_cache_capacity = local_service_conf.get(
"mkldnn_cache_capacity")
self.mkldnn_op_list = local_service_conf.get(
"mkldnn_op_list")
self.mkldnn_bf16_op_list = local_service_conf.get(
"mkldnn_bf16_op_list")
if self.model_config is None: if self.model_config is None:
self.with_serving = False self.with_serving = False
else: else:
...@@ -176,7 +189,12 @@ class Op(object): ...@@ -176,7 +189,12 @@ class Op(object):
devices=self.devices, devices=self.devices,
mem_optim=self.mem_optim, mem_optim=self.mem_optim,
ir_optim=self.ir_optim, ir_optim=self.ir_optim,
precision=self.precision) precision=self.precision,
use_mkldnn=self.use_mkldnn,
mkldnn_cache_capacity=self.
mkldnn_cache_capacity,
mkldnn_op_list=self.mkldnn_bf16_op_list,
mkldnn_bf16_op_list=self.mkldnn_bf16_op_list)
service_handler.prepare_server() # get fetch_list service_handler.prepare_server() # get fetch_list
serivce_ports = service_handler.get_port_list() serivce_ports = service_handler.get_port_list()
self._server_endpoints = [ self._server_endpoints = [
...@@ -199,7 +217,12 @@ class Op(object): ...@@ -199,7 +217,12 @@ class Op(object):
fetch_names=self._fetch_names, fetch_names=self._fetch_names,
mem_optim=self.mem_optim, mem_optim=self.mem_optim,
ir_optim=self.ir_optim, ir_optim=self.ir_optim,
precision=self.precision) precision=self.precision,
use_mkldnn=self.use_mkldnn,
mkldnn_cache_capacity=self.
mkldnn_cache_capacity,
mkldnn_op_list=self.mkldnn_op_list,
mkldnn_bf16_op_list=self.mkldnn_bf16_op_list)
if self._client_config is None: if self._client_config is None:
self._client_config = service_handler.get_client_config( self._client_config = service_handler.get_client_config(
) )
...@@ -564,7 +587,9 @@ class Op(object): ...@@ -564,7 +587,9 @@ class Op(object):
self._get_output_channels(), False, trace_buffer, self._get_output_channels(), False, trace_buffer,
self.model_config, self.workdir, self.thread_num, self.model_config, self.workdir, self.thread_num,
self.device_type, self.devices, self.mem_optim, self.device_type, self.devices, self.mem_optim,
self.ir_optim, self.precision)) self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list))
p.daemon = True p.daemon = True
p.start() p.start()
process.append(p) process.append(p)
...@@ -598,7 +623,9 @@ class Op(object): ...@@ -598,7 +623,9 @@ class Op(object):
self._get_output_channels(), True, trace_buffer, self._get_output_channels(), True, trace_buffer,
self.model_config, self.workdir, self.thread_num, self.model_config, self.workdir, self.thread_num,
self.device_type, self.devices, self.mem_optim, self.device_type, self.devices, self.mem_optim,
self.ir_optim, self.precision)) self.ir_optim, self.precision, self.use_mkldnn,
self.mkldnn_cache_capacity, self.mkldnn_op_list,
self.mkldnn_bf16_op_list))
# When a process exits, it attempts to terminate # When a process exits, it attempts to terminate
# all of its daemonic child processes. # all of its daemonic child processes.
t.daemon = True t.daemon = True
...@@ -1068,7 +1095,8 @@ class Op(object): ...@@ -1068,7 +1095,8 @@ class Op(object):
def _run(self, concurrency_idx, input_channel, output_channels, def _run(self, concurrency_idx, input_channel, output_channels,
is_thread_op, trace_buffer, model_config, workdir, thread_num, is_thread_op, trace_buffer, model_config, workdir, thread_num,
device_type, devices, mem_optim, ir_optim, precision): device_type, devices, mem_optim, ir_optim, precision, use_mkldnn,
mkldnn_cache_capacity, mkldnn_op_list, mkldnn_bf16_op_list):
""" """
_run() is the entry function of OP process / thread model.When client _run() is the entry function of OP process / thread model.When client
type is local_predictor in process mode, the CUDA environment needs to type is local_predictor in process mode, the CUDA environment needs to
...@@ -1090,7 +1118,11 @@ class Op(object): ...@@ -1090,7 +1118,11 @@ class Op(object):
devices: gpu id list[gpu], "" default[cpu] devices: gpu id list[gpu], "" default[cpu]
mem_optim: use memory/graphics memory optimization, True default. mem_optim: use memory/graphics memory optimization, True default.
ir_optim: use calculation chart optimization, False default. ir_optim: use calculation chart optimization, False default.
precision: inference precision, e.g. "fp32", "fp16", "int8" precision: inference precision, e.g. "fp32", "fp16", "int8", "bf16"
use_mkldnn: use mkldnn, default False.
mkldnn_cache_capacity: cache capacity of mkldnn, 0 means no limit.
mkldnn_op_list: OP list optimized by mkldnn, None default.
mkldnn_bf16_op_list: OP list optimized by mkldnn bf16, None default.
Returns: Returns:
None None
...@@ -1110,7 +1142,11 @@ class Op(object): ...@@ -1110,7 +1142,11 @@ class Op(object):
devices=devices, devices=devices,
mem_optim=mem_optim, mem_optim=mem_optim,
ir_optim=ir_optim, ir_optim=ir_optim,
precision=precision) precision=precision,
use_mkldnn=use_mkldnn,
mkldnn_cache_capacity=mkldnn_cache_capacity,
mkldnn_op_list=mkldnn_op_list,
mkldnn_bf16_op_list=mkldnn_bf16_op_list)
_LOGGER.info("Init cuda env in process {}".format( _LOGGER.info("Init cuda env in process {}".format(
concurrency_idx)) concurrency_idx))
......
...@@ -239,6 +239,8 @@ class PipelineServer(object): ...@@ -239,6 +239,8 @@ class PipelineServer(object):
"ir_optim": False, "ir_optim": False,
"precision": "fp32", "precision": "fp32",
"use_calib": False, "use_calib": False,
"use_mkldnn": False,
"mkldnn_cache_capacity": 0,
}, },
} }
for op in self._used_op: for op in self._used_op:
...@@ -397,6 +399,8 @@ class ServerYamlConfChecker(object): ...@@ -397,6 +399,8 @@ class ServerYamlConfChecker(object):
"ir_optim": False, "ir_optim": False,
"precision": "fp32", "precision": "fp32",
"use_calib": False, "use_calib": False,
"use_mkldnn": False,
"mkldnn_cache_capacity": 0,
} }
conf_type = { conf_type = {
"model_config": str, "model_config": str,
...@@ -408,6 +412,10 @@ class ServerYamlConfChecker(object): ...@@ -408,6 +412,10 @@ class ServerYamlConfChecker(object):
"ir_optim": bool, "ir_optim": bool,
"precision": str, "precision": str,
"use_calib": bool, "use_calib": bool,
"use_mkldnn": bool,
"mkldnn_cache_capacity": int,
"mkldnn_op_list": list,
"mkldnn_bf16_op_list": list,
} }
conf_qualification = {"thread_num": (">=", 1), } conf_qualification = {"thread_num": (">=", 1), }
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type, ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册