提交 db1599af 编写于 作者: B barrierye

add grpc

上级 ea7cab8f
......@@ -57,4 +57,5 @@ pyserver.add_cnannel(combine_out_channel)
pyserver.add_op(cnn_op)
pyserver.add_op(bow_op)
pyserver.add_op(combine_op)
pyserver.prepare_server(port=8080, worker_num=4)
pyserver.run_server()
......@@ -18,6 +18,7 @@ import queue
import os
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
......@@ -82,11 +83,20 @@ class Op(object):
self._client.connect([server_name])
self._fetch_names = fetch_names
def with_serving(self):
return self._client is not None
def get_inputs(self):
return self._inputs
def set_inputs(self, channels):
if not isinstance(channels, list):
raise TypeError('channels must be list type')
self._inputs = channels
def get_outputs(self):
return self._outputs
def set_outputs(self, channels):
if not isinstance(channels, list):
raise TypeError('channels must be list type')
......@@ -114,7 +124,7 @@ class Op(object):
input_data.append(channel.front())
data = self.preprocess(input_data)
if self._client is not None:
if self.with_serving():
fetch_map = self.midprocess(data)
output_data = self.postprocess(fetch_map)
else:
......@@ -124,11 +134,25 @@ class Op(object):
channel.push(output_data)
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, channel):
self._channel = channel
def Request(self, request, context):
pass
def Response(self, request, context):
pass
class PyServer(object):
def __init__(self):
self._channels = []
self._ops = []
self._op_threads = []
self._port = None
self._worker_num = None
def add_channel(self, channel):
self._channels.append(channel)
......@@ -139,16 +163,42 @@ class PyServer(object):
def gen_desc(self):
pass
def prepare_server(self, port, worker_num):
self._port = port
self._worker_num = worker_num
self.gen_desc()
def run_server(self):
inputs = []
outputs = []
for op in self._ops:
self.prepare_server(op)
inputs += op.get_inputs()
outputs += op.get_outputs()
if op.with_serving():
self.prepare_serving(op)
th = multiprocessing.Process(target=op.start, args=(op, ))
th.start()
self._op_threads.append(th)
for th in self._op_threads:
th.join()
def prepare_server(self, op):
input_channel = []
for channel in inputs:
if channel not in outputs:
input_channel.append(channel)
if len(input_channel) != 1:
raise Exception("input_channel more than 1 or no input_channel")
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonService_to_server(
GeneralPythonService(input_channel[0]), server)
server.start()
try:
for th in self._op_threads:
th.join()
except KeyboardInterrupt:
server.stop(0)
def prepare_serving(self, op):
model_path = op._server_model
port = op._server_port
device = op._device
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册