提交 0f67439c 编写于 作者: B barrierye

add GClient && make it succ to run

上级 94b42944
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import GClient
import sys
client = GClient()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
import paddle
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import os
import sys
from paddle_serving_server import OpMaker
from paddle_serving_server import OpSeqMaker
from paddle_serving_server import GServer
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(response_op)
server = GServer()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server()
...@@ -23,6 +23,10 @@ import time ...@@ -23,6 +23,10 @@ import time
import sys import sys
from .serving_client import PredictorRes from .serving_client import PredictorRes
import grpc
import gserver_general_model_service_pb2
import gserver_general_model_service_pb2_grpc
int_type = 0 int_type = 0
float_type = 1 float_type = 1
...@@ -376,14 +380,15 @@ class Client(object): ...@@ -376,14 +380,15 @@ class Client(object):
class GClient(object): class GClient(object):
def __init__(self): def __init__(self):
self.bclient_ = Client()
self.channel_ = None self.channel_ = None
def load_client_config(self, path): def load_client_config(self, path):
pass if not isinstance(path, str):
raise Exception("GClient only supports multi-model temporarily")
self._parse_model_config(path)
def connect(self, endpoint): def connect(self, endpoint):
self.channel_ = grpc.insecure_channel(endpoint) self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO
self.stub_ = gserver_general_model_service_pb2_grpc.GServerGeneralModelServiceStub( self.stub_ = gserver_general_model_service_pb2_grpc.GServerGeneralModelServiceStub(
self.channel_) self.channel_)
...@@ -394,14 +399,24 @@ class GClient(object): ...@@ -394,14 +399,24 @@ class GClient(object):
str(f.read()), model_conf) str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {} self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_type_ = {} self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float"} self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var): for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape self.feed_shapes_[var.alias_name] = var.shape
if self.feed_types_[var.alias_name] == 'float':
self.feed_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
for i, var in enumerate(model_conf.fetch_var): for i, var in enumerate(model_conf.fetch_var):
self.fetch_type_[var.alias_name] = var.fetch_type self.fetch_types_[var.alias_name] = var.fetch_type
if self.fetch_types_[var.alias_name] == 'float':
self.fetch_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch): def _pack_feed_data(self, feed, fetch):
req = gserver_general_model_service_pb2.Request() req = gserver_general_model_service_pb2.Request()
...@@ -413,10 +428,29 @@ class GClient(object): ...@@ -413,10 +428,29 @@ class GClient(object):
feed_batch = feed feed_batch = feed
else: else:
raise Exception("{} not support".format(type(feed))) raise Exception("{} not support".format(type(feed)))
#TODO for feed_data in feed_batch:
inst = gserver_general_model_service_pb2.Inst()
for name, var in feed_data.items():
inst.names.append(name)
itype = self.type_map_[self.feed_types_[name]]
data = np.array(var, dtype=itype)
inst.data.append(data.tobytes())
req.feed_insts.append(inst)
return req
def _unpack_resp(self, resp): def _unpack_resp(self, resp):
pass result_map = {}
inst = resp.fetch_insts[0]
for i, name in enumerate(inst.names):
if name not in self.fetch_names_:
continue
itype = self.type_map_[self.fetch_types_[name]]
result_map[name] = np.frombuffer(inst.data[i], dtype=itype)
result_map[name].shape = np.frombuffer(inst.shape[i], dtype="int32")
if name in self.lod_tensor_set_:
result_map["{}.lod".format(name)] = np.frombuffer(
inst.lod[i], dtype="int32")
return result_map
def predict(self, feed, fetch): def predict(self, feed, fetch):
req = self._pack_feed_data(feed, fetch) req = self._pack_feed_data(feed, fetch)
......
...@@ -26,8 +26,11 @@ import collections ...@@ -26,8 +26,11 @@ import collections
import fcntl import fcntl
import numpy as np import numpy as np
import grpc
import gserver_general_model_service_pb2 import gserver_general_model_service_pb2
import gserver_general_model_service_pb2_grpc import gserver_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures
class OpMaker(object): class OpMaker(object):
...@@ -440,7 +443,8 @@ class GServerService( ...@@ -440,7 +443,8 @@ class GServerService(
from paddle_serving_client import Client from paddle_serving_client import Client
self._parse_model_config(model_config_path) self._parse_model_config(model_config_path)
self.bclient_ = Client() self.bclient_ = Client()
self.bclient_.load_client_config(model_config_path) self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.connect(endpoints) self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path): def _parse_model_config(self, model_config_path):
...@@ -451,14 +455,24 @@ class GServerService( ...@@ -451,14 +455,24 @@ class GServerService(
str(f.read()), model_conf) str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {} self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_type_ = {} self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float"} self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var): for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape self.feed_shapes_[var.alias_name] = var.shape
if self.feed_types_[var.alias_name] == 'float':
self.feed_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
for i, var in enumerate(model_conf.fetch_var): for i, var in enumerate(model_conf.fetch_var):
self.fetch_type_[var.alias_name] = var.fetch_type self.fetch_types_[var.alias_name] = var.fetch_type
if self.fetch_types_[var.alias_name] == 'float':
self.fetch_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _unpack_request(self, request): def _unpack_request(self, request):
fetch_names = list(request.fetch_var_names) fetch_names = list(request.fetch_var_names)
...@@ -469,28 +483,35 @@ class GServerService( ...@@ -469,28 +483,35 @@ class GServerService(
data = feed_inst.data[idx] data = feed_inst.data[idx]
itype = self.type_map_[self.feed_types_[name]] itype = self.type_map_[self.feed_types_[name]]
feed_dict[name] = np.frombuffer(data, dtype=itype) feed_dict[name] = np.frombuffer(data, dtype=itype)
feed_batch.append(feed_inst) feed_batch.append(feed_dict)
return feed_batch, fetch_names return feed_batch, fetch_names
def _pack_resp_package(self, result): def _pack_resp_package(self, result, fetch_names):
resp = gserver_general_model_service_pb2.Response() resp = gserver_general_model_service_pb2.Response()
inst = gserver_general_model_service_pb2.Inst() inst = gserver_general_model_service_pb2.Inst()
for name in fetch_names: for name in fetch_names:
inst.name.append(name) inst.names.append(name)
inst.data.append(result[name].tobytes()) inst.data.append(result[name].tobytes())
inst.lod.append(result["{}.lod".format(name)].tobytes()) inst.shape.append(
np.array(
result[name].shape, dtype="int32").tobytes())
if name in self.lod_tensor_set_:
inst.lod.append(result["{}.lod".format(name)].tobytes())
else:
# TODO
inst.lod.append(bytes(0))
resp.fetch_insts.append(inst) resp.fetch_insts.append(inst)
return resp return resp
def inference(self, request, context): def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request) feed_dict, fetch_names = self._unpack_request(request)
data = self.bclient_.predict(feed=feed_dict, fetch=fetch_names) data = self.bclient_.predict(feed=feed_dict, fetch=fetch_names)
return self._pack_resp_package(data) return self._pack_resp_package(data, fetch_names)
class GServer(object): class GServer(object):
def __init__(self, worker_num=2): def __init__(self, worker_num=2):
slef.bserver_ = Server() self.bserver_ = Server()
self.worker_num_ = worker_num self.worker_num_ = worker_num
def set_op_sequence(self, op_seq): def set_op_sequence(self, op_seq):
...@@ -506,8 +527,8 @@ class GServer(object): ...@@ -506,8 +527,8 @@ class GServer(object):
default_port = 12000 default_port = 12000
self.port_list_ = [] self.port_list_ = []
for i in range(1000): for i in range(1000):
if default_port + i != port and self.port_is_available(default_port if default_port + i != port and self._port_is_available(default_port
+ i): + i):
self.port_list_.append(default_port + i) self.port_list_.append(default_port + i)
break break
self.bserver_.prepare_server( self.bserver_.prepare_server(
...@@ -525,17 +546,14 @@ class GServer(object): ...@@ -525,17 +546,14 @@ class GServer(object):
def run_server(self): def run_server(self):
p_bserver = Process( p_bserver = Process(
target=self._launch_rpc_service, args=(slef.bserver_, )) target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start() p_bserver.start()
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_)) futures.ThreadPoolExecutor(max_workers=self.worker_num_))
gserver_general_model_service_pb2_grpc.add_GServerGeneralModelService_to_server( gserver_general_model_service_pb2_grpc.add_GServerGeneralModelServiceServicer_to_server(
GServerService(self.model_config_path_, GServerService(self.model_config_path_,
"0.0.0.0:{}".format(self.port_list_[0])), server) ["0.0.0.0:{}".format(self.port_list_[0])]), server)
server.add_insecure_port('[::]:{}'.format(self.gport_)) server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start() server.start()
try: p_bserver.join()
server.join() server.wait_for_termination()
p_bserver.join()
except KeyboardInterrupt:
server.stop(0)
...@@ -18,6 +18,7 @@ message Inst { ...@@ -18,6 +18,7 @@ message Inst {
repeated bytes data = 1; repeated bytes data = 1;
repeated string names = 2; repeated string names = 2;
repeated bytes lod = 3; repeated bytes lod = 3;
repeated bytes shape = 4;
} }
message Request { message Request {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册