diff --git a/python/pipeline/local_service_handler.py b/python/pipeline/local_service_handler.py index 26450e28dcbf4f99a7149d3005faf8f93abc63f8..4682cd65a6e90bb17747ecb76d119deabcefda73 100644 --- a/python/pipeline/local_service_handler.py +++ b/python/pipeline/local_service_handler.py @@ -105,18 +105,33 @@ class LocalServiceHandler(object): def get_port_list(self): return self._port_list - def get_client(self): + def get_client(self, concurrency_idx): """ Function get_client is only used for local predictor case, creates one LocalPredictor object, and initializes the paddle predictor by function - load_model_config. + load_model_config.The concurrency_idx is used to select running devices. Args: - None + concurrency_idx: process/thread index Returns: _local_predictor_client """ + + #checking the legality of concurrency_idx. + device_num = len(self._devices) + if device_num <= 0: + _LOGGER.error("device_num must be not greater than 0. devices({})". + format(self._devices)) + raise ValueError("The number of self._devices error") + + if concurrency_idx <= 0: + _LOGGER.error("concurrency_idx({}) must be one positive number". + format(concurrency_idx)) + concurrency_idx = 0 + elif concurrency_idx >= device_num: + concurrency_idx = concurrency_idx % device_num + from paddle_serving_app.local_predict import LocalPredictor if self._local_predictor_client is None: self._local_predictor_client = LocalPredictor() @@ -126,7 +141,7 @@ class LocalServiceHandler(object): self._local_predictor_client.load_model_config( model_path=self._model_config, use_gpu=use_gpu, - gpu_id=self._devices[0], + gpu_id=self._devices[concurrency_idx], use_profile=self._use_profile, thread_num=self._thread_num, mem_optim=self._mem_optim, diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 92e0c0c6e0bb2e415f48729d25c2153d2026b6b2..45763da995790964bfdc52311e183058d962fc64 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -574,7 +574,7 @@ class Op(object): #Init cuda env in main thread if self.client_type == "local_predictor": _LOGGER.info("Init cuda env in main thread") - self.local_predictor = self._local_service_handler.get_client() + self.local_predictor = self._local_service_handler.get_client(0) threads = [] for concurrency_idx in range(self.concurrency): @@ -1034,7 +1034,8 @@ class Op(object): _LOGGER.info("Init cuda env in process {}".format( concurrency_idx)) - self.local_predictor = self.service_handler.get_client() + self.local_predictor = self.service_handler.get_client( + concurrency_idx) # check all ops initialized successfully. profiler = self._initialize(is_thread_op, concurrency_idx)