提交 1458203e 编写于 作者: B barrierye

support future var and update class name

上级 b7bd963d
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
from paddle_serving_client import GClient from paddle_serving_client import MultiLangClient
import sys import sys
client = GClient() client = MultiLangClient()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"]) client.connect(["127.0.0.1:9393"])
...@@ -27,5 +27,6 @@ test_reader = paddle.batch( ...@@ -27,5 +27,6 @@ test_reader = paddle.batch(
batch_size=1) batch_size=1)
for data in test_reader(): for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"]) future = client.predict(feed={"x": data[0][0]}, fetch=["price"], asyn=True)
fetch_map = future.result()
print("{} {}".format(fetch_map["price"][0], data[0][1][0])) print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import sys import sys
from paddle_serving_server import OpMaker from paddle_serving_server import OpMaker
from paddle_serving_server import OpSeqMaker from paddle_serving_server import OpSeqMaker
from paddle_serving_server import GServer from paddle_serving_server import MultiLangServer
op_maker = OpMaker() op_maker = OpMaker()
read_op = op_maker.create('general_reader') read_op = op_maker.create('general_reader')
...@@ -29,7 +29,7 @@ op_seq_maker.add_op(read_op) ...@@ -29,7 +29,7 @@ op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(response_op) op_seq_maker.add_op(response_op)
server = GServer() server = MultiLangServer()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1]) server.load_model_config(sys.argv[1])
server.prepare_server(workdir="work_dir1", port=9393, device="cpu") server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
......
...@@ -24,8 +24,8 @@ import sys ...@@ -24,8 +24,8 @@ import sys
from .serving_client import PredictorRes from .serving_client import PredictorRes
import grpc import grpc
import gserver_general_model_service_pb2 import multi_lang_general_model_service_pb2
import gserver_general_model_service_pb2_grpc import multi_lang_general_model_service_pb2_grpc
int_type = 0 int_type = 0
float_type = 1 float_type = 1
...@@ -378,7 +378,7 @@ class Client(object): ...@@ -378,7 +378,7 @@ class Client(object):
self.client_handle_ = None self.client_handle_ = None
class GClient(object): class MultiLangClient(object):
def __init__(self): def __init__(self):
self.channel_ = None self.channel_ = None
...@@ -389,7 +389,7 @@ class GClient(object): ...@@ -389,7 +389,7 @@ class GClient(object):
def connect(self, endpoint): def connect(self, endpoint):
self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO self.channel_ = grpc.insecure_channel(endpoint[0]) #TODO
self.stub_ = gserver_general_model_service_pb2_grpc.GServerGeneralModelServiceStub( self.stub_ = multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelServiceStub(
self.channel_) self.channel_)
def _flatten_list(self, nested_list): def _flatten_list(self, nested_list):
...@@ -427,7 +427,7 @@ class GClient(object): ...@@ -427,7 +427,7 @@ class GClient(object):
self.lod_tensor_set_.add(var.alias_name) self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch): def _pack_feed_data(self, feed, fetch):
req = gserver_general_model_service_pb2.Request() req = multi_lang_general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch) req.fetch_var_names.extend(fetch)
req.feed_var_names.extend(feed.keys()) req.feed_var_names.extend(feed.keys())
feed_batch = None feed_batch = None
...@@ -439,9 +439,9 @@ class GClient(object): ...@@ -439,9 +439,9 @@ class GClient(object):
raise Exception("{} not support".format(type(feed))) raise Exception("{} not support".format(type(feed)))
init_feed_names = False init_feed_names = False
for feed_data in feed_batch: for feed_data in feed_batch:
inst = gserver_general_model_service_pb2.FeedInst() inst = multi_lang_general_model_service_pb2.FeedInst()
for name in req.feed_var_names: for name in req.feed_var_names:
tensor = gserver_general_model_service_pb2.Tensor() tensor = multi_lang_general_model_service_pb2.Tensor()
var = feed_data[name] var = feed_data[name]
v_type = self.feed_types_[name] v_type = self.feed_types_[name]
if v_type == 0: # int64 if v_type == 0: # int64
...@@ -466,7 +466,7 @@ class GClient(object): ...@@ -466,7 +466,7 @@ class GClient(object):
req.insts.append(inst) req.insts.append(inst)
return req return req
def _unpack_resp(self, resp, fetch): def _unpack_resp(self, resp, fetch, need_variant_tag):
result_map = {} result_map = {}
inst = resp.outputs[0].insts[0] inst = resp.outputs[0].insts[0]
tag = resp.tag tag = resp.tag
...@@ -482,10 +482,30 @@ class GClient(object): ...@@ -482,10 +482,30 @@ class GClient(object):
result_map[name].shape = list(var.shape) result_map[name].shape = list(var.shape)
if name in self.lod_tensor_set_: if name in self.lod_tensor_set_:
result_map["{}.lod".format(name)] = np.array(list(var.lod)) result_map["{}.lod".format(name)] = np.array(list(var.lod))
return result_map, tag return result_map if not need_variant_tag else [result_map, tag]
def _done_callback_func(self, fetch, need_variant_tag):
def unpack_resp(resp):
return self._unpack_resp(resp, fetch, need_variant_tag)
def predict(self, feed, fetch, need_variant_tag=False): return unpack_resp
def predict(self, feed, fetch, need_variant_tag=False, asyn=False):
req = self._pack_feed_data(feed, fetch) req = self._pack_feed_data(feed, fetch)
resp = self.stub_.inference(req) if not asyn:
result_map, tag = self._unpack_resp(resp, fetch) resp = self.stub_.inference(req)
return result_map if not need_variant_tag else [result_map, tag] return self._unpack_resp(resp, fetch, need_variant_tag)
else:
call_future = self.stub_.inference.future(req)
return MultiLangPredictFuture(
call_future, self._done_callback_func(fetch, need_variant_tag))
class MultiLangPredictFuture(object):
def __init__(self, call_future, callback_func):
self.call_future_ = call_future
self.callback_func_ = callback_func
def result(self):
resp = self.call_future_.result()
return self.callback_func_(resp)
...@@ -27,11 +27,10 @@ import fcntl ...@@ -27,11 +27,10 @@ import fcntl
import numpy as np import numpy as np
import grpc import grpc
import gserver_general_model_service_pb2 import multi_lang_general_model_service_pb2
import gserver_general_model_service_pb2_grpc import multi_lang_general_model_service_pb2_grpc
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from concurrent import futures from concurrent import futures
import itertools
class OpMaker(object): class OpMaker(object):
...@@ -438,8 +437,8 @@ class Server(object): ...@@ -438,8 +437,8 @@ class Server(object):
os.system(command) os.system(command)
class GServerService( class MultiLangServerService(
gserver_general_model_service_pb2_grpc.GServerGeneralModelService): multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints): def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client from paddle_serving_client import Client
self._parse_model_config(model_config_path) self._parse_model_config(model_config_path)
...@@ -505,13 +504,13 @@ class GServerService( ...@@ -505,13 +504,13 @@ class GServerService(
return feed_batch, fetch_names return feed_batch, fetch_names
def _pack_resp_package(self, result, fetch_names, tag): def _pack_resp_package(self, result, fetch_names, tag):
resp = gserver_general_model_service_pb2.Response() resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily # Only one model is supported temporarily
model_output = gserver_general_model_service_pb2.ModelOutput() model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = gserver_general_model_service_pb2.FetchInst() inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names): for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name) # model_output.fetch_var_names.append(name)
tensor = gserver_general_model_service_pb2.Tensor() tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name] v_type = self.fetch_types_[name]
if v_type == 0: # int64 if v_type == 0: # int64
tensor.int64_data.extend( tensor.int64_data.extend(
...@@ -537,7 +536,7 @@ class GServerService( ...@@ -537,7 +536,7 @@ class GServerService(
return self._pack_resp_package(data, fetch_names, tag) return self._pack_resp_package(data, fetch_names, tag)
class GServer(object): class MultiLangServer(object):
def __init__(self, worker_num=2): def __init__(self, worker_num=2):
self.bserver_ = Server() self.bserver_ = Server()
self.worker_num_ = worker_num self.worker_num_ = worker_num
...@@ -547,7 +546,8 @@ class GServer(object): ...@@ -547,7 +546,8 @@ class GServer(object):
def load_model_config(self, model_config_path): def load_model_config(self, model_config_path):
if not isinstance(model_config_path, str): if not isinstance(model_config_path, str):
raise Exception("GServer only supports multi-model temporarily") raise Exception(
"MultiLangServer only supports multi-model temporarily")
self.bserver_.load_model_config(model_config_path) self.bserver_.load_model_config(model_config_path)
self.model_config_path_ = model_config_path self.model_config_path_ = model_config_path
...@@ -578,9 +578,10 @@ class GServer(object): ...@@ -578,9 +578,10 @@ class GServer(object):
p_bserver.start() p_bserver.start()
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_)) futures.ThreadPoolExecutor(max_workers=self.worker_num_))
gserver_general_model_service_pb2_grpc.add_GServerGeneralModelServiceServicer_to_server( multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
GServerService(self.model_config_path_, MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]), server) ["0.0.0.0:{}".format(self.port_list_[0])]),
server)
server.add_insecure_port('[::]:{}'.format(self.gport_)) server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start() server.start()
p_bserver.join() p_bserver.join()
......
...@@ -44,6 +44,6 @@ message ModelOutput { ...@@ -44,6 +44,6 @@ message ModelOutput {
optional string engine_name = 2; optional string engine_name = 2;
} }
service GServerGeneralModelService { service MultiLangGeneralModelService {
rpc inference(Request) returns (Response) {} rpc inference(Request) returns (Response) {}
}; };
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册