提交 3c525182 编写于 作者: Z zhangjun

add local predict support

上级 3378a130
......@@ -57,6 +57,8 @@ class LocalPredictor(object):
mem_optim=True,
ir_optim=False,
use_trt=False,
use_lite=False,
use_xpu=False,
use_feed_fetch_ops=False):
"""
Load model config and set the engine config for the paddle predictor
......@@ -70,6 +72,8 @@ class LocalPredictor(object):
mem_optim: memory optimization, True default.
ir_optim: open calculation chart optimization, False default.
use_trt: use nvidia TensorRT optimization, False default
use_lite: use Paddle-Lite Engint, False default
use_xpu: run predict on Baidu Kunlun, False default
use_feed_fetch_ops: use feed/fetch ops, False default.
"""
client_config = "{}/serving_server_conf.prototxt".format(model_path)
......@@ -80,9 +84,9 @@ class LocalPredictor(object):
config = AnalysisConfig(model_path)
logger.info("load_model_config params: model_path:{}, use_gpu:{},\
gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
use_trt:{}, use_feed_fetch_ops:{}".format(
use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format(
model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim,
ir_optim, use_trt, use_feed_fetch_ops))
ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
......@@ -119,6 +123,17 @@ class LocalPredictor(object):
use_static=False,
use_calib_mode=False)
if use_lite:
config.enable_lite_engine(
precision_mode = PrecisionType.Float32,
zero_copy = True,
passes_filter = [],
ops_filter = []
)
if use_xpu:
config.enable_xpu(100 * 1024 * 1024)
self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None, batch=False, log_id=0):
......
......@@ -44,6 +44,8 @@ class LocalServiceHandler(object):
ir_optim=False,
available_port_generator=None,
use_trt=False,
use_lite=False,
use_xpu=False,
use_profile=False):
"""
Initialization of localservicehandler
......@@ -60,6 +62,8 @@ class LocalServiceHandler(object):
ir_optim: use calculation chart optimization, False default.
available_port_generator: generate available ports
use_trt: use nvidia tensorRt engine, False default.
use_lite: use Paddle-Lite engine, False default.
use_xpu: run predict on Baidu Kunlun, False default.
use_profile: use profiling, False default.
Returns:
......@@ -74,10 +78,16 @@ class LocalServiceHandler(object):
if devices == "":
# cpu
devices = [-1]
self._device_type = "cpu"
self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in cpu device. Port({})"
.format(model_config, self._port_list))
if use_lite:
self._device_type = "arm"
self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in arm device. Port({})"
.format(model_config, self._port_list))
else:
self._device_type = "cpu"
self._port_list.append(available_port_generator.next())
_LOGGER.info("Model({}) will be launch in cpu device. Port({})"
.format(model_config, self._port_list))
else:
# gpu
self._device_type = "gpu"
......@@ -96,6 +106,8 @@ class LocalServiceHandler(object):
self._rpc_service_list = []
self._server_pros = []
self._use_trt = use_trt
self._use_lite = use_lite
self._use_xpu = use_xpu
self._use_profile = use_profile
self.fetch_names_ = fetch_names
......@@ -138,8 +150,11 @@ class LocalServiceHandler(object):
if self._local_predictor_client is None:
self._local_predictor_client = LocalPredictor()
use_gpu = False
use_lite = False
if self._device_type == "gpu":
use_gpu = True
elif self._device_type == "arm":
use_lite = True
self._local_predictor_client.load_model_config(
model_path=self._model_config,
use_gpu=use_gpu,
......@@ -148,7 +163,9 @@ class LocalServiceHandler(object):
thread_num=self._thread_num,
mem_optim=self._mem_optim,
ir_optim=self._ir_optim,
use_trt=self._use_trt)
use_trt=self._use_trt,
use_lite=use_lite,
use_xpu=self._use_xpu)
return self._local_predictor_client
def get_client_config(self):
......@@ -185,7 +202,7 @@ class LocalServiceHandler(object):
server = Server()
else:
#gpu
#gpu or arm
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册