提交 3b2fa3d5 编写于 作者: B barrierye

make it work

上级 bf1007e0
......@@ -17,15 +17,34 @@ import general_python_service_pb2
import general_python_service_pb2_grpc
import numpy as np
channel = grpc.insecure_channel('127.0.0.1:8080')
channel = grpc.insecure_channel('localhost:8080')
stub = general_python_service_pb2_grpc.GeneralPythonServiceStub(channel)
# line = "i am very sad | 0"
word_ids = np.array([8, 233, 52, 601])
req = general_python_service_pb2.Request()
"""
# line = "i am very sad | 0"
word_ids = np.array([8, 233, 52, 601], dtype='int64')
# word_ids = np.array([8, 233, 52, 601])
print(word_ids)
data = np.ndarray.tobytes(word_ids)
print(data)
# xx = np.frombuffer(data)
xx = np.frombuffer(data, dtype='int64')
print (xx)
req.feed_var_names.append("words")
req.feed_insts.append(np.ndarray.tobytes(word_ids))
req.feed_insts.append(data)
"""
x = np.array(
[
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584,
0.6283, 0.4919, 0.1856, 0.0795, -0.0332
],
dtype='float')
data = np.ndarray.tobytes(x)
req.feed_var_names.append("x")
req.feed_insts.append(data)
resp = stub.inference(req)
print(resp)
for idx, name in enumerate(resp.fetch_var_names):
print('{}: {}'.format(
name, np.frombuffer(
resp.fetch_insts[idx], dtype='float')))
......@@ -16,32 +16,90 @@
from pyserver import Op
from pyserver import Channel
from pyserver import PyServer
import numpy as np
import python_service_channel_pb2
# channel data: {name(str): data(bytes)}
class ImdbOp(Op):
def preprocess(self, input_data):
x = input_data[0]['words']
feed = {"words": np.array(x)}
data = input_data[0] # batchsize=1
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='int64')
# feed[inst.name] = np.frombuffer(inst.data)
return feed
def postprocess(self, output_data):
data = {"resp": fetch_map["prediction"][0][0]}
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["prediction"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
data.insts.append(inst)
return data
class CombineOp(Op):
def preprocess(self, input_data):
cnt = 0
for data in input_data:
cnt += data['resp']
return {"resp": cnt}
for input in input_data:
data = input[0] # batchsize=1
cnt += np.frombuffer(data.insts[0].data, dtype='float')
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
inst.data = np.ndarray.tobytes(cnt)
inst.name = "resp"
inst.id = 0 #TODO
data.insts.append(inst)
print(data)
return data
class UciOp(Op):
def preprocess(self, input_data):
data = input_data[0] # batchsize=1
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='float')
# feed[inst.name] = np.frombuffer(inst.data)
return feed
def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["price"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
data.insts.append(inst)
return data
read_channel = Channel(consumer=2)
cnn_out_channel = Channel()
bow_out_channel = Channel()
combine_out_channel = Channel()
cnn_op = UciOp(
inputs=[read_channel],
outputs=[cnn_out_channel],
server_model="./uci_housing_model",
server_port="9393",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
bow_op = UciOp(
inputs=[read_channel],
outputs=[bow_out_channel],
server_model="./uci_housing_model",
server_port="9292",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
'''
cnn_op = ImdbOp(
inputs=[read_channel],
outputs=[cnn_out_channel],
......@@ -50,7 +108,7 @@ cnn_op = ImdbOp(
device="cpu",
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["prediction"])
fetch_names=["acc", "cost", "prediction"])
bow_op = ImdbOp(
inputs=[read_channel],
outputs=[bow_out_channel],
......@@ -59,7 +117,8 @@ bow_op = ImdbOp(
device="cpu",
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292",
fetch_names=["prediction"])
fetch_names=["acc", "cost", "prediction"])
'''
combine_op = CombineOp(
inputs=[cnn_out_channel, bow_out_channel], outputs=[combine_out_channel])
......
......@@ -27,7 +27,7 @@ import python_service_channel_pb2
class Channel(Queue.Queue):
def __init__(self, consumer=1, maxsize=0, timeout=None, batchsize=1):
def __init__(self, consumer=1, maxsize=-1, timeout=None, batchsize=1):
Queue.Queue.__init__(self, maxsize=maxsize)
# super(Channel, self).__init__(maxsize=maxsize)
self._maxsize = maxsize
......@@ -42,10 +42,11 @@ class Channel(Queue.Queue):
def push(self, item):
with self._pushlock:
if len(self._pushbatch) == batchsize:
self._pushbatch.append(item)
if len(self._pushbatch) == self._batchsize:
self.put(self._pushbatch, timeout=self._timeout)
# self.put(self._pushbatch)
self._pushbatch = []
self._pushbatch.append(item)
def front(self):
if self._consumer == 1:
......@@ -111,6 +112,8 @@ class Op(object):
def midprocess(self, data):
# data = preprocess(input), which must be a dict
print('data: {}'.format(data))
print('fetch: {}'.format(self._fetch_names))
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
return fetch_map
......@@ -126,7 +129,10 @@ class Op(object):
input_data = []
for channel in self._inputs:
input_data.append(channel.front())
if len(input_data) > 1:
data = self.preprocess(input_data)
else:
data = self.preprocess(input_data[0])
if self.with_serving():
fetch_map = self.midprocess(data)
......@@ -141,22 +147,36 @@ class Op(object):
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel):
super(GeneralPythonService, self).__init__()
self._in_channel = in_channel
self._out_channel = out_channel
print('succ init')
def inference(self, request, context):
print('start inferce')
data = python_service_channel_pb2.ChannelData()
print('gen data: {}'.format(data))
for idx, name in enumerate(request.feed_var_names):
print('name: {}'.format(request.feed_var_names[idx]))
print('data: {}'.format(request.feed_insts[idx]))
inst = python_service_channel_pb2.Inst()
inst.data = request.feed_insts(idx)
inst.data = request.feed_insts[idx]
inst.name = name
inst.id = 0 #TODO
data.insts.append(inst)
print('push data')
self._in_channel.push(data)
print('wait for infer')
data = self._out_channel.front()
data = data[0] #TODO batchsize = 1
print('get data')
resp = general_python_service_pb2.Response()
print('gen resp')
print(data)
for inst in data.insts:
resp.fetch_data.append(inst.data)
print('append data')
resp.fetch_insts.append(inst.data)
print('append name')
resp.fetch_var_names.append(inst.name)
return resp
......@@ -206,18 +226,20 @@ class PyServer(object):
def run_server(self):
for op in self._ops:
th = multiprocessing.Process(
target=self.op_start_wrapper, args=(op, ))
# th = multiprocessing.Process(target=self.op_start_wrapper, args=(op, ))
th = threading.Thread(target=self.op_start_wrapper, args=(op, ))
th.start()
self._op_threads.append(th)
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
GeneralPythonService(self._in_channel, self._out_channel), server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start()
try:
for th in self._op_threads:
th.join()
server.join()
except KeyboardInterrupt:
server.stop(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册