From 36508885e8eac7e675e23de64f3f4360c72c5624 Mon Sep 17 00:00:00 2001 From: TeslaZhao Date: Mon, 7 Dec 2020 17:24:12 +0800 Subject: [PATCH] Predicting on multi-devices in local predictor --- python/pipeline/local_service_handler.py | 23 +++++++++++++++++++---- python/pipeline/operator.py | 5 +++-- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/python/pipeline/local_service_handler.py b/python/pipeline/local_service_handler.py index 26450e28..4682cd65 100644 --- a/python/pipeline/local_service_handler.py +++ b/python/pipeline/local_service_handler.py @@ -105,18 +105,33 @@ class LocalServiceHandler(object): def get_port_list(self): return self._port_list - def get_client(self): + def get_client(self, concurrency_idx): """ Function get_client is only used for local predictor case, creates one LocalPredictor object, and initializes the paddle predictor by function - load_model_config. + load_model_config.The concurrency_idx is used to select running devices. Args: - None + concurrency_idx: process/thread index Returns: _local_predictor_client """ + + #checking the legality of concurrency_idx. + device_num = len(self._devices) + if device_num <= 0: + _LOGGER.error("device_num must be not greater than 0. devices({})". + format(self._devices)) + raise ValueError("The number of self._devices error") + + if concurrency_idx <= 0: + _LOGGER.error("concurrency_idx({}) must be one positive number". + format(concurrency_idx)) + concurrency_idx = 0 + elif concurrency_idx >= device_num: + concurrency_idx = concurrency_idx % device_num + from paddle_serving_app.local_predict import LocalPredictor if self._local_predictor_client is None: self._local_predictor_client = LocalPredictor() @@ -126,7 +141,7 @@ class LocalServiceHandler(object): self._local_predictor_client.load_model_config( model_path=self._model_config, use_gpu=use_gpu, - gpu_id=self._devices[0], + gpu_id=self._devices[concurrency_idx], use_profile=self._use_profile, thread_num=self._thread_num, mem_optim=self._mem_optim, diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 92e0c0c6..45763da9 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -574,7 +574,7 @@ class Op(object): #Init cuda env in main thread if self.client_type == "local_predictor": _LOGGER.info("Init cuda env in main thread") - self.local_predictor = self._local_service_handler.get_client() + self.local_predictor = self._local_service_handler.get_client(0) threads = [] for concurrency_idx in range(self.concurrency): @@ -1034,7 +1034,8 @@ class Op(object): _LOGGER.info("Init cuda env in process {}".format( concurrency_idx)) - self.local_predictor = self.service_handler.get_client() + self.local_predictor = self.service_handler.get_client( + concurrency_idx) # check all ops initialized successfully. profiler = self._initialize(is_thread_op, concurrency_idx) -- GitLab