提交 36508885 编写于 作者: T TeslaZhao

Predicting on multi-devices in local predictor

上级 4d6feecc
...@@ -105,18 +105,33 @@ class LocalServiceHandler(object): ...@@ -105,18 +105,33 @@ class LocalServiceHandler(object):
def get_port_list(self): def get_port_list(self):
return self._port_list return self._port_list
def get_client(self): def get_client(self, concurrency_idx):
""" """
Function get_client is only used for local predictor case, creates one Function get_client is only used for local predictor case, creates one
LocalPredictor object, and initializes the paddle predictor by function LocalPredictor object, and initializes the paddle predictor by function
load_model_config. load_model_config.The concurrency_idx is used to select running devices.
Args: Args:
None concurrency_idx: process/thread index
Returns: Returns:
_local_predictor_client _local_predictor_client
""" """
#checking the legality of concurrency_idx.
device_num = len(self._devices)
if device_num <= 0:
_LOGGER.error("device_num must be not greater than 0. devices({})".
format(self._devices))
raise ValueError("The number of self._devices error")
if concurrency_idx <= 0:
_LOGGER.error("concurrency_idx({}) must be one positive number".
format(concurrency_idx))
concurrency_idx = 0
elif concurrency_idx >= device_num:
concurrency_idx = concurrency_idx % device_num
from paddle_serving_app.local_predict import LocalPredictor from paddle_serving_app.local_predict import LocalPredictor
if self._local_predictor_client is None: if self._local_predictor_client is None:
self._local_predictor_client = LocalPredictor() self._local_predictor_client = LocalPredictor()
...@@ -126,7 +141,7 @@ class LocalServiceHandler(object): ...@@ -126,7 +141,7 @@ class LocalServiceHandler(object):
self._local_predictor_client.load_model_config( self._local_predictor_client.load_model_config(
model_path=self._model_config, model_path=self._model_config,
use_gpu=use_gpu, use_gpu=use_gpu,
gpu_id=self._devices[0], gpu_id=self._devices[concurrency_idx],
use_profile=self._use_profile, use_profile=self._use_profile,
thread_num=self._thread_num, thread_num=self._thread_num,
mem_optim=self._mem_optim, mem_optim=self._mem_optim,
......
...@@ -574,7 +574,7 @@ class Op(object): ...@@ -574,7 +574,7 @@ class Op(object):
#Init cuda env in main thread #Init cuda env in main thread
if self.client_type == "local_predictor": if self.client_type == "local_predictor":
_LOGGER.info("Init cuda env in main thread") _LOGGER.info("Init cuda env in main thread")
self.local_predictor = self._local_service_handler.get_client() self.local_predictor = self._local_service_handler.get_client(0)
threads = [] threads = []
for concurrency_idx in range(self.concurrency): for concurrency_idx in range(self.concurrency):
...@@ -1034,7 +1034,8 @@ class Op(object): ...@@ -1034,7 +1034,8 @@ class Op(object):
_LOGGER.info("Init cuda env in process {}".format( _LOGGER.info("Init cuda env in process {}".format(
concurrency_idx)) concurrency_idx))
self.local_predictor = self.service_handler.get_client() self.local_predictor = self.service_handler.get_client(
concurrency_idx)
# check all ops initialized successfully. # check all ops initialized successfully.
profiler = self._initialize(is_thread_op, concurrency_idx) profiler = self._initialize(is_thread_op, concurrency_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册