提交 15fa6d0f 编写于 作者: B barrierye

update code

上级 0bac3e06
......@@ -13,9 +13,9 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_server.pyserver import Op
from paddle_serving_server.pyserver import Channel
from paddle_serving_server.pyserver import PyServer
from pyserver import Op
from pyserver import Channel
from pyserver import PyServer
# channel data: {name(str): data(bytes)}
......@@ -65,11 +65,11 @@ combine_op = CombineOp(
pyserver = PyServer()
pyserver.add_channel(read_channel)
pyserver.add_cnannel(cnn_out_channel)
pyserver.add_cnannel(bow_out_channel)
pyserver.add_cnannel(combine_out_channel)
pyserver.add_channel(cnn_out_channel)
pyserver.add_channel(bow_out_channel)
pyserver.add_channel(combine_out_channel)
pyserver.add_op(cnn_op)
pyserver.add_op(bow_op)
pyserver.add_op(combine_op)
pyserver.prepare_server(port=8080, worker_num=4)
pyserver.prepare_server(port=8080, worker_num=1)
pyserver.run_server()
......@@ -19,11 +19,11 @@ service GeneralPythonService {
}
message Request {
repeated bytes feedinsts = 1;
repeated string fetch_var_names = 2;
repeated bytes feed_insts = 1;
repeated string feed_var_names = 2;
}
message Response {
repeated bytes fetchinsts = 1;
repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2;
}
......@@ -14,7 +14,7 @@
# pylint: disable=doc-string-missing
import threading
import multiprocessing
import queue
import Queue
import os
import paddle_serving_server
from paddle_serving_client import Client
......@@ -23,11 +23,13 @@ import numpy
import grpc
import general_python_service_pb2
import general_python_service_pb2_grpc
import python_service_channel_pb2
class Channel(queue.Queue):
def __init__(self, consumer=1, maxsize=0, timeout=0, batchsize=1):
super(Channel, self).__init__(maxsize=maxsize)
class Channel(Queue.Queue):
def __init__(self, consumer=1, maxsize=0, timeout=None, batchsize=1):
Queue.Queue.__init__(self, maxsize=maxsize)
# super(Channel, self).__init__(maxsize=maxsize)
self._maxsize = maxsize
self._timeout = timeout
self._batchsize = batchsize
......@@ -46,7 +48,7 @@ class Channel(queue.Queue):
self._pushbatch.append(item)
def front(self):
if consumer == 1:
if self._consumer == 1:
return self.get(timeout=self._timeout)
with self._frontlock:
if self._count == 0:
......@@ -70,13 +72,14 @@ class Op(object):
self._run = False
self.set_inputs(inputs)
self.set_outputs(outputs)
self._client = None
if client_config is not None and \
server_name is not None and \
fetch_names is not None:
self.set_client(client_config, server_name, fetch_names)
self._server_model = server_model
self._server_port = server_port
self._device = deviceis
self._device = device
def set_client(self, client_config, server_name, fetch_names):
self._client = Client()
......@@ -141,13 +144,21 @@ class GeneralPythonService(
self._in_channel = in_channel
self._out_channel = out_channel
def Request(self, request, context):
data_dict = {}
for idx, name in enumerate(request.fetch_var_names):
data_dict[name] = request.feedinsts[idx]
self._in_channel.push(data_dict)
resp = self._out_channel.front()
return general_python_service_pb2_grpc.Response(resp)
def inference(self, request, context):
data = python_service_channel_pb2.ChannelData()
for idx, name in enumerate(request.feed_var_names):
inst = python_service_channel_pb2.Inst()
inst.data = request.feed_insts(idx)
inst.name = name
inst.id = 0 #TODO
data.insts.append(inst)
self._in_channel.push(data)
data = self._out_channel.front()
resp = general_python_service_pb2.Response()
for inst in data.insts:
resp.fetch_data.append(inst.data)
resp.fetch_var_names.append(inst.name)
return resp
class PyServer(object):
......@@ -164,7 +175,7 @@ class PyServer(object):
self._channels.append(channel)
def add_op(self, op):
slef._ops.append(op)
self._ops.append(op)
def gen_desc(self):
print('here will generate desc for paas')
......@@ -176,8 +187,8 @@ class PyServer(object):
inputs = set()
outputs = set()
for op in self._ops:
inputs += op.get_inputs()
outputs += op.get_outputs()
inputs |= set(op.get_inputs())
outputs |= set(op.get_outputs())
if op.with_serving():
self.prepare_serving(op)
in_channel = inputs - outputs
......@@ -190,14 +201,18 @@ class PyServer(object):
self._out_channel = out_channel.pop()
self.gen_desc()
def op_start_wrapper(self, op):
return op.start()
def run_server(self):
for op in self._ops:
th = multiprocessing.Process(target=op.start, args=(op, ))
th = multiprocessing.Process(
target=self.op_start_wrapper, args=(op, ))
th.start()
self._op_threads.append(th)
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonService_to_server(
general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
GeneralPythonService(self._in_channel, self._out_channel), server)
server.start()
try:
......@@ -218,4 +233,6 @@ class PyServer(object):
else:
cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port)
print(cmd)
return
os.system(cmd)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
message ChannelData { repeated Inst insts = 1; }
message Inst {
required bytes data = 1;
required string name = 2;
required int32 id = 3;
optional string type = 4 [ default = "channel" ];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册