提交 0bac3e06 编写于 作者: B barrierye

update code

上级 c9707058
......@@ -18,36 +18,50 @@ from paddle_serving_server.pyserver import Channel
from paddle_serving_server.pyserver import PyServer
class CNNOp(Op):
# channel data: {name(str): data(bytes)}
class ImdbOp(Op):
def preprocess(self, input_data):
pass
x = input_data[0]['words']
feed = {"words": np.array(x)}
return feed
def postprocess(self, output_data):
pass
data = {"resp": fetch_map["prediction"][0][0]}
return data
class CombineOp(Op):
def preprocess(self, input_data):
cnt = 0
for data in input_data:
cnt += data['resp']
return {"resp": cnt}
read_channel = Channel(consumer=2)
cnn_out_channel = Channel()
bow_out_channel = Channel()
combine_out_channel = Channel()
cnn_op = Op(inputs=[read_channel],
cnn_op = ImdbOp(
inputs=[read_channel],
outputs=[cnn_out_channel],
server_model="./imdb_cnn_model",
server_port="9393",
device="cpu",
client_config="imdb_cnn_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9393",
fetch_names=["acc", "cost", "prediction"])
bow_op = Op(inputs=[read_channel],
fetch_names=["prediction"])
bow_op = ImdbOp(
inputs=[read_channel],
outputs=[bow_out_channel],
server_model="./imdb_bow_model",
server_port="9292",
device="cpu",
client_config="imdb_bow_client_conf/serving_client_conf.prototxt",
server_name="127.0.0.1:9292",
fetch_names=["acc", "cost", "prediction"])
combine_op = Op(inputs=[cnn_out_channel, bow_out_channel],
outputs=[combine_out_channel])
fetch_names=["prediction"])
combine_op = CombineOp(
inputs=[cnn_out_channel, bow_out_channel], outputs=[combine_out_channel])
pyserver = PyServer()
pyserver.add_channel(read_channel)
......
......@@ -18,23 +18,12 @@ service GeneralPythonService {
rpc inference(Request) returns (Response) {}
}
message Tensor {
repeated bytes data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
repeated int32 lod = 7;
};
message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message Request {
repeated FeedInst insts = 1;
repeated bytes feedinsts = 1;
repeated string fetch_var_names = 2;
}
message Response { repeated FetchInst insts = 1; }
message Response {
repeated bytes fetchinsts = 1;
repeated string fetch_var_names = 2;
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import queue
import os
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
class Channel(queue.Queue):
def __init__(self, consumer=1, maxsize=0, timeout=0, batchsize=1):
super(Channel, self).__init__(maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self._batchsize = batchsize
self._consumer = consumer
self._pushlock = threading.Lock()
self._frontlock = threading.Lock()
self._pushbatch = []
self._frontbatch = None
self._count = 0
def push(self, item):
with self._pushlock:
if len(self._pushbatch) == batchsize:
self.put(self._pushbatch, timeout=self._timeout)
self._pushbatch = []
self._pushbatch.append(item)
def front(self):
if consumer == 1:
return self.get(timeout=self._timeout)
with self._frontlock:
if self._count == 0:
self._frontbatch = self.get(timeout=self._timeout)
self._count += 1
if self._count == self._consumer:
self._count = 0
return self._frontbatch
class Op(object):
def __init__(self,
inputs,
outputs,
server_model=None,
server_port=None,
device=None,
client_config=None,
server_name=None,
fetch_names=None):
self._run = False
self.set_inputs(inputs)
self.set_outputs(outputs)
if client_config is not None and \
server_name is not None and \
fetch_names is not None:
self.set_client(client_config, server_name, fetch_names)
self._server_model = server_model
self._server_port = server_port
self._device = deviceis
def set_client(self, client_config, server_name, fetch_names):
self._client = Client()
self._client.load_client_config(client_config)
self._client.connect([server_name])
self._fetch_names = fetch_names
def with_serving(self):
return self._client is not None
def get_inputs(self):
return self._inputs
def set_inputs(self, channels):
if not isinstance(channels, list):
raise TypeError('channels must be list type')
self._inputs = channels
def get_outputs(self):
return self._outputs
def set_outputs(self, channels):
if not isinstance(channels, list):
raise TypeError('channels must be list type')
self._outputs = channels
def preprocess(self, input_data):
return input_data
def midprocess(self, data):
# data = preprocess(input), which is a dict
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
return fetch_map
def postprocess(self, output_data):
return output_data
def stop(self):
self._run = False
def start(self):
self._run = True
while self._run:
input_data = []
for channel in self._inputs:
input_data.append(channel.front())
data = self.preprocess(input_data)
if self.with_serving():
fetch_map = self.midprocess(data)
output_data = self.postprocess(fetch_map)
else:
output_data = self.postprocess(data)
for channel in self._outputs:
channel.push(output_data)
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, channel):
self._channel = channel
def Request(self, request, context):
pass
def Response(self, request, context):
pass
class PyServer(object):
def __init__(self):
self._channels = []
self._ops = []
self._op_threads = []
self._port = None
self._worker_num = None
def add_channel(self, channel):
self._channels.append(channel)
def add_op(self, op):
slef._ops.append(op)
def gen_desc(self):
pass
def prepare_server(self, port, worker_num):
self._port = port
self._worker_num = worker_num
self.gen_desc()
def run_server(self):
inputs = []
outputs = []
for op in self._ops:
inputs += op.get_inputs()
outputs += op.get_outputs()
if op.with_serving():
self.prepare_serving(op)
th = multiprocessing.Process(target=op.start, args=(op, ))
th.start()
self._op_threads.append(th)
input_channel = []
for channel in inputs:
if channel not in outputs:
input_channel.append(channel)
if len(input_channel) != 1:
raise Exception("input_channel more than 1 or no input_channel")
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonService_to_server(
GeneralPythonService(input_channel[0]), server)
server.start()
try:
for th in self._op_threads:
th.join()
except KeyboardInterrupt:
server.stop(0)
def prepare_serving(self, op):
model_path = op._server_model
port = op._server_port
device = op._device
# run a server (not in PyServing)
if device == "cpu":
cmd = "python -m paddle_serving_server.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port)
else:
cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port)
os.system(cmd)
......@@ -19,6 +19,7 @@ import os
import paddle_serving_server
from paddle_serving_client import Client
from concurrent import futures
import numpy
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
......@@ -106,7 +107,7 @@ class Op(object):
return input_data
def midprocess(self, data):
# data = preprocess(input), which is a dict
# data = preprocess(input), which must be a dict
fetch_map = self._client.predict(feed=data, fetch=self._fetch_names)
return fetch_map
......@@ -136,14 +137,17 @@ class Op(object):
class GeneralPythonService(
general_python_service_pb2_grpc.GeneralPythonService):
def __init__(self, channel):
self._channel = channel
def __init__(self, in_channel, out_channel):
self._in_channel = in_channel
self._out_channel = out_channel
def Request(self, request, context):
pass
def Response(self, request, context):
pass
data_dict = {}
for idx, name in enumerate(request.fetch_var_names):
data_dict[name] = request.feedinsts[idx]
self._in_channel.push(data_dict)
resp = self._out_channel.front()
return general_python_service_pb2_grpc.Response(resp)
class PyServer(object):
......@@ -153,6 +157,8 @@ class PyServer(object):
self._op_threads = []
self._port = None
self._worker_num = None
self._in_channel = None
self._out_channel = None
def add_channel(self, channel):
self._channels.append(channel)
......@@ -161,36 +167,38 @@ class PyServer(object):
slef._ops.append(op)
def gen_desc(self):
print('here will generate desc for paas')
pass
def prepare_server(self, port, worker_num):
self._port = port
self._worker_num = worker_num
self.gen_desc()
def run_server(self):
inputs = []
outputs = []
inputs = set()
outputs = set()
for op in self._ops:
inputs += op.get_inputs()
outputs += op.get_outputs()
if op.with_serving():
self.prepare_serving(op)
in_channel = inputs - outputs
out_channel = outputs - inputs
if len(in_channel) != 1 or len(out_channel) != 1:
raise Exception(
"in_channel(out_channel) more than 1 or no in_channel(out_channel)"
)
self._in_channel = in_channel.pop()
self._out_channel = out_channel.pop()
self.gen_desc()
def run_server(self):
for op in self._ops:
th = multiprocessing.Process(target=op.start, args=(op, ))
th.start()
self._op_threads.append(th)
input_channel = []
for channel in inputs:
if channel not in outputs:
input_channel.append(channel)
if len(input_channel) != 1:
raise Exception("input_channel more than 1 or no input_channel")
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonService_to_server(
GeneralPythonService(input_channel[0]), server)
GeneralPythonService(self._in_channel, self._out_channel), server)
server.start()
try:
for th in self._op_threads:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册