提交 ba7d67f8 编写于 作者: G gongweibao

add get_config test=develop

上级 fbd1ecd9
......@@ -43,8 +43,17 @@ message Response {
message ModelOutput {
repeated FetchInst insts = 1;
optional string engine_name = 2;
};
message EmptyRequest{
};
message ServingConfig{
int32 max_batch_size = 1;
string proto_txt = 2;
}
service MultiLangGeneralModelService {
rpc inference(Request) returns (Response) {}
rpc get_config(EmptyRequest) returns (ServingConfig) {}
};
......@@ -26,7 +26,7 @@ import grpc
from .proto import multi_lang_general_model_service_pb2
sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc
from .proto import grpc_pb2 as grpc_pb2
int_type = 0
float_type = 1
......@@ -384,17 +384,29 @@ class Client(object):
class MultiLangClient(object):
def __init__(self):
self.channel_ = None
self._config = None
def load_client_config(self, path):
if not isinstance(path, str):
raise Exception("GClient only supports multi-model temporarily")
self._parse_model_config(path)
with open(path, 'r') as f:
proto_txt = str(f.read())
self._parse_model_config(proto_txt)
def connect(self, endpoint):
def _load_client_config(self, stub):
req= grpc_pb2.ServingConfig()
self._config = self.stub_.get_client_proto_text(req)
self._parse_model_config(config.proto_txt)
def connect(self, endpoint, use_remote_config=True):
self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO
self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
self.stub_ = grpc_pb2.MultiLangGeneralModelServiceStub(
self.channel_)
if use_remote_config:
self._load_client_config(stub)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
......@@ -403,11 +415,9 @@ class MultiLangClient(object):
else:
yield item
def _parse_model_config(self, model_config_path):
def _parse_model_config(self, proto_txt):
model_conf = m_config.GeneralModelConfig()
f = open(model_config_path, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
model_conf = google.protobuf.text_format.Merge(proto_txt, model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
......@@ -539,7 +549,6 @@ class MultiLangClient(object):
is_python=is_python,
need_variant_tag=need_variant_tag))
class MultiLangPredictFuture(object):
def __init__(self, call_future, callback_func):
self.call_future_ = call_future
......
......@@ -27,11 +27,11 @@ import fcntl
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
from .proto import multi_lang_general_model_service_pb2 as pb2
import sys
sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
from .proto import multi_lang_general_model_service_pb2_grpc
from .proto import multi_lang_general_model_service_pb2_grpc as grpc_pb2
from multiprocessing import Pool, Process
from concurrent import futures
......@@ -441,21 +441,25 @@ class Server(object):
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
path = "{}/serving_server_conf.prototxt".format(model_config_path)
with open(path, 'r') as f:
proto_txt = str(f.read())
self._parse_model_config(proto_txt)
self.bclient_ = Client()
self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.load_client_config(path)
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path):
self._max_batch_size = -1 # <=0:infinite
self._proto_txt = proto_txt
def _parse_model_config(self, proto_txt):
model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path),
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
model_conf = google.protobuf.text_format.Merge(proto_txt), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
......@@ -511,12 +515,12 @@ class MultiLangServerService(
return feed_batch, fetch_names, is_python
def _pack_resp_package(self, result, fetch_names, is_python, tag):
resp = multi_lang_general_model_service_pb2.Response()
resp = pb2.Response()
# Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
model_output = pb2.ModelOutput()
inst = pb2.FetchInst()
for idx, name in enumerate(fetch_names):
tensor = multi_lang_general_model_service_pb2.Tensor()
tensor = pb2.Tensor()
v_type = self.fetch_types_[name]
if is_python:
tensor.data = result[name].tobytes()
......@@ -542,6 +546,22 @@ class MultiLangServerService(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, is_python, tag)
def get_config(self, request, context):
key = "PADDLE_SERVING_MAX_BATCH_SIZE"
max_batch_size = os.getenv(key)
if max_batch_size:
try:
max_batch_size=int(max_batch_size)
self._max_batch_size = max_batch_size
except Exception as e:
print("invalid value:{} of {}".format(max_batch_size, key))
response = pb2.ServingConfig()
response.proto_txt = self.proto_txt
response.max_batch_size = self._max_batch_size
return response
class MultiLangServer(object):
def __init__(self, worker_num=2):
......@@ -585,7 +605,7 @@ class MultiLangServer(object):
p_bserver.start()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_))
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]),
server)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册