提交 4854efb2 编写于 作者: B barrierye

support future var and update class name

上级 5f10282a
......@@ -13,10 +13,10 @@
# limitations under the License.
# pylint: disable=doc-string-missing
from paddle_serving_client import GClient
from paddle_serving_client import MultiLangClient
import sys
client = GClient()
client = MultiLangClient()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
......@@ -27,5 +27,6 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
future = client.predict(feed={"x": data[0][0]}, fetch=["price"], asyn=True)
fetch_map = future.result()
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......@@ -17,7 +17,7 @@ import os
import sys
from paddle_serving_server import OpMaker
from paddle_serving_server import OpSeqMaker
from paddle_serving_server import GServer
from paddle_serving_server import MultiLangServer
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
......@@ -29,7 +29,7 @@ op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(response_op)
server = GServer()
server = MultiLangServer()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
......
......@@ -24,8 +24,8 @@ import sys
from .serving_client import PredictorRes
import grpc
import gserver_general_model_service_pb2
import gserver_general_model_service_pb2_grpc
import multi_lang_general_model_service_pb2
import multi_lang_general_model_service_pb2_grpc
int_type = 0
float_type = 1
......@@ -378,7 +378,7 @@ class Client(object):
self.client_handle_ = None
class GClient(object):
class MultiLangClient(object):
def __init__(self):
self.channel_ = None
......@@ -389,7 +389,7 @@ class GClient(object):
def connect(self, endpoint):
self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO
self.stub_ = gserver_general_model_service_pb2_grpc.GServerGeneralModelServiceStub(
self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
self.channel_)
def _flatten_list(self, nested_list):
......@@ -427,7 +427,7 @@ class GClient(object):
self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch):
req = gserver_general_model_service_pb2.Request()
req = multi_lang_general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch)
req.feed_var_names.extend(feed.keys())
feed_batch = None
......@@ -439,9 +439,9 @@ class GClient(object):
raise Exception("{} not support".format(type(feed)))
init_feed_names = False
for feed_data in feed_batch:
inst = gserver_general_model_service_pb2.FeedInst()
inst = multi_lang_general_model_service_pb2.FeedInst()
for name in req.feed_var_names:
tensor = gserver_general_model_service_pb2.Tensor()
tensor = multi_lang_general_model_service_pb2.Tensor()
var = feed_data[name]
v_type = self.feed_types_[name]
if v_type == 0: # int64
......@@ -466,7 +466,7 @@ class GClient(object):
req.insts.append(inst)
return req
def _unpack_resp(self, resp, fetch):
def _unpack_resp(self, resp, fetch, need_variant_tag):
result_map = {}
inst = resp.outputs[0].insts[0]
tag = resp.tag
......@@ -482,10 +482,30 @@ class GClient(object):
result_map[name].shape = list(var.shape)
if name in self.lod_tensor_set_:
result_map["{}.lod".format(name)] = np.array(list(var.lod))
return result_map, tag
return result_map if not need_variant_tag else [result_map, tag]
def _done_callback_func(self, fetch, need_variant_tag):
def unpack_resp(resp):
return self._unpack_resp(resp, fetch, need_variant_tag)
def predict(self, feed, fetch, need_variant_tag=False):
return unpack_resp
def predict(self, feed, fetch, need_variant_tag=False, asyn=False):
req = self._pack_feed_data(feed, fetch)
resp = self.stub_.inference(req)
result_map, tag = self._unpack_resp(resp, fetch)
return result_map if not need_variant_tag else [result_map, tag]
if not asyn:
resp = self.stub_.inference(req)
return self._unpack_resp(resp, fetch, need_variant_tag)
else:
call_future = self.stub_.inference.future(req)
return MultiLangPredictFuture(
call_future, self._done_callback_func(fetch, need_variant_tag))
class MultiLangPredictFuture(object):
def __init__(self, call_future, callback_func):
self.call_future_ = call_future
self.callback_func_ = callback_func
def result(self):
resp = self.call_future_.result()
return self.callback_func_(resp)
......@@ -27,11 +27,10 @@ import fcntl
import numpy as np
import grpc
import gserver_general_model_service_pb2
import gserver_general_model_service_pb2_grpc
import multi_lang_general_model_service_pb2
import multi_lang_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures
import itertools
class OpMaker(object):
......@@ -438,8 +437,8 @@ class Server(object):
os.system(command)
class GServerService(
gserver_general_model_service_pb2_grpc.GServerGeneralModelService):
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
......@@ -505,13 +504,13 @@ class GServerService(
return feed_batch, fetch_names
def _pack_resp_package(self, result, fetch_names, tag):
resp = gserver_general_model_service_pb2.Response()
resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily
model_output = gserver_general_model_service_pb2.ModelOutput()
inst = gserver_general_model_service_pb2.FetchInst()
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name)
tensor = gserver_general_model_service_pb2.Tensor()
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if v_type == 0: # int64
tensor.int64_data.extend(
......@@ -537,7 +536,7 @@ class GServerService(
return self._pack_resp_package(data, fetch_names, tag)
class GServer(object):
class MultiLangServer(object):
def __init__(self, worker_num=2):
self.bserver_ = Server()
self.worker_num_ = worker_num
......@@ -547,7 +546,8 @@ class GServer(object):
def load_model_config(self, model_config_path):
if not isinstance(model_config_path, str):
raise Exception("GServer only supports multi-model temporarily")
raise Exception(
"MultiLangServer only supports multi-model temporarily")
self.bserver_.load_model_config(model_config_path)
self.model_config_path_ = model_config_path
......@@ -578,9 +578,10 @@ class GServer(object):
p_bserver.start()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_))
gserver_general_model_service_pb2_grpc.add_GServerGeneralModelServiceServicer_to_server(
GServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]), server)
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]),
server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
p_bserver.join()
......
......@@ -44,6 +44,6 @@ message ModelOutput {
optional string engine_name = 2;
}
service GServerGeneralModelService {
service MultiLangGeneralModelService {
rpc inference(Request) returns (Response) {}
};
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册