提交 f88fc43f 编写于 作者: W wangjiawei04

implement grpc + local predictor

上级 58a50288
...@@ -76,7 +76,7 @@ class Debugger(object): ...@@ -76,7 +76,7 @@ class Debugger(object):
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config) self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None): def predict(self, feed=None, fetch=None, batch=True):
if feed is None or fetch is None: if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction") raise ValueError("You should specify feed and fetch for prediction")
fetch_list = [] fetch_list = []
...@@ -116,15 +116,26 @@ class Debugger(object): ...@@ -116,15 +116,26 @@ class Debugger(object):
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for name in input_names: for name in input_names:
print(feed)
if isinstance(feed[name], list): if isinstance(feed[name], list):
feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[ feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
name]) name])
if self.feed_types_[name] == 0: if self.feed_types_[name] == 0:
feed[name] = feed[name].astype("int64") feed[name] = feed[name].astype("int64")
else: elif self.feed_types_[name] == 1:
feed[name] = feed[name].astype("float32") feed[name] = feed[name].astype("float32")
elif self.feed_types_[name] == 2:
feed[name] = feed[name].astype("int32")
else:
raise ValueError("local predictor receives wrong data type")
input_tensor = self.predictor.get_input_tensor(name) input_tensor = self.predictor.get_input_tensor(name)
input_tensor.copy_from_cpu(feed[name]) #TODO:set lods
if "{}.lod".format(name) in feed:
input_tensor.set_lod(feed["{}.lod".format(name)])
if batch == True:
input_tensor.copy_from_cpu(feed[name][np.newaxis,:])
else:
input_tensor.copy_from_cpu(feed[name])
output_tensors = [] output_tensors = []
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
for output_name in output_names: for output_name in output_names:
......
...@@ -453,14 +453,18 @@ class Server(object): ...@@ -453,14 +453,18 @@ class Server(object):
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer): MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path, is_multi_model, endpoints): def __init__(self, model_config_path, is_multi_model, endpoints, local_predictor=None):
self.is_multi_model_ = is_multi_model
self.model_config_path_ = model_config_path self.model_config_path_ = model_config_path
self.endpoints_ = endpoints
with open(self.model_config_path_) as f: with open(self.model_config_path_) as f:
self.model_config_str_ = str(f.read()) self.model_config_str_ = str(f.read())
self._parse_model_config(self.model_config_str_) self._parse_model_config(self.model_config_str_)
self._init_bclient(self.model_config_path_, self.endpoints_) self.is_multi_model_ = is_multi_model
if local_predictor == None:
self.local_predictor = None
self.endpoints_ = endpoints
self._init_bclient(self.model_config_path_, self.endpoints_)
else:
self.local_predictor = local_predictor
def _init_bclient(self, model_config_path, endpoints, timeout_ms=None): def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
from paddle_serving_client import Client from paddle_serving_client import Client
...@@ -585,8 +589,12 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -585,8 +589,12 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
def Inference(self, request, context): def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request( feed_dict, fetch_names, is_python = self._unpack_inference_request(
request) request)
ret = self.bclient_.predict( if self.local_predictor == None:
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
else:
ret = [self.local_predictor.predict(
feed=feed_dict[0], fetch=fetch_names), "VariantTagNeeded"]
return self._pack_inference_response(ret, fetch_names, is_python) return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context): def GetClientConfig(self, request, context):
...@@ -596,8 +604,14 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -596,8 +604,14 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
class MultiLangServer(object): class MultiLangServer(object):
def __init__(self): def __init__(self, use_local_predictor=False):
self.bserver_ = Server() self.use_local_predictor = use_local_predictor
if use_local_predictor:
from paddle_serving_app.local_predict import Debugger
self.local_predictor_ = Debugger()
else:
self.bserver_ = Server()
self.local_predictor_ = None
self.worker_num_ = 4 self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024 self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000 self.concurrency_ = 100000
...@@ -620,6 +634,9 @@ class MultiLangServer(object): ...@@ -620,6 +634,9 @@ class MultiLangServer(object):
"max_body_size is less than default value, will use default value in service." "max_body_size is less than default value, will use default value in service."
) )
def set_use_local_predictor(self, mode):
self.use_local_predictor = True
def set_port(self, port): def set_port(self, port):
self.gport_ = port self.gport_ = port
...@@ -645,16 +662,23 @@ class MultiLangServer(object): ...@@ -645,16 +662,23 @@ class MultiLangServer(object):
self.bserver_.use_mkl(flag) self.bserver_.use_mkl(flag)
def load_model_config(self, server_config_paths, client_config_path=None): def load_model_config(self, server_config_paths, client_config_path=None):
self.bserver_.load_model_config(server_config_paths) if self.use_local_predictor == False:
if client_config_path is None: self.bserver_.load_model_config(server_config_paths)
if client_config_path is None:
if isinstance(server_config_paths, dict):
self.is_multi_model_ = True
client_config_path = '{}/serving_server_conf.prototxt'.format(
list(server_config_paths.items())[0][1])
else:
client_config_path = '{}/serving_server_conf.prototxt'.format(
server_config_paths)
self.bclient_config_path_ = client_config_path
else:
if isinstance(server_config_paths, dict): if isinstance(server_config_paths, dict):
self.is_multi_model_ = True raise ValueError("local predictor does not support model essemble")
client_config_path = '{}/serving_server_conf.prototxt'.format( client_config_path = '{}/serving_server_conf.prototxt'.format(server_config_paths)
list(server_config_paths.items())[0][1]) self.local_predictor_.load_model_config(server_config_paths, profile=False)
else: self.local_config_path_ = client_config_path
client_config_path = '{}/serving_server_conf.prototxt'.format(
server_config_paths)
self.bclient_config_path_ = client_config_path
def prepare_server(self, def prepare_server(self,
workdir=None, workdir=None,
...@@ -662,19 +686,20 @@ class MultiLangServer(object): ...@@ -662,19 +686,20 @@ class MultiLangServer(object):
device="cpu", device="cpu",
cube_conf=None): cube_conf=None):
if not self._port_is_available(port): if not self._port_is_available(port):
raise SystemExit("Prot {} is already used".format(port)) raise SystemExit("Port {} is already used".format(port))
default_port = 12000 if self.use_local_predictor == False:
self.port_list_ = [] default_port = 12000
for i in range(1000): self.port_list_ = []
if default_port + i != port and self._port_is_available(default_port for i in range(1000):
+ i): if default_port + i != port and self._port_is_available(default_port
self.port_list_.append(default_port + i) + i):
break self.port_list_.append(default_port + i)
self.bserver_.prepare_server( break
workdir=workdir, self.bserver_.prepare_server(
port=self.port_list_[0], workdir=workdir,
device=device, port=self.port_list_[0],
cube_conf=cube_conf) device=device,
cube_conf=cube_conf)
self.set_port(port) self.set_port(port)
def _launch_brpc_service(self, bserver): def _launch_brpc_service(self, bserver):
...@@ -687,20 +712,29 @@ class MultiLangServer(object): ...@@ -687,20 +712,29 @@ class MultiLangServer(object):
return result != 0 return result != 0
def run_server(self): def run_server(self):
p_bserver = Process( if self.use_local_predictor is False:
target=self._launch_brpc_service, args=(self.bserver_, )) print("brpc server process start")
p_bserver.start() p_bserver = Process(
target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start()
options = [('grpc.max_send_message_length', self.body_size_), options = [('grpc.max_send_message_length', self.body_size_),
('grpc.max_receive_message_length', self.body_size_)] ('grpc.max_receive_message_length', self.body_size_)]
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_), futures.ThreadPoolExecutor(max_workers=self.worker_num_),
options=options, options=options,
maximum_concurrent_rpcs=self.concurrency_) maximum_concurrent_rpcs=self.concurrency_)
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server( if self.use_local_predictor is False:
MultiLangServerServiceServicer( multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
self.bclient_config_path_, self.is_multi_model_, MultiLangServerServiceServicer(
["0.0.0.0:{}".format(self.port_list_[0])]), server) self.bclient_config_path_, self.is_multi_model_,
["0.0.0.0:{}".format(self.port_list_[0])], None), server)
else:
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerServiceServicer(
self.local_config_path_, None, None, self.local_predictor_), server)
server.add_insecure_port('[::]:{}'.format(self.gport_)) server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start() server.start()
p_bserver.join() if self.use_local_predictor is False:
p_bserver.join()
server.wait_for_termination() server.wait_for_termination()
...@@ -58,6 +58,11 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -58,6 +58,11 @@ def parse_args(): # pylint: disable=doc-string-missing
default=False, default=False,
action="store_true", action="store_true",
help="Use Multi-language-service") help="Use Multi-language-service")
parser.add_argument(
"--local_predict",
default=False,
action="store_true",
help="Use Local Predictor")
return parser.parse_args() return parser.parse_args()
...@@ -73,6 +78,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -73,6 +78,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
max_body_size = args.max_body_size max_body_size = args.max_body_size
use_mkl = args.use_mkl use_mkl = args.use_mkl
use_multilang = args.use_multilang use_multilang = args.use_multilang
local_predict=args.local_predict
if model == "": if model == "":
print("You must specify your serving model") print("You must specify your serving model")
...@@ -91,17 +97,20 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -91,17 +97,20 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server = None server = None
if use_multilang: if use_multilang:
server = serving.MultiLangServer() server = serving.MultiLangServer(local_predict)
else: else:
if local_predict == True:
raise ValueError("local predict can only run with multilang")
server = serving.Server() server = serving.Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) if local_predict is False:
server.set_num_threads(thread_num) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_memory_optimize(mem_optim) server.set_num_threads(thread_num)
server.set_ir_optimize(ir_optim) server.set_memory_optimize(mem_optim)
server.use_mkl(use_mkl) server.set_ir_optimize(ir_optim)
server.set_max_body_size(max_body_size) server.use_mkl(use_mkl)
server.set_port(port) server.set_max_body_size(max_body_size)
server.set_port(port)
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
server.run_server() server.run_server()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册