提交 6f10ff23 编写于 作者: B barrierye

add grpc server

上级 d5fa693a
......@@ -372,3 +372,53 @@ class Client(object):
def release(self):
self.client_handle_.destroy_predictor()
self.client_handle_ = None
class GClient(object):
def __init__(self):
self.bclient_ = Client()
self.channel_ = None
def load_client_config(self, path):
pass
def connect(self, endpoint):
self.channel_ = grpc.insecure_channel(endpoint)
self.stub_ = gserver_general_model_service_pb2_grpc.GServerGeneralModelServiceStub(
self.channel_)
def _parse_model_config(self, model_config_path):
model_conf = m_config.GeneralModelConfig()
f = open(model_config_path, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_type_ = {}
self.type_map_ = {0: "int64", 1: "float"}
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
for i, var in enumerate(model_conf.fetch_var):
self.fetch_type_[var.alias_name] = var.fetch_type
def _pack_feed_data(self, feed, fetch):
req = gserver_general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch)
feed_batch = None
if isinstance(feed, dict):
feed_batch = [feed]
elif isinstance(feed, list):
feed_batch = feed
else:
raise Exception("{} not support".format(type(feed)))
#TODO
def _unpack_resp(self, resp):
pass
def predict(self, feed, fetch):
req = self._pack_feed_data(feed, fetch)
resp = self.stub_.inference(req)
return self._unpack_resp(resp)
......@@ -25,6 +25,10 @@ from contextlib import closing
import collections
import fcntl
import numpy as np
import gserver_general_model_service_pb2
import gserver_general_model_service_pb2_grpc
class OpMaker(object):
def __init__(self):
......@@ -428,3 +432,110 @@ class Server(object):
print("Going to Run Command")
print(command)
os.system(command)
class GServerService(
gserver_general_model_service_pb2_grpc.GServerGeneralModelService):
def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client()
self.bclient_.load_client_config(model_config_path)
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path):
model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path),
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_type_ = {}
self.type_map_ = {0: "int64", 1: "float"}
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
for i, var in enumerate(model_conf.fetch_var):
self.fetch_type_[var.alias_name] = var.fetch_type
def _unpack_request(self, request):
fetch_names = list(request.fetch_var_names)
feed_batch = []
for feed_inst in request.feed_insts:
feed_dict = {}
for idx, name in enumerate(feed_inst.names):
data = feed_inst.data[idx]
itype = self.type_map_[self.feed_types_[name]]
feed_dict[name] = np.frombuffer(data, dtype=itype)
feed_batch.append(feed_inst)
return feed_batch, fetch_names
def _pack_resp_package(self, result):
resp = gserver_general_model_service_pb2.Response()
inst = gserver_general_model_service_pb2.Inst()
for name in fetch_names:
inst.name.append(name)
inst.data.append(result[name].tobytes())
inst.lod.append(result["{}.lod".format(name)].tobytes())
resp.fetch_insts.append(inst)
return resp
def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request)
data = self.bclient_.predict(feed=feed_dict, fetch=fetch_names)
return self._pack_resp_package(data)
class GServer(object):
def __init__(self, worker_num=2):
slef.bserver_ = Server()
self.worker_num_ = worker_num
def set_op_sequence(self, op_seq):
self.bserver_.set_op_sequence(op_seq)
def load_model_config(self, model_config_path):
if not isinstance(model_config_path, str):
raise Exception("GServer only supports multi-model temporarily")
self.bserver_.load_model_config(model_config_path)
self.model_config_path_ = model_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
default_port = 12000
self.port_list_ = []
for i in range(1000):
if default_port + i != port and self.port_is_available(default_port
+ i):
self.port_list_.append(default_port + i)
break
self.bserver_.prepare_server(
workdir=workdir, port=self.port_list_[0], device=device)
self.gport_ = port
def _launch_brpc_service(self, bserver):
bserver.run_server()
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
return result != 0
def run_server(self):
p_bserver = Process(
target=self._launch_rpc_service, args=(slef.bserver_, ))
p_bserver.start()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_))
gserver_general_model_service_pb2_grpc.add_GServerGeneralModelService_to_server(
GServerService(self.model_config_path_,
"0.0.0.0:{}".format(self.port_list_[0])), server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
try:
server.join()
p_bserver.join()
except KeyboardInterrupt:
server.stop(0)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
message Inst {
repeated bytes data = 1;
repeated string names = 2;
repeated bytes lod = 3;
}
message Request {
repeated Inst feed_insts = 1;
repeated string fetch_var_names = 2;
};
message Response { repeated Inst fetch_insts = 1; }
service GServerGeneralModelService {
rpc inference(Request) returns (Response) {}
};
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册