提交 f88fc43f 编写于 作者: W wangjiawei04

implement grpc + local predictor

上级 58a50288
......@@ -76,7 +76,7 @@ class Debugger(object):
config.switch_use_feed_fetch_ops(False)
self.predictor = create_paddle_predictor(config)
def predict(self, feed=None, fetch=None):
def predict(self, feed=None, fetch=None, batch=True):
if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
fetch_list = []
......@@ -116,14 +116,25 @@ class Debugger(object):
input_names = self.predictor.get_input_names()
for name in input_names:
print(feed)
if isinstance(feed[name], list):
feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
name])
if self.feed_types_[name] == 0:
feed[name] = feed[name].astype("int64")
else:
elif self.feed_types_[name] == 1:
feed[name] = feed[name].astype("float32")
elif self.feed_types_[name] == 2:
feed[name] = feed[name].astype("int32")
else:
raise ValueError("local predictor receives wrong data type")
input_tensor = self.predictor.get_input_tensor(name)
#TODO:set lods
if "{}.lod".format(name) in feed:
input_tensor.set_lod(feed["{}.lod".format(name)])
if batch == True:
input_tensor.copy_from_cpu(feed[name][np.newaxis,:])
else:
input_tensor.copy_from_cpu(feed[name])
output_tensors = []
output_names = self.predictor.get_output_names()
......
......@@ -453,14 +453,18 @@ class Server(object):
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path, is_multi_model, endpoints):
self.is_multi_model_ = is_multi_model
def __init__(self, model_config_path, is_multi_model, endpoints, local_predictor=None):
self.model_config_path_ = model_config_path
self.endpoints_ = endpoints
with open(self.model_config_path_) as f:
self.model_config_str_ = str(f.read())
self._parse_model_config(self.model_config_str_)
self.is_multi_model_ = is_multi_model
if local_predictor == None:
self.local_predictor = None
self.endpoints_ = endpoints
self._init_bclient(self.model_config_path_, self.endpoints_)
else:
self.local_predictor = local_predictor
def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
from paddle_serving_client import Client
......@@ -585,8 +589,12 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request(
request)
if self.local_predictor == None:
ret = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
else:
ret = [self.local_predictor.predict(
feed=feed_dict[0], fetch=fetch_names), "VariantTagNeeded"]
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
......@@ -596,8 +604,14 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
class MultiLangServer(object):
def __init__(self):
def __init__(self, use_local_predictor=False):
self.use_local_predictor = use_local_predictor
if use_local_predictor:
from paddle_serving_app.local_predict import Debugger
self.local_predictor_ = Debugger()
else:
self.bserver_ = Server()
self.local_predictor_ = None
self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000
......@@ -620,6 +634,9 @@ class MultiLangServer(object):
"max_body_size is less than default value, will use default value in service."
)
def set_use_local_predictor(self, mode):
self.use_local_predictor = True
def set_port(self, port):
self.gport_ = port
......@@ -645,6 +662,7 @@ class MultiLangServer(object):
self.bserver_.use_mkl(flag)
def load_model_config(self, server_config_paths, client_config_path=None):
if self.use_local_predictor == False:
self.bserver_.load_model_config(server_config_paths)
if client_config_path is None:
if isinstance(server_config_paths, dict):
......@@ -655,6 +673,12 @@ class MultiLangServer(object):
client_config_path = '{}/serving_server_conf.prototxt'.format(
server_config_paths)
self.bclient_config_path_ = client_config_path
else:
if isinstance(server_config_paths, dict):
raise ValueError("local predictor does not support model essemble")
client_config_path = '{}/serving_server_conf.prototxt'.format(server_config_paths)
self.local_predictor_.load_model_config(server_config_paths, profile=False)
self.local_config_path_ = client_config_path
def prepare_server(self,
workdir=None,
......@@ -662,7 +686,8 @@ class MultiLangServer(object):
device="cpu",
cube_conf=None):
if not self._port_is_available(port):
raise SystemExit("Prot {} is already used".format(port))
raise SystemExit("Port {} is already used".format(port))
if self.use_local_predictor == False:
default_port = 12000
self.port_list_ = []
for i in range(1000):
......@@ -687,20 +712,29 @@ class MultiLangServer(object):
return result != 0
def run_server(self):
if self.use_local_predictor is False:
print("brpc server process start")
p_bserver = Process(
target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start()
options = [('grpc.max_send_message_length', self.body_size_),
('grpc.max_receive_message_length', self.body_size_)]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_),
options=options,
maximum_concurrent_rpcs=self.concurrency_)
if self.use_local_predictor is False:
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerServiceServicer(
self.bclient_config_path_, self.is_multi_model_,
["0.0.0.0:{}".format(self.port_list_[0])]), server)
["0.0.0.0:{}".format(self.port_list_[0])], None), server)
else:
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerServiceServicer(
self.local_config_path_, None, None, self.local_predictor_), server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
if self.use_local_predictor is False:
p_bserver.join()
server.wait_for_termination()
......@@ -58,6 +58,11 @@ def parse_args(): # pylint: disable=doc-string-missing
default=False,
action="store_true",
help="Use Multi-language-service")
parser.add_argument(
"--local_predict",
default=False,
action="store_true",
help="Use Local Predictor")
return parser.parse_args()
......@@ -73,6 +78,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
max_body_size = args.max_body_size
use_mkl = args.use_mkl
use_multilang = args.use_multilang
local_predict=args.local_predict
if model == "":
print("You must specify your serving model")
......@@ -91,17 +97,20 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server = None
if use_multilang:
server = serving.MultiLangServer()
server = serving.MultiLangServer(local_predict)
else:
if local_predict == True:
raise ValueError("local predict can only run with multilang")
server = serving.Server()
if local_predict is False:
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim)
server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size)
server.set_port(port)
server.set_port(port)
server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device)
server.run_server()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册