未验证 提交 f7ee7b16 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #907 from TeslaZhao/develop

Predicting on multi-devices in local predictor
# IMDB model ensemble examples
## Get models
```
sh get_data.sh
```
## Start servers
```
python -m paddle_serving_server.serve --model imdb_cnn_model --port 9292 &> cnn.log &
python -m paddle_serving_server.serve --model imdb_bow_model --port 9393 &> bow.log &
python test_pipeline_server.py &>pipeline.log &
```
## Start clients
```
python test_pipeline_client.py
```
......@@ -17,6 +17,7 @@ except ImportError:
from paddle_serving_server.web_service import WebService, Op
import logging
import numpy as np
import sys
_LOGGER = logging.getLogger()
......@@ -31,11 +32,18 @@ class UciOp(Op):
log_id, input_dict))
x_value = input_dict["x"]
proc_dict = {}
if isinstance(x_value, (str, unicode)):
input_dict["x"] = np.array(
[float(x.strip())
for x in x_value.split(self.separator)]).reshape(1, 13)
_LOGGER.error("input_dict:{}".format(input_dict))
if sys.version_info.major == 2:
if isinstance(x_value, (str, unicode)):
input_dict["x"] = np.array(
[float(x.strip())
for x in x_value.split(self.separator)]).reshape(1, 13)
_LOGGER.error("input_dict:{}".format(input_dict))
else:
if isinstance(x_value, str):
input_dict["x"] = np.array(
[float(x.strip())
for x in x_value.split(self.separator)]).reshape(1, 13)
_LOGGER.error("input_dict:{}".format(input_dict))
return input_dict, False, None, ""
......
......@@ -312,7 +312,7 @@ class OpAnalyst(object):
# reduce op times
op_times = {
op_name: sum(step_times.values())
op_name: sum(list(step_times.values()))
for op_name, step_times in op_times.items()
}
......
......@@ -429,9 +429,12 @@ class ProcessChannel(object):
self._cv.wait()
if self._stop.value == 1:
raise ChannelStopError()
_LOGGER.debug(
self._log("(data_id={} log_id={}) Op({}) Got data".format(
resp.values()[0].id, resp.values()[0].log_id, op_name)))
if resp is not None:
list_values = list(resp.values())
_LOGGER.debug(
self._log("(data_id={} log_id={}) Op({}) Got data".format(
list_values[0].id, list_values[0].log_id, op_name)))
return resp
elif op_name is None:
_LOGGER.critical(
......@@ -458,11 +461,12 @@ class ProcessChannel(object):
try:
channeldata = self._que.get(timeout=0)
self._output_buf.append(channeldata)
list_values = list(channeldata.values())
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Pop ready item into output_buffer".
format(channeldata.values()[0].id,
channeldata.values()[0].log_id, op_name)))
format(list_values[0].id, list_values[0].log_id,
op_name)))
break
except Queue.Empty:
if timeout is not None:
......@@ -513,10 +517,12 @@ class ProcessChannel(object):
self._cv.notify_all()
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Got data from output_buffer".
format(resp.values()[0].id, resp.values()[0].log_id, op_name)))
if resp is not None:
list_values = list(resp.values())
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Got data from output_buffer".
format(list_values[0].id, list_values[0].log_id, op_name)))
return resp
def stop(self):
......@@ -726,9 +732,11 @@ class ThreadChannel(Queue.PriorityQueue):
self._cv.wait()
if self._stop:
raise ChannelStopError()
_LOGGER.debug(
self._log("(data_id={} log_id={}) Op({}) Got data".format(
resp.values()[0].id, resp.values()[0].log_id, op_name)))
if resp is not None:
list_values = list(resp.values())
_LOGGER.debug(
self._log("(data_id={} log_id={}) Op({}) Got data".format(
list_values[0].id, list_values[0].log_id, op_name)))
return resp
elif op_name is None:
_LOGGER.critical(
......@@ -755,11 +763,12 @@ class ThreadChannel(Queue.PriorityQueue):
try:
channeldata = self.get(timeout=0)
self._output_buf.append(channeldata)
list_values = list(channeldata.values())
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Pop ready item into output_buffer".
format(channeldata.values()[0].id,
channeldata.values()[0].log_id, op_name)))
format(list_values[0].id, list_values[0].log_id,
op_name)))
break
except Queue.Empty:
if timeout is not None:
......@@ -810,10 +819,12 @@ class ThreadChannel(Queue.PriorityQueue):
self._cv.notify_all()
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Got data from output_buffer".
format(resp.values()[0].id, resp.values()[0].log_id, op_name)))
if resp is not None:
list_values = list(resp.values())
_LOGGER.debug(
self._log(
"(data_id={} log_id={}) Op({}) Got data from output_buffer".
format(list_values[0].id, list_values[0].log_id, op_name)))
return resp
def stop(self):
......
......@@ -105,18 +105,35 @@ class LocalServiceHandler(object):
def get_port_list(self):
return self._port_list
def get_client(self):
def get_client(self, concurrency_idx):
"""
Function get_client is only used for local predictor case, creates one
LocalPredictor object, and initializes the paddle predictor by function
load_model_config.
load_model_config.The concurrency_idx is used to select running devices.
Args:
None
concurrency_idx: process/thread index
Returns:
_local_predictor_client
"""
#checking the legality of concurrency_idx.
device_num = len(self._devices)
if device_num <= 0:
_LOGGER.error("device_num must be not greater than 0. devices({})".
format(self._devices))
raise ValueError("The number of self._devices error")
if concurrency_idx < 0:
_LOGGER.error("concurrency_idx({}) must be one positive number".
format(concurrency_idx))
concurrency_idx = 0
elif concurrency_idx >= device_num:
concurrency_idx = concurrency_idx % device_num
_LOGGER.info("GET_CLIENT : concurrency_idx={}, device_num={}".format(
concurrency_idx, device_num))
from paddle_serving_app.local_predict import LocalPredictor
if self._local_predictor_client is None:
self._local_predictor_client = LocalPredictor()
......@@ -126,7 +143,7 @@ class LocalServiceHandler(object):
self._local_predictor_client.load_model_config(
model_path=self._model_config,
use_gpu=use_gpu,
gpu_id=self._devices[0],
gpu_id=self._devices[concurrency_idx],
use_profile=self._use_profile,
thread_num=self._thread_num,
mem_optim=self._mem_optim,
......
......@@ -55,7 +55,7 @@ class Op(object):
client_type=None,
concurrency=None,
timeout=None,
retry=None,
retry=0,
batch_size=None,
auto_batching_timeout=None,
local_service_handler=None):
......@@ -574,7 +574,7 @@ class Op(object):
#Init cuda env in main thread
if self.client_type == "local_predictor":
_LOGGER.info("Init cuda env in main thread")
self.local_predictor = self._local_service_handler.get_client()
self.local_predictor = self._local_service_handler.get_client(0)
threads = []
for concurrency_idx in range(self.concurrency):
......@@ -679,7 +679,7 @@ class Op(object):
err_channeldata_dict = collections.OrderedDict()
### if (batch_num == 1 && skip == True) ,then skip the process stage.
is_skip_process = False
data_ids = preped_data_dict.keys()
data_ids = list(preped_data_dict.keys())
if len(data_ids) == 1 and skip_process_dict.get(data_ids[0]) == True:
is_skip_process = True
_LOGGER.info("(data_id={} log_id={}) skip process stage".format(
......@@ -1034,7 +1034,8 @@ class Op(object):
_LOGGER.info("Init cuda env in process {}".format(
concurrency_idx))
self.local_predictor = self.service_handler.get_client()
self.local_predictor = self.service_handler.get_client(
concurrency_idx)
# check all ops initialized successfully.
profiler = self._initialize(is_thread_op, concurrency_idx)
......
......@@ -53,10 +53,10 @@ class PipelineClient(object):
if logid is None:
req.logid = 0
else:
if six.PY2:
if sys.version_info.major == 2:
req.logid = long(logid)
elif six.PY3:
req.logid = int(log_id)
elif sys.version_info.major == 3:
req.logid = int(logid)
feed_dict.pop("logid")
clientip = feed_dict.get("clientip")
......@@ -71,10 +71,15 @@ class PipelineClient(object):
np.set_printoptions(threshold=sys.maxsize)
for key, value in feed_dict.items():
req.key.append(key)
if (sys.version_info.major == 2 and isinstance(value,
(str, unicode)) or
((sys.version_info.major == 3) and isinstance(value, str))):
req.value.append(value)
continue
if isinstance(value, np.ndarray):
req.value.append(value.__repr__())
elif isinstance(value, (str, unicode)):
req.value.append(value)
elif isinstance(value, list):
req.value.append(np.array(value).__repr__())
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册