提交 4ad40937 编写于 作者: H HexToString

fix 2.0 multiThread and fix py35 encrytion and support --model multiModels and support ocr C++

上级 82707adf
......@@ -19,9 +19,6 @@ set(PADDLE_SERVING_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
SET(PADDLE_SERVING_INSTALL_DIR ${CMAKE_BINARY_DIR}/output)
SET(CMAKE_INSTALL_RPATH "\$ORIGIN" "${CMAKE_INSTALL_RPATH}")
include(system)
SET(CMAKE_BUILD_TYPE "Debug")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
project(paddle-serving CXX C)
message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
"${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}")
......@@ -60,6 +57,13 @@ option(PACK "Compile for whl"
option(WITH_TRT "Compile Paddle Serving with TRT" OFF)
option(PADDLE_ON_INFERENCE "Compile for encryption" ON)
option(WITH_OPENCV "Compile Paddle Serving with OPENCV" OFF)
option(WITH_GDB "Compile Paddle Serving with GDB" OFF)
if (WITH_GDB)
SET(CMAKE_BUILD_TYPE "Debug")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
endif()
if (WITH_OPENCV)
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
......
......@@ -29,12 +29,7 @@ test_reader = paddle.batch(
for data in test_reader():
new_data = np.zeros((1, 13)).astype("float32")
print('testclient.py-----data',data[0][0])
print('testclient.py-----shape',data[0][0].shape)
new_data[0] = data[0][0]
print('testclient.py-----newdata',new_data)
print('testclient.py-----newdata-0',new_data[0])
print('testclient.py-----newdata.shape',new_data.shape)
fetch_map = client.predict(
feed={"x": new_data}, fetch=["price"], batch=True)
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......
......@@ -34,8 +34,10 @@ from .proto import multi_lang_general_model_service_pb2_grpc
int64_type = 0
float32_type = 1
int32_type = 2
bytes_type = 3
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
string_type= set([bytes_type])
class _NOPProfiler(object):
......@@ -139,10 +141,22 @@ class Client(object):
from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes
def load_client_config(self, path):
def load_client_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
from .serving_client import PredictorClient
model_conf = m_config.GeneralModelConfig()
f = open(path, 'r')
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
......@@ -151,19 +165,16 @@ class Client(object):
# get feed shapes, feed types
# map feed names to index
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
self.client_handle_.init(file_path_list)
if "FLAGS_max_body_size" not in os.environ:
os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
read_env_flags = ["profile_client", "profile_server", "max_body_size"]
self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
self.feed_names_to_idx_ = {}#this is not useful
self.lod_tensor_set = set()
self.feed_tensor_len = {}
self.feed_tensor_len = {}#this is only used for shape check
self.key = None
for i, var in enumerate(model_conf.feed_var):
......@@ -178,6 +189,14 @@ class Client(object):
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
self.feed_tensor_len[var.alias_name] = counter
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
......@@ -288,13 +307,17 @@ class Client(object):
raise ValueError("Feed only accepts dict and list of dict")
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
float_feed_names = []
int_shape = []
int_lod_slot_batch = []
float_slot_batch = []
float_feed_names = []
float_lod_slot_batch = []
float_shape = []
string_slot_batch = []
string_feed_names = []
string_lod_slot_batch = []
string_shape = []
fetch_names = []
counter = 0
......@@ -311,9 +334,11 @@ class Client(object):
for i, feed_i in enumerate(feed_batch):
int_slot = []
float_slot = []
int_lod_slot = []
float_slot = []
float_lod_slot = []
string_slot = []
string_lod_slot = []
for key in feed_i:
if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key))
......@@ -368,10 +393,24 @@ class Client(object):
else:
float_slot.append(feed_i[key])
self.all_numpy_input = False
#if input is string, feed is not numpy.
elif self.feed_types_[key] in string_type:
if i == 0:
string_feed_names.append(key)
string_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i:
string_lod_slot_batch.append(feed_i["{}.lod".format(
key)])
else:
string_lod_slot_batch.append([])
string_slot.append(feed_i[key])
self.has_numpy_input = True
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
int_lod_slot_batch.append(int_lod_slot)
float_slot_batch.append(float_slot)
float_lod_slot_batch.append(float_lod_slot)
string_slot_batch.append(string_slot)
string_lod_slot_batch.append(string_lod_slot)
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
......@@ -381,7 +420,8 @@ class Client(object):
res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape,
float_lod_slot_batch, int_slot_batch, int_feed_names, int_shape,
int_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
int_lod_slot_batch, string_slot_batch, string_feed_names, string_shape,
string_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
log_id)
elif self.has_numpy_input == False:
raise ValueError(
......@@ -509,8 +549,8 @@ class MultiLangClient(object):
get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
)
resp = self.stub_.GetClientConfig(get_client_config_req)
model_config_str = resp.client_config_str
self._parse_model_config(model_config_str)
model_config_path_list = resp.client_config_str_list
self._parse_model_config(model_config_path_list)
def _flatten_list(self, nested_list):
for item in nested_list:
......@@ -520,25 +560,39 @@ class MultiLangClient(object):
else:
yield item
def _parse_model_config(self, model_config_str):
def _parse_model_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
model_conf = m_config.GeneralModelConfig()
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
......
......@@ -14,6 +14,8 @@ class OpMaker(object):
"general_single_kv": "GeneralSingleKVOp",
"general_dist_kv_infer": "GeneralDistKVInferOp",
"general_dist_kv": "GeneralDistKVOp"
"general_copy": "GeneralCopyOp",
"general_detection":"GeneralDetectionOp",
}
self.node_name_suffix_ = collections.defaultdict(int)
......@@ -45,7 +47,6 @@ class OpMaker(object):
# overall efficiency.
return google.protobuf.text_format.MessageToString(node)
class OpSeqMaker(object):
def __init__(self):
self.workflow = server_sdk.Workflow()
......@@ -78,7 +79,8 @@ class OpSeqMaker(object):
workflow_conf.workflows.extend([self.workflow])
return workflow_conf
# TODO:Currently, SDK only supports "Sequence".OpGraphMaker is not useful.
# Config should be changed to adapt command-line for list[dict] or list[list[] ]
class OpGraphMaker(object):
def __init__(self):
self.workflow = server_sdk.Workflow()
......
......@@ -11,38 +11,55 @@ from .proto import multi_lang_general_model_service_pb2_grpc
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path, is_multi_model, endpoints):
def __init__(self, model_config_path_list, is_multi_model, endpoints):
self.is_multi_model_ = is_multi_model
self.model_config_path_ = model_config_path
self.model_config_path_list = model_config_path_list
self.endpoints_ = endpoints
with open(self.model_config_path_) as f:
self.model_config_str_ = str(f.read())
self._parse_model_config(self.model_config_str_)
self._init_bclient(self.model_config_path_, self.endpoints_)
self._init_bclient(self.model_config_path_list, self.endpoints_)
self._parse_model_config(self.model_config_path_list)
def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
def _init_bclient(self, model_config_path_list, endpoints, timeout_ms=None):
from paddle_serving_client import Client
self.bclient_ = Client()
if timeout_ms is not None:
self.bclient_.set_rpc_timeout_ms(timeout_ms)
self.bclient_.load_client_config(model_config_path)
self.bclient_.load_client_config(model_config_path_list)
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_str):
def _parse_model_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
model_conf = m_config.GeneralModelConfig()
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
......@@ -69,11 +86,11 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0:
if v_type == 0:# int64
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
elif v_type == 1:# float32
data = np.frombuffer(var.data, dtype="float32")
elif v_type == 2:
elif v_type == 2:# int32
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
......@@ -82,7 +99,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2:
elif v_type == 2:# int32
data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
......@@ -138,7 +155,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
# This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_, self.endpoints_, timeout_ms)
self._init_bclient(self.model_config_path_list, self.endpoints_, timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
......@@ -155,6 +172,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
#model_config_path_list is list right now.
#dict should be added when graphMaker is used.
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str = self.model_config_str_
resp.client_config_str_list[:] = self.model_config_path_list
return resp
\ No newline at end of file
......@@ -23,7 +23,6 @@ import json
import base64
import time
from multiprocessing import Process
from .web_service import WebService, port_is_available
from flask import Flask, request
import sys
if sys.version_info.major == 2:
......@@ -182,15 +181,26 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
print("You must specify your serving model")
exit(-1)
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
import paddle_serving_server as serving
op_maker = serving.OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = serving.OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
if len(model) == 2 and idx == 0:
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
if use_multilang:
......@@ -348,7 +358,7 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__":
args = parse_args()
args = serve_args()
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
......
......@@ -34,6 +34,7 @@ import platform
import numpy as np
import grpc
import sys
import collections
from multiprocessing import Pool, Process
from concurrent import futures
......@@ -154,13 +155,12 @@ class Server(object):
def _prepare_engine(self, model_config_paths, device, use_encryption_model):
if self.model_toolkit_conf == None:
self.model_toolkit_conf = []
self.model_toolkit_conf = server_sdk.ModelToolkitConf()
for engine_name, model_config_path in model_config_paths.items():
engine = server_sdk.EngineDesc()
engine.name = engine_name
# engine.reloadable_meta = model_config_path + "/fluid_time_file"
engine.reloadable_meta = self.workdir + "/fluid_time_file"
engine.reloadable_meta = model_config_path + "/fluid_time_file"
os.system("touch {}".format(engine.reloadable_meta))
engine.reloadable_type = "timestamp_ne"
engine.runtime_thread_num = 0
......@@ -292,7 +292,6 @@ class Server(object):
def get_device_version(self):
avx_flag = False
mkl_flag = self.mkl_flag
openblas_flag = False
r = os.system("cat /proc/cpuinfo | grep avx > /dev/null 2>&1")
if r == 0:
avx_flag = True
......@@ -387,10 +386,8 @@ class Server(object):
os.system("mkdir -p {}".format(workdir))
else:
os.system("mkdir -p {}".format(workdir))
os.system("touch {}/fluid_time_file".format(workdir))
for subdir in self.subdirectory:
os.system("mkdir {}/{}".format(workdir, subdir))
os.system("mkdir -p {}/{}".format(workdir, subdir))
os.system("touch {}/{}/fluid_time_file".format(workdir, subdir))
if not self.port_is_available(port):
......@@ -507,7 +504,7 @@ class MultiLangServer(object):
self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000
self.is_multi_model_ = False # for model ensemble
self.is_multi_model_ = False # for model ensemble, which is not useful right now.
def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency
......
......@@ -116,7 +116,7 @@ class WebService(object):
device = "arm"
else:
device = "cpu"
op_maker = serving.OpMaker()
op_maker = OpMaker()
op_seq_maker = OpSeqMaker()
read_op = op_maker.create('general_reader')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册