diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index f201eefc449b3aea11db6ae209d79fb6acb05173..4eedcf719f0bba9cbe01c43053b8896deeb095e1 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -372,3 +372,53 @@ class Client(object): def release(self): self.client_handle_.destroy_predictor() self.client_handle_ = None + + +class GClient(object): + def __init__(self): + self.bclient_ = Client() + self.channel_ = None + + def load_client_config(self, path): + pass + + def connect(self, endpoint): + self.channel_ = grpc.insecure_channel(endpoint) + self.stub_ = gserver_general_model_service_pb2_grpc.GServerGeneralModelServiceStub( + self.channel_) + + def _parse_model_config(self, model_config_path): + model_conf = m_config.GeneralModelConfig() + f = open(model_config_path, 'r') + model_conf = google.protobuf.text_format.Merge( + str(f.read()), model_conf) + self.feed_names_ = [var.alias_name for var in model_conf.feed_var] + self.feed_types_ = {} + self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] + self.fetch_type_ = {} + self.type_map_ = {0: "int64", 1: "float"} + for i, var in enumerate(model_conf.feed_var): + self.feed_types_[var.alias_name] = var.feed_type + self.feed_shapes_[var.alias_name] = var.shape + for i, var in enumerate(model_conf.fetch_var): + self.fetch_type_[var.alias_name] = var.fetch_type + + def _pack_feed_data(self, feed, fetch): + req = gserver_general_model_service_pb2.Request() + req.fetch_var_names.extend(fetch) + feed_batch = None + if isinstance(feed, dict): + feed_batch = [feed] + elif isinstance(feed, list): + feed_batch = feed + else: + raise Exception("{} not support".format(type(feed))) + #TODO + + def _unpack_resp(self, resp): + pass + + def predict(self, feed, fetch): + req = self._pack_feed_data(feed, fetch) + resp = self.stub_.inference(req) + return self._unpack_resp(resp) diff --git a/python/paddle_serving_server/__init__.py b/python/paddle_serving_server/__init__.py index 7356de2c2feac126272cf9a771a03146a87ef541..f584ec21a7741954952d700aef5751c4b0c4a2ac 100644 --- a/python/paddle_serving_server/__init__.py +++ b/python/paddle_serving_server/__init__.py @@ -25,6 +25,10 @@ from contextlib import closing import collections import fcntl +import numpy as np +import gserver_general_model_service_pb2 +import gserver_general_model_service_pb2_grpc + class OpMaker(object): def __init__(self): @@ -428,3 +432,110 @@ class Server(object): print("Going to Run Command") print(command) os.system(command) + + +class GServerService( + gserver_general_model_service_pb2_grpc.GServerGeneralModelService): + def __init__(self, model_config_path, endpoints): + from paddle_serving_client import Client + self._parse_model_config(model_config_path) + self.bclient_ = Client() + self.bclient_.load_client_config(model_config_path) + self.bclient_.connect(endpoints) + + def _parse_model_config(self, model_config_path): + model_conf = m_config.GeneralModelConfig() + f = open("{}/serving_server_conf.prototxt".format(model_config_path), + 'r') + model_conf = google.protobuf.text_format.Merge( + str(f.read()), model_conf) + self.feed_names_ = [var.alias_name for var in model_conf.feed_var] + self.feed_types_ = {} + self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] + self.fetch_type_ = {} + self.type_map_ = {0: "int64", 1: "float"} + for i, var in enumerate(model_conf.feed_var): + self.feed_types_[var.alias_name] = var.feed_type + self.feed_shapes_[var.alias_name] = var.shape + for i, var in enumerate(model_conf.fetch_var): + self.fetch_type_[var.alias_name] = var.fetch_type + + def _unpack_request(self, request): + fetch_names = list(request.fetch_var_names) + feed_batch = [] + for feed_inst in request.feed_insts: + feed_dict = {} + for idx, name in enumerate(feed_inst.names): + data = feed_inst.data[idx] + itype = self.type_map_[self.feed_types_[name]] + feed_dict[name] = np.frombuffer(data, dtype=itype) + feed_batch.append(feed_inst) + return feed_batch, fetch_names + + def _pack_resp_package(self, result): + resp = gserver_general_model_service_pb2.Response() + inst = gserver_general_model_service_pb2.Inst() + for name in fetch_names: + inst.name.append(name) + inst.data.append(result[name].tobytes()) + inst.lod.append(result["{}.lod".format(name)].tobytes()) + resp.fetch_insts.append(inst) + return resp + + def inference(self, request, context): + feed_dict, fetch_names = self._unpack_request(request) + data = self.bclient_.predict(feed=feed_dict, fetch=fetch_names) + return self._pack_resp_package(data) + + +class GServer(object): + def __init__(self, worker_num=2): + slef.bserver_ = Server() + self.worker_num_ = worker_num + + def set_op_sequence(self, op_seq): + self.bserver_.set_op_sequence(op_seq) + + def load_model_config(self, model_config_path): + if not isinstance(model_config_path, str): + raise Exception("GServer only supports multi-model temporarily") + self.bserver_.load_model_config(model_config_path) + self.model_config_path_ = model_config_path + + def prepare_server(self, workdir=None, port=9292, device="cpu"): + default_port = 12000 + self.port_list_ = [] + for i in range(1000): + if default_port + i != port and self.port_is_available(default_port + + i): + self.port_list_.append(default_port + i) + break + self.bserver_.prepare_server( + workdir=workdir, port=self.port_list_[0], device=device) + self.gport_ = port + + def _launch_brpc_service(self, bserver): + bserver.run_server() + + def _port_is_available(self, port): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + result = sock.connect_ex(('0.0.0.0', port)) + return result != 0 + + def run_server(self): + p_bserver = Process( + target=self._launch_rpc_service, args=(slef.bserver_, )) + p_bserver.start() + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=self.worker_num_)) + gserver_general_model_service_pb2_grpc.add_GServerGeneralModelService_to_server( + GServerService(self.model_config_path_, + "0.0.0.0:{}".format(self.port_list_[0])), server) + server.add_insecure_port('[::]:{}'.format(self.gport_)) + server.start() + try: + server.join() + p_bserver.join() + except KeyboardInterrupt: + server.stop(0) diff --git a/python/paddle_serving_server/gserver_general_model_service.proto b/python/paddle_serving_server/gserver_general_model_service.proto new file mode 100644 index 0000000000000000000000000000000000000000..1366095b4abafcd64602bd51c80275609c598235 --- /dev/null +++ b/python/paddle_serving_server/gserver_general_model_service.proto @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +message Inst { + repeated bytes data = 1; + repeated string names = 2; + repeated bytes lod = 3; +} + +message Request { + repeated Inst feed_insts = 1; + repeated string fetch_var_names = 2; +}; + +message Response { repeated Inst fetch_insts = 1; } + +service GServerGeneralModelService { + rpc inference(Request) returns (Response) {} +};