提交 15fa6d0f 编写于 作者: B barrierye

update code

上级 0bac3e06
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_server.pyserver import Op from pyserver import Op
from paddle_serving_server.pyserver import Channel from pyserver import Channel
from paddle_serving_server.pyserver import PyServer from pyserver import PyServer
# channel data: {name(str): data(bytes)} # channel data: {name(str): data(bytes)}
...@@ -65,11 +65,11 @@ combine_op = CombineOp( ...@@ -65,11 +65,11 @@ combine_op = CombineOp(
pyserver = PyServer() pyserver = PyServer()
pyserver.add_channel(read_channel) pyserver.add_channel(read_channel)
pyserver.add_cnannel(cnn_out_channel) pyserver.add_channel(cnn_out_channel)
pyserver.add_cnannel(bow_out_channel) pyserver.add_channel(bow_out_channel)
pyserver.add_cnannel(combine_out_channel) pyserver.add_channel(combine_out_channel)
pyserver.add_op(cnn_op) pyserver.add_op(cnn_op)
pyserver.add_op(bow_op) pyserver.add_op(bow_op)
pyserver.add_op(combine_op) pyserver.add_op(combine_op)
pyserver.prepare_server(port=8080, worker_num=4) pyserver.prepare_server(port=8080, worker_num=1)
pyserver.run_server() pyserver.run_server()
...@@ -19,11 +19,11 @@ service GeneralPythonService { ...@@ -19,11 +19,11 @@ service GeneralPythonService {
} }
message Request { message Request {
repeated bytes feedinsts = 1; repeated bytes feed_insts = 1;
repeated string fetch_var_names = 2; repeated string feed_var_names = 2;
} }
message Response { message Response {
repeated bytes fetchinsts = 1; repeated bytes fetch_insts = 1;
repeated string fetch_var_names = 2; repeated string fetch_var_names = 2;
} }
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import threading import threading
import multiprocessing import multiprocessing
import queue import Queue
import os import os
import paddle_serving_server import paddle_serving_server
from paddle_serving_client import Client from paddle_serving_client import Client
...@@ -23,11 +23,13 @@ import numpy ...@@ -23,11 +23,13 @@ import numpy
import grpc import grpc
import general_python_service_pb2 import general_python_service_pb2
import general_python_service_pb2_grpc import general_python_service_pb2_grpc
import python_service_channel_pb2
class Channel(queue.Queue): class Channel(Queue.Queue):
def __init__(self, consumer=1, maxsize=0, timeout=0, batchsize=1): def __init__(self, consumer=1, maxsize=0, timeout=None, batchsize=1):
super(Channel, self).__init__(maxsize=maxsize) Queue.Queue.__init__(self, maxsize=maxsize)
# super(Channel, self).__init__(maxsize=maxsize)
self._maxsize = maxsize self._maxsize = maxsize
self._timeout = timeout self._timeout = timeout
self._batchsize = batchsize self._batchsize = batchsize
...@@ -46,7 +48,7 @@ class Channel(queue.Queue): ...@@ -46,7 +48,7 @@ class Channel(queue.Queue):
self._pushbatch.append(item) self._pushbatch.append(item)
def front(self): def front(self):
if consumer == 1: if self._consumer == 1:
return self.get(timeout=self._timeout) return self.get(timeout=self._timeout)
with self._frontlock: with self._frontlock:
if self._count == 0: if self._count == 0:
...@@ -70,13 +72,14 @@ class Op(object): ...@@ -70,13 +72,14 @@ class Op(object):
self._run = False self._run = False
self.set_inputs(inputs) self.set_inputs(inputs)
self.set_outputs(outputs) self.set_outputs(outputs)
self._client = None
if client_config is not None and \ if client_config is not None and \
server_name is not None and \ server_name is not None and \
fetch_names is not None: fetch_names is not None:
self.set_client(client_config, server_name, fetch_names) self.set_client(client_config, server_name, fetch_names)
self._server_model = server_model self._server_model = server_model
self._server_port = server_port self._server_port = server_port
self._device = deviceis self._device = device
def set_client(self, client_config, server_name, fetch_names): def set_client(self, client_config, server_name, fetch_names):
self._client = Client() self._client = Client()
...@@ -141,13 +144,21 @@ class GeneralPythonService( ...@@ -141,13 +144,21 @@ class GeneralPythonService(
self._in_channel = in_channel self._in_channel = in_channel
self._out_channel = out_channel self._out_channel = out_channel
def Request(self, request, context): def inference(self, request, context):
data_dict = {} data = python_service_channel_pb2.ChannelData()
for idx, name in enumerate(request.fetch_var_names): for idx, name in enumerate(request.feed_var_names):
data_dict[name] = request.feedinsts[idx] inst = python_service_channel_pb2.Inst()
self._in_channel.push(data_dict) inst.data = request.feed_insts(idx)
resp = self._out_channel.front() inst.name = name
return general_python_service_pb2_grpc.Response(resp) inst.id = 0 #TODO
data.insts.append(inst)
self._in_channel.push(data)
data = self._out_channel.front()
resp = general_python_service_pb2.Response()
for inst in data.insts:
resp.fetch_data.append(inst.data)
resp.fetch_var_names.append(inst.name)
return resp
class PyServer(object): class PyServer(object):
...@@ -164,7 +175,7 @@ class PyServer(object): ...@@ -164,7 +175,7 @@ class PyServer(object):
self._channels.append(channel) self._channels.append(channel)
def add_op(self, op): def add_op(self, op):
slef._ops.append(op) self._ops.append(op)
def gen_desc(self): def gen_desc(self):
print('here will generate desc for paas') print('here will generate desc for paas')
...@@ -176,8 +187,8 @@ class PyServer(object): ...@@ -176,8 +187,8 @@ class PyServer(object):
inputs = set() inputs = set()
outputs = set() outputs = set()
for op in self._ops: for op in self._ops:
inputs += op.get_inputs() inputs |= set(op.get_inputs())
outputs += op.get_outputs() outputs |= set(op.get_outputs())
if op.with_serving(): if op.with_serving():
self.prepare_serving(op) self.prepare_serving(op)
in_channel = inputs - outputs in_channel = inputs - outputs
...@@ -190,14 +201,18 @@ class PyServer(object): ...@@ -190,14 +201,18 @@ class PyServer(object):
self._out_channel = out_channel.pop() self._out_channel = out_channel.pop()
self.gen_desc() self.gen_desc()
def op_start_wrapper(self, op):
return op.start()
def run_server(self): def run_server(self):
for op in self._ops: for op in self._ops:
th = multiprocessing.Process(target=op.start, args=(op, )) th = multiprocessing.Process(
target=self.op_start_wrapper, args=(op, ))
th.start() th.start()
self._op_threads.append(th) self._op_threads.append(th)
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self._worker_num)) futures.ThreadPoolExecutor(max_workers=self._worker_num))
general_python_service_pb2_grpc.add_GeneralPythonService_to_server( general_python_service_pb2_grpc.add_GeneralPythonServiceServicer_to_server(
GeneralPythonService(self._in_channel, self._out_channel), server) GeneralPythonService(self._in_channel, self._out_channel), server)
server.start() server.start()
try: try:
...@@ -218,4 +233,6 @@ class PyServer(object): ...@@ -218,4 +233,6 @@ class PyServer(object):
else: else:
cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format( cmd = "python -m paddle_serving_server_gpu.serve --model {} --thread 4 --port {} &>/dev/null &".format(
model_path, port) model_path, port)
print(cmd)
return
os.system(cmd) os.system(cmd)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
message ChannelData { repeated Inst insts = 1; }
message Inst {
required bytes data = 1;
required string name = 2;
required int32 id = 3;
optional string type = 4 [ default = "channel" ];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册