提交 3b2fa3d5 编写于 作者: B barrierye

make it work

上级 bf1007e0
...@@ -17,15 +17,34 @@ import general_python_service_pb2 ...@@ -17,15 +17,34 @@ import general_python_service_pb2
import general_python_service_pb2_grpc import general_python_service_pb2_grpc
import numpy as np import numpy as np
channel = grpc.insecure_channel('127.0.0.1:8080') channel = grpc.insecure_channel('localhost:8080')
stub = general_python_service_pb2_grpc.GeneralPythonServiceStub(channel) stub = general_python_service_pb2_grpc.GeneralPythonServiceStub(channel)
# line = "i am very sad | 0"
word_ids = np.array([8, 233, 52, 601])
req = general_python_service_pb2.Request() req = general_python_service_pb2.Request()
"""
# line = "i am very sad | 0"
word_ids = np.array([8, 233, 52, 601], dtype='int64')
# word_ids = np.array([8, 233, 52, 601])
print(word_ids)
data = np.ndarray.tobytes(word_ids)
print(data)
# xx = np.frombuffer(data)
xx = np.frombuffer(data, dtype='int64')
print (xx)
req.feed_var_names.append("words") req.feed_var_names.append("words")
req.feed_insts.append(np.ndarray.tobytes(word_ids)) req.feed_insts.append(data)
"""
x = np.array(
[
0.0137, -0.1136, 0.2553, -0.0692, 0.0582, -0.0727, -0.1583, -0.0584,
0.6283, 0.4919, 0.1856, 0.0795, -0.0332
],
dtype='float')
data = np.ndarray.tobytes(x)
req.feed_var_names.append("x")
req.feed_insts.append(data)
resp = stub.inference(req) resp = stub.inference(req)
print(resp) for idx, name in enumerate(resp.fetch_var_names):
print('{}: {}'.format(
name, np.frombuffer(
resp.fetch_insts[idx], dtype='float')))
...@@ -16,32 +16,90 @@ ...@@ -16,32 +16,90 @@
from pyserver import Op from pyserver import Op
from pyserver import Channel from pyserver import Channel
from pyserver import PyServer from pyserver import PyServer
import numpy as np
import python_service_channel_pb2
# channel data: {name(str): data(bytes)} # channel data: {name(str): data(bytes)}
class ImdbOp(Op): class ImdbOp(Op):
def preprocess(self, input_data): def preprocess(self, input_data):
x = input_data[0]['words'] data = input_data[0] # batchsize=1
feed = {"words": np.array(x)} feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='int64')
# feed[inst.name] = np.frombuffer(inst.data)
return feed return feed
def postprocess(self, output_data): def postprocess(self, output_data):
data = {"resp": fetch_map["prediction"][0][0]} data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["prediction"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
data.insts.append(inst)
return data return data
class CombineOp(Op): class CombineOp(Op):
def preprocess(self, input_data): def preprocess(self, input_data):
cnt = 0 cnt = 0
for data in input_data: for input in input_data:
cnt += data['resp'] data = input[0] # batchsize=1
return {"resp": cnt} cnt += np.frombuffer(data.insts[0].data, dtype='float')
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
inst.data = np.ndarray.tobytes(cnt)
inst.name = "resp"
inst.id = 0 #TODO
data.insts.append(inst)
print(data)
return data
class UciOp(Op):
def preprocess(self, input_data):
data = input_data[0] # batchsize=1
feed = {}
for inst in data.insts:
feed[inst.name] = np.frombuffer(inst.data, dtype='float')
# feed[inst.name] = np.frombuffer(inst.data)
return feed
def postprocess(self, output_data):
data = python_service_channel_pb2.ChannelData()
inst = python_service_channel_pb2.Inst()
pred = np.array(output_data["price"][0][0], dtype='float')
inst.data = np.ndarray.tobytes(pred)
inst.name = "prediction"
inst.id = 0 #TODO
data.insts.append(inst)
return data
read_channel = Channel(consumer=2) read_channel = Channel(consumer=2)
cnn_out_channel = Channel() cnn_out_channel = Channel()
bow_out_channel = Channel() bow_out_channel = Channel()
combine_out_channel = Channel() combine_out_channel = Channel()
cnn_op = UciOp(
inputs=[read_channel],
outputs=[cnn_out_channel],
server_model="./uci_housing_model",
server_port="9393",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
bow_op = UciOp(
inputs=[read_channel],
outputs=[bow_out_channel],
server_model="./uci_housing_model",
server_port="9292",
device="cpu",
client_config="uci_housing_client/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["price"])
'''
cnn_op = ImdbOp( cnn_op = ImdbOp(
inputs=[read_channel], inputs=[read_channel],
outputs=[cnn_out_channel], outputs=[cnn_out_channel],
...@@ -50,7 +108,7 @@ cnn_op = ImdbOp( ...@@ -50,7 +108,7 @@ cnn_op = ImdbOp(
device="cpu", device="cpu",
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt", client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393", server_name="127.0.0.1:9393",
fetch_names=["prediction"]) fetch_names=["acc", "cost", "prediction"])
bow_op = ImdbOp( bow_op = ImdbOp(
inputs=[read_channel], inputs=[read_channel],
outputs=[bow_out_channel], outputs=[bow_out_channel],
...@@ -59,7 +117,8 @@ bow_op = ImdbOp( ...@@ -59,7 +117,8 @@ bow_op = ImdbOp(
device="cpu", device="cpu",
client_config="imdb_bow_client_conf/serving_client_conf.prototxt", client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292", server_name="127.0.0.1:9292",
fetch_names=["prediction"]) fetch_names=["acc", "cost", "prediction"])
'''
combine_op = CombineOp( combine_op = CombineOp(
inputs=[cnn_out_channel, bow_out_channel], outputs=[combine_out_channel]) inputs=[cnn_out_channel, bow_out_channel], outputs=[combine_out_channel])
......
...@@ -27,7 +27,7 @@ import python_service_channel_pb2 ...@@ -27,7 +27,7 @@ import python_service_channel_pb2
class Channel(Queue.Queue): class Channel(Queue.Queue):
def __init__(self, consumer=1, maxsize=0, timeout=None, batchsize=1): def __init__(self, consumer=1, maxsize=-1, timeout=None, batchsize=1):
Queue.Queue.__init__(self, maxsize=maxsize) Queue.Queue.__init__(self, maxsize=maxsize)
# super(Channel, self).__init__(maxsize=maxsize) # super(Channel, self).__init__(maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
...@@ -42,10 +42,11 @@ class Channel(Queue.Queue): ...@@ -42,10 +42,11 @@ class Channel(Queue.Queue):
def push(self, item): def push(self, item):
with self._pushlock: with self._pushlock:
if len(self._pushbatch) == batchsize: self._pushbatch.append(item)
if len(self._pushbatch) == self._batchsize:
self.put(self._pushbatch, timeout=self._timeout) self.put(self._pushbatch, timeout=self._timeout)
# self.put(self._pushbatch)
self._pushbatch = [] self._pushbatch = []
self._pushbatch.append(item)
def front(self): def front(self):
if self._consumer == 1: if self._consumer == 1:
...@@ -111,6 +112,8 @@ class Op(object): ...@@ -111,6 +112,8 @@ class Op(object):
def midprocess(self, data): def midprocess(self, data):
# data = preprocess(input), which must be a dict # data = preprocess(input), which must be a dict
print('data: {}'.format(data))
print('fetch: {}'.format(self._fetch_names))
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names) fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
return fetch_map return fetch_map
...@@ -126,7 +129,10 @@ class Op(object): ...@@ -126,7 +129,10 @@ class Op(object):
input_data = [] input_data = []
for channel in self._inputs: for channel in self._inputs:
input_data.append(channel.front()) input_data.append(channel.front())
data = self.preprocess(input_data) if len(input_data) > 1:
data = self.preprocess(input_data)
else:
data = self.preprocess(input_data[0])
if self.with_serving(): if self.with_serving():
fetch_map = self.midprocess(data) fetch_map = self.midprocess(data)
...@@ -141,22 +147,36 @@ class Op(object): ...@@ -141,22 +147,36 @@ class Op(object):
class GeneralPythonService( class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService): general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, in_channel, out_channel): def __init__(self, in_channel, out_channel):
super(GeneralPythonService, self).__init__()
self._in_channel = in_channel self._in_channel = in_channel
self._out_channel = out_channel self._out_channel = out_channel
print('succ init')
def inference(self, request, context): def inference(self, request, context):
print('start inferce')
data = python_service_channel_pb2.ChannelData() data = python_service_channel_pb2.ChannelData()
print('gen data: {}'.format(data))
for idx, name in enumerate(request.feed_var_names): for idx, name in enumerate(request.feed_var_names):
print('name: {}'.format(request.feed_var_names[idx]))
print('data: {}'.format(request.feed_insts[idx]))
inst = python_service_channel_pb2.Inst() inst = python_service_channel_pb2.Inst()
inst.data = request.feed_insts(idx) inst.data = request.feed_insts[idx]
inst.name = name inst.name = name
inst.id = 0 #TODO inst.id = 0 #TODO
data.insts.append(inst) data.insts.append(inst)
print('push data')
self._in_channel.push(data) self._in_channel.push(data)
print('wait for infer')
data = self._out_channel.front() data = self._out_channel.front()
data = data[0] #TODO batchsize = 1
print('get data')
resp = general_python_service_pb2.Response() resp = general_python_service_pb2.Response()
print('gen resp')
print(data)
for inst in data.insts: for inst in data.insts:
resp.fetch_data.append(inst.data) print('append data')
resp.fetch_insts.append(inst.data)
print('append name')
resp.fetch_var_names.append(inst.name) resp.fetch_var_names.append(inst.name)
return resp return resp
...@@ -206,18 +226,20 @@ class PyServer(object): ...@@ -206,18 +226,20 @@ class PyServer(object):
def run_server(self): def run_server(self):
for op in self._ops: for op in self._ops:
th = multiprocessing.Process( # th = multiprocessing.Process(target=self.op_start_wrapper, args=(op, ))
target=self.op_start_wrapper, args=(op, )) th = threading.Thread(target=self.op_start_wrapper, args=(op, ))
th.start() th.start()
self._op_threads.append(th) self._op_threads.append(th)
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server( general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
GeneralPythonService(self._in_channel, self._out_channel), server) GeneralPythonService(self._in_channel, self._out_channel), server)
server.add_insecure_port('[::]:{}'.format(self._port))
server.start() server.start()
try: try:
for th in self._op_threads: for th in self._op_threads:
th.join() th.join()
server.join()
except KeyboardInterrupt: except KeyboardInterrupt:
server.stop(0) server.stop(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册