提交 af15850c 编写于 作者: B barrierye

support different client requests

上级 a8635745
......@@ -18,9 +18,15 @@ from pyserver import Channel
from pyserver import PyServer
import numpy as np
import python_service_channel_pb2
import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.INFO)
# channel data: {name(str): data(bytes)}
"""
class ImdbOp(Op):
def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData()
......@@ -28,36 +34,44 @@ class ImdbOp(Op):
pred = np.array(output_data["prediction"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
inst.id = 0 #TODO
data.insts.append(inst)
return data
"""
class CombineOp(Op):
def preprocess(self, input_data):
data_id = None
cnt = 0
for input in input_data:
data = input[0] # batchsize=1
cnt += np.frombuffer(data.insts[0].data, dtype='float')
if data_id is None:
data_id = data.id
if data_id != data.id:
raise Exception("id not match: {} vs {}".format(data_id,
data.id))
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
inst.data = np.ndarray.tobytes(cnt)
inst.name = "resp"
inst.id = 0 #TODO
data.insts.append(inst)
data.id = data_id
print(data)
return data
class UciOp(Op):
def postprocess(self, output_data):
data_ids = self.get_data_ids()
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["price"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
data.insts.append(inst)
data.id = data_ids[0]
return data
......@@ -121,5 +135,5 @@ pyserver.add_channel(combine_out_channel)
pyserver.add_op(cnn_op)
pyserver.add_op(bow_op)
pyserver.add_op(combine_op)
pyserver.prepare_server(port=8080, worker_num=1)
pyserver.prepare_server(port=8080, worker_num=2)
pyserver.run_server()
......@@ -24,6 +24,8 @@ import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
import python_service_channel_pb2
import logging
import time
class Channel(Queue.Queue):
......@@ -39,10 +41,12 @@ class Channel(Queue.Queue):
self._pushbatch = []
self._frontbatch = None
self._count = 0
self._order = 0
def push(self, item):
with self._pushlock:
self._pushbatch.append(item)
self._order += 1
if len(self._pushbatch) == self._batchsize:
self.put(self._pushbatch, timeout=self._timeout)
# self.put(self._pushbatch)
......@@ -87,6 +91,7 @@ class Op(object):
self._server_model = server_model
self._server_port = server_port
self._device = device
self._data_ids = []
def set_client(self, client_config, server_name, fetch_names):
self._client = Client()
......@@ -113,12 +118,22 @@ class Op(object):
raise TypeError('channels must be list type')
self._outputs = channels
def get_data_ids(self):
return self._data_ids
def clear_data_ids(self):
self._data_ids = []
def append_id_to_data_ids(self, data_id):
self._data_ids.append(data_id)
def preprocess(self, input_data):
if len(input_data) != 1:
raise Exception(
'this Op has multiple previous channels. Please override this method'
)
feed_batch = []
self.clear_data_ids()
for data in input_data:
if len(data.insts) != self._batch_size:
raise Exception('len(data_insts) != self._batch_size')
......@@ -126,12 +141,13 @@ class Op(object):
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype=self._in_dtype)
feed_batch.append(feed)
self.append_id_to_data_ids(data.id)
return feed_batch
def midprocess(self, data):
# data = preprocess(input), which must be a dict
print('data: {}'.format(data))
print('fetch: {}'.format(self._fetch_names))
logging.debug('data: {}'.format(data))
logging.debug('fetch: {}'.format(self._fetch_names))
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
return fetch_map
......@@ -168,36 +184,80 @@ class GeneralPythonService(
super(GeneralPythonService, self).__init__()
self._in_channel = in_channel
self._out_channel = out_channel
print('succ init')
def inference(self, request, context):
print('start inferce')
self._lock = threading.Lock()
self._globel_resp_dict = {}
self._id_counter = 0
self._recive_func = threading.Thread(
target=GeneralPythonService._recive_out_channel_func, args=(self, ))
self._recive_func.start()
logging.debug('succ init')
def _recive_out_channel_func(self):
while True:
data = self._out_channel.front()
data_id = None
for d in data:
if data_id is None:
data_id = d.id
if data_id != d.id:
raise Exception("id not match: {} vs {}".format(data_id,
d.id))
with self._lock:
self._globel_resp_dict[data_id] = data
#TODO wake up inference
def _get_next_id(self):
with self._lock:
self._id_counter += 1
return self._id_counter - 1
def _get_data_in_globel_resp_dict(self, data_id):
if data_id in self._globel_resp_dict:
with self._lock:
return self._globel_resp_dict.pop(data_id)
return None
def _pack_data_for_infer(self, request):
logging.debug('start inferce')
data = python_service_channel_pb2.ChannelData()
print('gen data: {}'.format(data))
data_id = self._get_next_id()
data.id = data_id
for idx, name in enumerate(request.feed_var_names):
print('name: {}'.format(request.feed_var_names[idx]))
print('data: {}'.format(request.feed_insts[idx]))
logging.debug('name: {}'.format(request.feed_var_names[idx]))
logging.debug('data: {}'.format(request.feed_insts[idx]))
inst = python_service_channel_pb2.Inst()
inst.data = request.feed_insts[idx]
inst.name = name
inst.id = 0 #TODO
data.insts.append(inst)
print('push data')
self._in_channel.push(data)
print('wait for infer')
data = self._out_channel.front()
return data, data_id
def _pack_data_for_resp(self, data):
data = data[0] #TODO batchsize = 1
print('get data')
logging.debug('get data')
resp = general_python_service_pb2.Response()
print('gen resp')
print(data)
logging.debug('gen resp')
logging.debug(data)
for inst in data.insts:
print('append data')
logging.debug('append data')
resp.fetch_insts.append(inst.data)
print('append name')
logging.debug('append name')
resp.fetch_var_names.append(inst.name)
return resp
def inference(self, request, context):
data, data_id = self._pack_data_for_infer(request)
logging.debug('push data')
self._in_channel.push(data)
logging.debug('wait for infer')
resp_data = None
while True:
resp_data = self._get_data_in_globel_resp_dict(data_id)
if resp_data is not None:
break
time.sleep(0.05) #TODO: wake up by _recive_out_channel_func
resp = self._pack_data_for_resp(resp_data)
return resp
class PyServer(object):
def __init__(self):
......@@ -216,7 +276,7 @@ class PyServer(object):
self._ops.append(op)
def gen_desc(self):
print('here will generate desc for paas')
logging.info('here will generate desc for paas')
pass
def prepare_server(self, port, worker_num):
......@@ -273,6 +333,6 @@ class PyServer(object):
else:
cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port)
print(cmd)
logging.info(cmd)
return
os.system(cmd)
......@@ -14,11 +14,13 @@
syntax = "proto2";
message ChannelData { repeated Inst insts = 1; }
message ChannelData {
repeated Inst insts = 1;
required int32 id = 2;
optional string type = 3 [ default = "channel" ];
}
message Inst {
required bytes data = 1;
required string name = 2;
required int32 id = 3;
optional string type = 4 [ default = "channel" ];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册