提交 36508885 编写于 作者: T TeslaZhao

Predicting on multi-devices in local predictor

上级 4d6feecc
......@@ -105,18 +105,33 @@ class LocalServiceHandler(object):
def get_port_list(self):
return self._port_list
def get_client(self):
def get_client(self, concurrency_idx):
"""
Function get_client is only used for local predictor case, creates one
LocalPredictor object, and initializes the paddle predictor by function
load_model_config.
load_model_config.The concurrency_idx is used to select running devices.
Args:
None
concurrency_idx: process/thread index
Returns:
_local_predictor_client
"""
#checking the legality of concurrency_idx.
device_num = len(self._devices)
if device_num <= 0:
_LOGGER.error("device_num must be not greater than 0. devices({})".
format(self._devices))
raise ValueError("The number of self._devices error")
if concurrency_idx <= 0:
_LOGGER.error("concurrency_idx({}) must be one positive number".
format(concurrency_idx))
concurrency_idx = 0
elif concurrency_idx >= device_num:
concurrency_idx = concurrency_idx % device_num
from paddle_serving_app.local_predict import LocalPredictor
if self._local_predictor_client is None:
self._local_predictor_client = LocalPredictor()
......@@ -126,7 +141,7 @@ class LocalServiceHandler(object):
self._local_predictor_client.load_model_config(
model_path=self._model_config,
use_gpu=use_gpu,
gpu_id=self._devices[0],
gpu_id=self._devices[concurrency_idx],
use_profile=self._use_profile,
thread_num=self._thread_num,
mem_optim=self._mem_optim,
......
......@@ -574,7 +574,7 @@ class Op(object):
#Init cuda env in main thread
if self.client_type == "local_predictor":
_LOGGER.info("Init cuda env in main thread")
self.local_predictor = self._local_service_handler.get_client()
self.local_predictor = self._local_service_handler.get_client(0)
threads = []
for concurrency_idx in range(self.concurrency):
......@@ -1034,7 +1034,8 @@ class Op(object):
_LOGGER.info("Init cuda env in process {}".format(
concurrency_idx))
self.local_predictor = self.service_handler.get_client()
self.local_predictor = self.service_handler.get_client(
concurrency_idx)
# check all ops initialized successfully.
profiler = self._initialize(is_thread_op, concurrency_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册