提交 33338d8a 编写于 作者: H HexToString

fix multi model python-part

上级 53413a48
......@@ -59,7 +59,7 @@ message SimpleResponse { required int32 err_code = 1; }
message GetClientConfigRequest {}
message GetClientConfigResponse { required string client_config_str = 1; }
message GetClientConfigResponse { repeated string client_config_str_list = 1; }
service MultiLangGeneralModelService {
rpc Inference(InferenceRequest) returns (InferenceResponse) {}
......
......@@ -141,15 +141,24 @@ class Client(object):
from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes
def load_client_config(self, path):
if isinstance(path, str):
path_list = [path]
elif isinstance(path, list):
path_list = path
def load_client_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
from .serving_client import PredictorClient
model_conf = m_config.GeneralModelConfig()
f = open(path_list[0], 'r')
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
......@@ -158,7 +167,7 @@ class Client(object):
# get feed shapes, feed types
# map feed names to index
self.client_handle_ = PredictorClient()
self.client_handle_.init(path_list)
self.client_handle_.init(file_path_list)
if "FLAGS_max_body_size" not in os.environ:
os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
read_env_flags = ["profile_client", "profile_server", "max_body_size"]
......@@ -166,9 +175,9 @@ class Client(object):
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_names_to_idx_ = {}
self.feed_names_to_idx_ = {}#this is not useful
self.lod_tensor_set = set()
self.feed_tensor_len = {}
self.feed_tensor_len = {}#this is only used for shape check
self.key = None
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
......@@ -183,9 +192,9 @@ class Client(object):
counter *= dim
self.feed_tensor_len[var.alias_name] = counter
if len(path_list) > 1:
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(path_list[-1], 'r')
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
......@@ -545,8 +554,8 @@ class MultiLangClient(object):
get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
)
resp = self.stub_.GetClientConfig(get_client_config_req)
model_config_str = resp.client_config_str
self._parse_model_config(model_config_str)
model_config_path_list = resp.client_config_str_list
self._parse_model_config(model_config_path_list)
def _flatten_list(self, nested_list):
for item in nested_list:
......@@ -556,25 +565,42 @@ class MultiLangClient(object):
else:
yield item
def _parse_model_config(self, model_config_str):
def _parse_model_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
model_conf = m_config.GeneralModelConfig()
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
......
......@@ -114,7 +114,8 @@ class OpSeqMaker(object):
workflow_conf.workflows.extend([self.workflow])
return workflow_conf
# TODO:Currently, SDK only supports "Sequence".OpGraphMaker is not useful.
#Config should be changed to adapt command-line for list[dict] or list[list[] ]
class OpGraphMaker(object):
def __init__(self):
self.workflow = server_sdk.Workflow()
......@@ -182,6 +183,9 @@ class Server(object):
"max_body_size is less than default value, will use default value in service."
)
def use_encryption_model(self, flag=False):
self.encryption_model = flag
def set_port(self, port):
self.port = port
......@@ -200,9 +204,6 @@ class Server(object):
def set_ir_optimize(self, flag=False):
self.ir_optimization = flag
def use_encryption_model(self, flag=False):
self.encryption_model = flag
def set_product_name(self, product_name=None):
if product_name == None:
raise ValueError("product_name can't be None.")
......@@ -301,6 +302,15 @@ class Server(object):
# the resource.prototxt file to determine the input and output
# format of the workflow. To ensure that the input and output
# of multiple models are the same.
if isinstance(model_config_paths_args, str):
model_config_paths_args = [model_config_paths_args]
for single_model_config in model_config_paths_args:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
if isinstance(model_config_paths_args, list):
# If there is only one model path, use the default infer_op.
# Because there are several infer_op type, we need to find
......@@ -309,6 +319,8 @@ class Server(object):
'GeneralInferOp', 'GeneralDistKVInferOp',
'GeneralDistKVQuantInferOp','GeneralDetectionOp',
]
# now only support single-workflow.
# TODO:support multi-workflow
model_config_paths_list_idx = 0
for node in self.workflow_conf.workflows[0].nodes:
if node.type in default_engine_types:
......@@ -327,6 +339,7 @@ class Server(object):
model_config_paths_list_idx += 1
if model_config_paths_list_idx == len(model_config_paths_args):
break
#Right now, this is not useful.
elif isinstance(model_config_paths_args, dict):
self.model_config_paths = collections.OrderedDict()
for node_str, path in model_config_paths_args.items():
......@@ -339,7 +352,7 @@ class Server(object):
self.model_conf[node.name] = google.protobuf.text_format.Merge(
str(f.read()), m_config.GeneralModelConfig())
else:
raise Exception("The type of model_config_paths must be str or "
raise Exception("The type of model_config_paths must be str or list or "
"dict({op: model_path}), not {}.".format(
type(model_config_paths_args)))
......@@ -490,3 +503,324 @@ class Server(object):
print("Going to Run Command")
print(command)
os.system(command)
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path_list, is_multi_model, endpoints):
self.is_multi_model_ = is_multi_model
self.model_config_path_list = model_config_path_list
self.endpoints_ = endpoints
self._init_bclient(self.model_config_path_list, self.endpoints_)
self._parse_model_config(self.model_config_path_list)
def _init_bclient(self, model_config_path_list, endpoints, timeout_ms=None):
from paddle_serving_client import Client
self.bclient_ = Client()
if timeout_ms is not None:
self.bclient_.set_rpc_timeout_ms(timeout_ms)
self.bclient_.load_client_config(model_config_path_list)
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
for sub_item in self._flatten_list(item):
yield sub_item
else:
yield item
def _unpack_inference_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
is_python = request.is_python
log_id = request.log_id
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
for idx, name in enumerate(feed_names):
var = feed_inst.tensor_array[idx]
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0: # int64
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1: # float32
data = np.frombuffer(var.data, dtype="float32")
elif v_type == 2: # int32
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
else:
if v_type == 0: # int64
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2: # int32
data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data
if len(var.lod) > 0:
feed_dict["{}.lod".format(name)] = var.lod
feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python, log_id
def _pack_inference_response(self, ret, fetch_names, is_python):
resp = multi_lang_general_model_service_pb2.InferenceResponse()
if ret is None:
resp.err_code = 1
return resp
results, tag = ret
resp.tag = tag
resp.err_code = 0
if not self.is_multi_model_:
results = {'general_infer_0': results}
for model_name, model_result in results.items():
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if is_python:
tensor.data = model_result[name].tobytes()
else:
if v_type == 0: # int64
tensor.int64_data.extend(model_result[name].reshape(-1)
.tolist())
elif v_type == 1: # float32
tensor.float_data.extend(model_result[name].reshape(-1)
.tolist())
elif v_type == 2: # int32
tensor.int_data.extend(model_result[name].reshape(-1)
.tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape))
if "{}.lod".format(name) in model_result:
tensor.lod.extend(model_result["{}.lod".format(name)]
.tolist())
inst.tensor_array.append(tensor)
model_output.insts.append(inst)
model_output.engine_name = model_name
resp.outputs.append(model_output)
return resp
def SetTimeout(self, request, context):
# This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_list, self.endpoints_, timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
def Inference(self, request, context):
feed_batch, fetch_names, is_python, log_id = \
self._unpack_inference_request(request)
ret = self.bclient_.predict(
feed=feed_batch,
fetch=fetch_names,
batch=True,
need_variant_tag=True,
log_id=log_id)
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
#model_config_path_list is list right now.
#dict should be added when graphMaker is used.
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str_list[:] = self.model_config_path_list
return resp
class MultiLangServer(object):
def __init__(self):
self.bserver_ = Server()
self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000
self.is_multi_model_ = False # for model ensemble, which is not useful right now.
def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency
self.bserver_.set_max_concurrency(concurrency)
def set_num_threads(self, threads):
self.worker_num_ = threads
self.bserver_.set_num_threads(threads)
def set_max_body_size(self, body_size):
self.bserver_.set_max_body_size(body_size)
if body_size >= self.body_size_:
self.body_size_ = body_size
else:
print(
"max_body_size is less than default value, will use default value in service."
)
def use_encryption_model(self, flag=False):
self.encryption_model = flag
def set_port(self, port):
self.gport_ = port
def set_reload_interval(self, interval):
self.bserver_.set_reload_interval(interval)
def set_op_sequence(self, op_seq):
self.bserver_.set_op_sequence(op_seq)
def set_op_graph(self, op_graph):
self.bserver_.set_op_graph(op_graph)
def set_memory_optimize(self, flag=False):
self.bserver_.set_memory_optimize(flag)
def set_ir_optimize(self, flag=False):
self.bserver_.set_ir_optimize(flag)
def set_op_sequence(self, op_seq):
self.bserver_.set_op_sequence(op_seq)
def use_mkl(self, flag):
self.bserver_.use_mkl(flag)
def load_model_config(self, server_config_dir_paths, client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
pass
else:
raise Exception("The type of model_config_paths must be str or list"
", not {}.".format(
type(server_config_dir_paths)))
for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
self.bserver_.load_model_config(server_config_dir_paths)
if client_config_path is None:
#now dict is not useful.
if isinstance(server_config_dir_paths, dict):
self.is_multi_model_ = True
client_config_path = []
for server_config_path_items in list(server_config_dir_paths.items()):
client_config_path.append( server_config_path_items[1] )
elif isinstance(server_config_dir_paths, list):
self.is_multi_model_ = False
client_config_path = server_config_dir_paths
else:
raise Exception("The type of model_config_paths must be str or list or "
"dict({op: model_path}), not {}.".format(
type(server_config_dir_paths)))
if isinstance(client_config_path, str):
client_config_path = [client_config_path]
elif isinstance(client_config_path, list):
pass
else:# dict is not support right now.
raise Exception("The type of client_config_path must be str or list or "
"dict({op: model_path}), not {}.".format(
type(client_config_path)))
if len(client_config_path) != len(server_config_dir_paths):
raise Warning("The len(client_config_path) is {}, != len(server_config_dir_paths) {}."
.format( len(client_config_path), len(server_config_dir_paths) )
)
self.bclient_config_path_list = client_config_path
def prepare_server(self,
workdir=None,
port=9292,
device="cpu",
cube_conf=None):
if not self._port_is_available(port):
raise SystemExit("Prot {} is already used".format(port))
default_port = 12000
self.port_list_ = []
for i in range(1000):
if default_port + i != port and self._port_is_available(default_port
+ i):
self.port_list_.append(default_port + i)
break
self.bserver_.prepare_server(
workdir=workdir,
port=self.port_list_[0],
device=device,
cube_conf=cube_conf)
self.set_port(port)
def _launch_brpc_service(self, bserver):
bserver.run_server()
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
return result != 0
def run_server(self):
p_bserver = Process(
target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start()
options = [('grpc.max_send_message_length', self.body_size_),
('grpc.max_receive_message_length', self.body_size_)]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_),
options=options,
maximum_concurrent_rpcs=self.concurrency_)
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerServiceServicer(
self.bclient_config_path_list, self.is_multi_model_,
["0.0.0.0:{}".format(self.port_list_[0])]), server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
p_bserver.join()
server.wait_for_termination()
\ No newline at end of file
......@@ -22,6 +22,7 @@ import sys
import json
import base64
import time
import os
from multiprocessing import Process
from .web_service import WebService, port_is_available
from flask import Flask, request
......@@ -103,6 +104,12 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
if model == "":
print("You must specify your serving model")
exit(-1)
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
import paddle_serving_server as serving
op_maker = serving.OpMaker()
......@@ -162,8 +169,11 @@ class MainService(BaseHTTPRequestHandler):
return False
else:
key = base64.b64decode(post_data["key"].encode())
with open(args.model + "/key", "wb") as f:
f.write(key)
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
with open(single_model_config + "/key", "wb") as f:
f.write(key)
return True
def check_key(self, post_data):
......@@ -171,9 +181,14 @@ class MainService(BaseHTTPRequestHandler):
return False
else:
key = base64.b64decode(post_data["key"].encode())
with open(args.model + "/key", "rb") as f:
cur_key = f.read()
return (key == cur_key)
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
with open(single_model_config + "/key", "rb") as f:
cur_key = f.read()
if key != cur_key:
return False
return True
def start(self, post_data):
post_data = json.loads(post_data.decode('utf-8'))
......@@ -218,6 +233,12 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__":
args = parse_args()
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
if args.name == "None":
if args.use_encryption_model:
p_flag = False
......@@ -232,6 +253,7 @@ if __name__ == "__main__":
start_standard_model(args.port)
else:
service = WebService(name=args.name)
service.load_model_config(args.model)
service.prepare_server(
workdir=args.workdir, port=args.port, device=args.device)
......
......@@ -21,6 +21,7 @@ from paddle_serving_client import Client
from contextlib import closing
import socket
import numpy as np
import os
from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op
......@@ -59,39 +60,65 @@ class WebService(object):
def run_service(self):
self._server.run_server()
def load_model_config(self, model_config):
print("This API will be deprecated later. Please do not use it")
self.model_config = model_config
import os
def load_model_config(self, server_config_dir_paths, client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
pass
for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
self.server_config_dir_paths = server_config_dir_paths
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
if os.path.isdir(model_config):
client_config = "{}/serving_server_conf.prototxt".format(
model_config)
elif os.path.isfile(model_config):
client_config = model_config
file_path_list = []
for single_model_config in self.server_config_dir_paths:
file_path_list.append( "{}/serving_server_conf.prototxt".format(single_model_config) )
model_conf = m_config.GeneralModelConfig()
f = open(client_config, 'r')
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_vars = {var.name: var for var in model_conf.feed_var}
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
if client_config_path == None:
self.client_config_path = self.server_config_dir_paths
def _launch_rpc_service(self):
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
for idx, single_model in enumerate(self.server_config_dir_paths):
infer_op_name = "general_infer"
if len(self.server_config_dir_paths) == 2 and idx == 0:
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(16)
server.set_memory_optimize(self.mem_optim)
server.set_ir_optimize(self.ir_optim)
server.load_model_config(self.model_config)
server.load_model_config(self.server_config_dir_paths)#brpc Server support server_config_dir_paths
server.prepare_server(
workdir=self.workdir, port=self.port_list[0], device=self.device)
server.run_server()
......@@ -126,8 +153,7 @@ class WebService(object):
def _launch_web_service(self):
self.client = Client()
self.client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
self.client.load_client_config(self.client_config_path)
self.client.connect(["0.0.0.0:{}".format(self.port_list[0])])
def get_prediction(self, request):
......@@ -198,8 +224,11 @@ class WebService(object):
def _launch_local_predictor(self):
from paddle_serving_app.local_predict import LocalPredictor
self.client = LocalPredictor()
self.client.load_model_config(
"{}".format(self.model_config), use_gpu=False)
# actually, LocalPredictor is like a server, but it is WebService Request initiator
# for WebService it is a Client.
# local_predictor only support single-Model DirPath - Type:str
# so the input must be self.server_config_dir_paths[0]
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=False)
def run_web_service(self):
print("This API will be deprecated later. Please do not use it")
......
此差异已折叠。
......@@ -53,15 +53,29 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
print("You must specify your serving model")
exit(-1)
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
import paddle_serving_server_gpu as serving
op_maker = serving.OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = serving.OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
if len(model) == 2 and idx == 0:
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
if use_multilang:
......@@ -156,8 +170,11 @@ class MainService(BaseHTTPRequestHandler):
return False
else:
key = base64.b64decode(post_data["key"].encode())
with open(args.model + "/key", "wb") as f:
f.write(key)
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
with open(single_model_config + "/key", "wb") as f:
f.write(key)
return True
def check_key(self, post_data):
......@@ -165,9 +182,14 @@ class MainService(BaseHTTPRequestHandler):
return False
else:
key = base64.b64decode(post_data["key"].encode())
with open(args.model + "/key", "rb") as f:
cur_key = f.read()
return (key == cur_key)
for single_model_config in args.model:
if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
with open(single_model_config + "/key", "rb") as f:
cur_key = f.read()
if key != cur_key:
return False
return True
def start(self, post_data):
post_data = json.loads(post_data.decode('utf-8'))
......@@ -211,6 +233,12 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__":
args = serve_args()
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
if args.name == "None":
from .web_service import port_is_available
if args.use_encryption_model:
......
......@@ -19,9 +19,9 @@ from contextlib import closing
from multiprocessing import Pool, Process, Queue
from paddle_serving_client import Client
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
from paddle_serving_server_gpu.serve import start_multi_card
import socket
import sys
import os
import numpy as np
import paddle_serving_server_gpu as serving
......@@ -66,23 +66,39 @@ class WebService(object):
def run_service(self):
self._server.run_server()
def load_model_config(self, model_config):
print("This API will be deprecated later. Please do not use it")
self.model_config = model_config
import os
def load_model_config(self, server_config_dir_paths, client_config_path=None):
if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list):
pass
for single_model_config in server_config_dir_paths:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
self.server_config_dir_paths = server_config_dir_paths
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
if os.path.isdir(model_config):
client_config = "{}/serving_server_conf.prototxt".format(
model_config)
elif os.path.isfile(model_config):
client_config = model_config
file_path_list = []
for single_model_config in self.server_config_dir_paths:
file_path_list.append( "{}/serving_server_conf.prototxt".format(single_model_config) )
model_conf = m_config.GeneralModelConfig()
f = open(client_config, 'r')
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_vars = {var.name: var for var in model_conf.feed_var}
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_vars = {var.name: var for var in model_conf.fetch_var}
if client_config_path == None:
self.client_config_path = self.server_config_dir_paths
def set_gpus(self, gpus):
print("This API will be deprecated later. Please do not use it")
......@@ -103,14 +119,22 @@ class WebService(object):
device = "arm"
else:
device = "cpu"
op_maker = serving.OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_maker = OpMaker()
op_seq_maker = OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
for idx, single_model in enumerate(self.server_config_dir_paths):
infer_op_name = "general_infer"
if len(self.server_config_dir_paths) == 2 and idx == 0:
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
server = Server()
......@@ -125,7 +149,7 @@ class WebService(object):
if use_xpu:
server.set_xpu()
server.load_model_config(self.model_config)
server.load_model_config(self.server_config_dir_paths)
if gpuid >= 0:
server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device)
......@@ -193,8 +217,7 @@ class WebService(object):
def _launch_web_service(self):
gpu_num = len(self.gpus)
self.client = Client()
self.client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
self.client.load_client_config(self.client_config_path)
endpoints = ""
if gpu_num > 0:
for i in range(gpu_num):
......@@ -277,8 +300,11 @@ class WebService(object):
def _launch_local_predictor(self, gpu):
from paddle_serving_app.local_predict import LocalPredictor
self.client = LocalPredictor()
self.client.load_model_config(
"{}".format(self.model_config), use_gpu=True, gpu_id=self.gpus[0])
# actually, LocalPredictor is like a server, but it is WebService Request initiator
# for WebService it is a Client.
# local_predictor only support single-Model DirPath - Type:str
# so the input must be self.server_config_dir_paths[0]
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=True, gpu_id=self.gpus[0])
def run_web_service(self):
print("This API will be deprecated later. Please do not use it")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册