提交 ba7d67f8 编写于 作者: G gongweibao

add get_config test=develop

上级 fbd1ecd9
...@@ -43,8 +43,17 @@ message Response { ...@@ -43,8 +43,17 @@ message Response {
message ModelOutput { message ModelOutput {
repeated FetchInst insts = 1; repeated FetchInst insts = 1;
optional string engine_name = 2; optional string engine_name = 2;
};
message EmptyRequest{
};
message ServingConfig{
int32 max_batch_size = 1;
string proto_txt = 2;
} }
service MultiLangGeneralModelService { service MultiLangGeneralModelService {
rpc inference(Request) returns (Response) {} rpc inference(Request) returns (Response) {}
rpc get_config(EmptyRequest) returns (ServingConfig) {}
}; };
...@@ -26,7 +26,7 @@ import grpc ...@@ -26,7 +26,7 @@ import grpc
from .proto import multi_lang_general_model_service_pb2 from .proto import multi_lang_general_model_service_pb2
sys.path.append( sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto')) os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc from .proto import grpc_pb2 as grpc_pb2
int_type = 0 int_type = 0
float_type = 1 float_type = 1
...@@ -384,17 +384,29 @@ class Client(object): ...@@ -384,17 +384,29 @@ class Client(object):
class MultiLangClient(object): class MultiLangClient(object):
def __init__(self): def __init__(self):
self.channel_ = None self.channel_ = None
self._config = None
def load_client_config(self, path): def load_client_config(self, path):
if not isinstance(path, str): if not isinstance(path, str):
raise Exception("GClient only supports multi-model temporarily") raise Exception("GClient only supports multi-model temporarily")
self._parse_model_config(path) with open(path, 'r') as f:
proto_txt = str(f.read())
self._parse_model_config(proto_txt)
def connect(self, endpoint): def _load_client_config(self, stub):
req= grpc_pb2.ServingConfig()
self._config = self.stub_.get_client_proto_text(req)
self._parse_model_config(config.proto_txt)
def connect(self, endpoint, use_remote_config=True):
self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO
self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub( self.stub_ = grpc_pb2.MultiLangGeneralModelServiceStub(
self.channel_) self.channel_)
if use_remote_config:
self._load_client_config(stub)
def _flatten_list(self, nested_list): def _flatten_list(self, nested_list):
for item in nested_list: for item in nested_list:
if isinstance(item, (list, tuple)): if isinstance(item, (list, tuple)):
...@@ -403,11 +415,9 @@ class MultiLangClient(object): ...@@ -403,11 +415,9 @@ class MultiLangClient(object):
else: else:
yield item yield item
def _parse_model_config(self, model_config_path): def _parse_model_config(self, proto_txt):
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(model_config_path, 'r') model_conf = google.protobuf.text_format.Merge(proto_txt, model_conf)
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {} self.feed_types_ = {}
self.feed_shapes_ = {} self.feed_shapes_ = {}
...@@ -539,7 +549,6 @@ class MultiLangClient(object): ...@@ -539,7 +549,6 @@ class MultiLangClient(object):
is_python=is_python, is_python=is_python,
need_variant_tag=need_variant_tag)) need_variant_tag=need_variant_tag))
class MultiLangPredictFuture(object): class MultiLangPredictFuture(object):
def __init__(self, call_future, callback_func): def __init__(self, call_future, callback_func):
self.call_future_ = call_future self.call_future_ = call_future
......
...@@ -27,11 +27,11 @@ import fcntl ...@@ -27,11 +27,11 @@ import fcntl
import numpy as np import numpy as np
import grpc import grpc
from .proto import multi_lang_general_model_service_pb2 from .proto import multi_lang_general_model_service_pb2 as pb2
import sys import sys
sys.path.append( sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto')) os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc from .proto import multi_lang_general_model_service_pb2_grpc as grpc_pb2
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from concurrent import futures from concurrent import futures
...@@ -441,21 +441,25 @@ class Server(object): ...@@ -441,21 +441,25 @@ class Server(object):
class MultiLangServerService( class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService): pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints): def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client from paddle_serving_client import Client
self._parse_model_config(model_config_path)
path = "{}/serving_server_conf.prototxt".format(model_config_path)
with open(path, 'r') as f:
proto_txt = str(f.read())
self._parse_model_config(proto_txt)
self.bclient_ = Client() self.bclient_ = Client()
self.bclient_.load_client_config( self.bclient_.load_client_config(path)
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.connect(endpoints) self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path): self._max_batch_size = -1 # <=0:infinite
self._proto_txt = proto_txt
def _parse_model_config(self, proto_txt):
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path), model_conf = google.protobuf.text_format.Merge(proto_txt), model_conf)
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {} self.feed_types_ = {}
self.feed_shapes_ = {} self.feed_shapes_ = {}
...@@ -511,12 +515,12 @@ class MultiLangServerService( ...@@ -511,12 +515,12 @@ class MultiLangServerService(
return feed_batch, fetch_names, is_python return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, is_python, tag): def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response() resp = pb2.Response()
# Only one model is supported temporarily # Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput() model_output = pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst() inst = pb2.FetchInst()
for idx, name in enumerate(fetch_names): for idx, name in enumerate(fetch_names):
tensor = multi_lang_general_model_service_pb2.Tensor() tensor = pb2.Tensor()
v_type = self.fetch_types_[name] v_type = self.fetch_types_[name]
if is_python: if is_python:
tensor.data = result[name].tobytes() tensor.data = result[name].tobytes()
...@@ -542,6 +546,22 @@ class MultiLangServerService( ...@@ -542,6 +546,22 @@ class MultiLangServerService(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, is_python, tag) return self._pack_resp_package(data, fetch_names, is_python, tag)
def get_config(self, request, context):
key = "PADDLE_SERVING_MAX_BATCH_SIZE"
max_batch_size = os.getenv(key)
if max_batch_size:
try:
max_batch_size=int(max_batch_size)
self._max_batch_size = max_batch_size
except Exception as e:
print("invalid value:{} of {}".format(max_batch_size, key))
response = pb2.ServingConfig()
response.proto_txt = self.proto_txt
response.max_batch_size = self._max_batch_size
return response
class MultiLangServer(object): class MultiLangServer(object):
def __init__(self, worker_num=2): def __init__(self, worker_num=2):
...@@ -585,7 +605,7 @@ class MultiLangServer(object): ...@@ -585,7 +605,7 @@ class MultiLangServer(object):
p_bserver.start() p_bserver.start()
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_)) futures.ThreadPoolExecutor(max_workers=self.worker_num_))
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server( pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(self.model_config_path_, MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]), ["0.0.0.0:{}".format(self.port_list_[0])]),
server) server)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册