“bc021d775ed333dc9dca217203ee0d2999700813”上不存在“paddle/operators/scale_op.cu”
提交 4ad40937 编写于 作者: H HexToString

fix 2.0 multiThread and fix py35 encrytion and support --model multiModels and support ocr C++

上级 82707adf
......@@ -19,9 +19,6 @@ set(PADDLE_SERVING_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
SET(PADDLE_SERVING_INSTALL_DIR ${CMAKE_BINARY_DIR}/output)
SET(CMAKE_INSTALL_RPATH "\$ORIGIN" "${CMAKE_INSTALL_RPATH}")
include(system)
SET(CMAKE_BUILD_TYPE "Debug")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
project(paddle-serving CXX C)
message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
"${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}")
......@@ -52,14 +49,21 @@ option(WITH_GPU "Compile Paddle Serving with NVIDIA GPU"
option(WITH_LITE "Compile Paddle Serving with Paddle Lite Engine" OFF)
option(WITH_XPU "Compile Paddle Serving with Baidu Kunlun" OFF)
option(WITH_PYTHON "Compile Paddle Serving with Python" ON)
option(CLIENT "Compile Paddle Serving Client" OFF)
option(SERVER "Compile Paddle Serving Server" OFF)
option(APP "Compile Paddle Serving App package" OFF)
option(CLIENT "Compile Paddle Serving Client" OFF)
option(SERVER "Compile Paddle Serving Server" OFF)
option(APP "Compile Paddle Serving App package" OFF)
option(WITH_ELASTIC_CTR "Compile ELASITC-CTR solution" OFF)
option(PACK "Compile for whl" OFF)
option(WITH_TRT "Compile Paddle Serving with TRT" OFF)
option(PADDLE_ON_INFERENCE "Compile for encryption" ON)
option(WITH_OPENCV "Compile Paddle Serving with OPENCV" OFF)
option(WITH_OPENCV "Compile Paddle Serving with OPENCV" OFF)
option(WITH_GDB "Compile Paddle Serving with GDB" OFF)
if (WITH_GDB)
SET(CMAKE_BUILD_TYPE "Debug")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
endif()
if (WITH_OPENCV)
SET(OPENCV_DIR "" CACHE PATH "Location of libraries")
......
......@@ -29,12 +29,7 @@ test_reader = paddle.batch(
for data in test_reader():
new_data = np.zeros((1, 13)).astype("float32")
print('testclient.py-----data',data[0][0])
print('testclient.py-----shape',data[0][0].shape)
new_data[0] = data[0][0]
print('testclient.py-----newdata',new_data)
print('testclient.py-----newdata-0',new_data[0])
print('testclient.py-----newdata.shape',new_data.shape)
fetch_map = client.predict(
feed={"x": new_data}, fetch=["price"], batch=True)
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
......
......@@ -34,8 +34,10 @@ from .proto import multi_lang_general_model_service_pb2_grpc
int64_type = 0
float32_type = 1
int32_type = 2
bytes_type = 3
int_type = set([int64_type, int32_type])
float_type = set([float32_type])
string_type= set([bytes_type])
class _NOPProfiler(object):
......@@ -139,10 +141,22 @@ class Client(object):
from .serving_client import PredictorRes
self.predictorres_constructor = PredictorRes
def load_client_config(self, path):
def load_client_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
from .serving_client import PredictorClient
model_conf = m_config.GeneralModelConfig()
f = open(path, 'r')
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
......@@ -151,19 +165,16 @@ class Client(object):
# get feed shapes, feed types
# map feed names to index
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
self.client_handle_.init(file_path_list)
if "FLAGS_max_body_size" not in os.environ:
os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
read_env_flags = ["profile_client", "profile_server", "max_body_size"]
self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
self.feed_names_to_idx_ = {}#this is not useful
self.lod_tensor_set = set()
self.feed_tensor_len = {}
self.feed_tensor_len = {}#this is only used for shape check
self.key = None
for i, var in enumerate(model_conf.feed_var):
......@@ -178,6 +189,14 @@ class Client(object):
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
self.feed_tensor_len[var.alias_name] = counter
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
......@@ -288,13 +307,17 @@ class Client(object):
raise ValueError("Feed only accepts dict and list of dict")
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
float_feed_names = []
int_shape = []
int_lod_slot_batch = []
float_slot_batch = []
float_feed_names = []
float_lod_slot_batch = []
float_shape = []
string_slot_batch = []
string_feed_names = []
string_lod_slot_batch = []
string_shape = []
fetch_names = []
counter = 0
......@@ -311,9 +334,11 @@ class Client(object):
for i, feed_i in enumerate(feed_batch):
int_slot = []
float_slot = []
int_lod_slot = []
float_slot = []
float_lod_slot = []
string_slot = []
string_lod_slot = []
for key in feed_i:
if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key))
......@@ -368,10 +393,24 @@ class Client(object):
else:
float_slot.append(feed_i[key])
self.all_numpy_input = False
#if input is string, feed is not numpy.
elif self.feed_types_[key] in string_type:
if i == 0:
string_feed_names.append(key)
string_shape.append(self.feed_shapes_[key])
if "{}.lod".format(key) in feed_i:
string_lod_slot_batch.append(feed_i["{}.lod".format(
key)])
else:
string_lod_slot_batch.append([])
string_slot.append(feed_i[key])
self.has_numpy_input = True
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
int_lod_slot_batch.append(int_lod_slot)
float_slot_batch.append(float_slot)
float_lod_slot_batch.append(float_lod_slot)
string_slot_batch.append(string_slot)
string_lod_slot_batch.append(string_lod_slot)
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
......@@ -381,7 +420,8 @@ class Client(object):
res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape,
float_lod_slot_batch, int_slot_batch, int_feed_names, int_shape,
int_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
int_lod_slot_batch, string_slot_batch, string_feed_names, string_shape,
string_lod_slot_batch, fetch_names, result_batch_handle, self.pid,
log_id)
elif self.has_numpy_input == False:
raise ValueError(
......@@ -509,8 +549,8 @@ class MultiLangClient(object):
get_client_config_req = multi_lang_general_model_service_pb2.GetClientConfigRequest(
)
resp = self.stub_.GetClientConfig(get_client_config_req)
model_config_str = resp.client_config_str
self._parse_model_config(model_config_str)
model_config_path_list = resp.client_config_str_list
self._parse_model_config(model_config_path_list)
def _flatten_list(self, nested_list):
for item in nested_list:
......@@ -520,25 +560,39 @@ class MultiLangClient(object):
else:
yield item
def _parse_model_config(self, model_config_str):
def _parse_model_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
model_conf = m_config.GeneralModelConfig()
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
......
......@@ -14,6 +14,8 @@ class OpMaker(object):
"general_single_kv": "GeneralSingleKVOp",
"general_dist_kv_infer": "GeneralDistKVInferOp",
"general_dist_kv": "GeneralDistKVOp"
"general_copy": "GeneralCopyOp",
"general_detection":"GeneralDetectionOp",
}
self.node_name_suffix_ = collections.defaultdict(int)
......@@ -45,7 +47,6 @@ class OpMaker(object):
# overall efficiency.
return google.protobuf.text_format.MessageToString(node)
class OpSeqMaker(object):
def __init__(self):
self.workflow = server_sdk.Workflow()
......@@ -78,7 +79,8 @@ class OpSeqMaker(object):
workflow_conf.workflows.extend([self.workflow])
return workflow_conf
# TODO:Currently, SDK only supports "Sequence".OpGraphMaker is not useful.
# Config should be changed to adapt command-line for list[dict] or list[list[] ]
class OpGraphMaker(object):
def __init__(self):
self.workflow = server_sdk.Workflow()
......
......@@ -11,38 +11,55 @@ from .proto import multi_lang_general_model_service_pb2_grpc
class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
MultiLangGeneralModelServiceServicer):
def __init__(self, model_config_path, is_multi_model, endpoints):
def __init__(self, model_config_path_list, is_multi_model, endpoints):
self.is_multi_model_ = is_multi_model
self.model_config_path_ = model_config_path
self.model_config_path_list = model_config_path_list
self.endpoints_ = endpoints
with open(self.model_config_path_) as f:
self.model_config_str_ = str(f.read())
self._parse_model_config(self.model_config_str_)
self._init_bclient(self.model_config_path_, self.endpoints_)
self._init_bclient(self.model_config_path_list, self.endpoints_)
self._parse_model_config(self.model_config_path_list)
def _init_bclient(self, model_config_path, endpoints, timeout_ms=None):
def _init_bclient(self, model_config_path_list, endpoints, timeout_ms=None):
from paddle_serving_client import Client
self.bclient_ = Client()
if timeout_ms is not None:
self.bclient_.set_rpc_timeout_ms(timeout_ms)
self.bclient_.load_client_config(model_config_path)
self.bclient_.load_client_config(model_config_path_list)
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_str):
def _parse_model_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
model_config_path_list = [model_config_path_list]
elif isinstance(model_config_path_list, list):
pass
file_path_list = []
for single_model_config in model_config_path_list:
if os.path.isdir(single_model_config):
file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
elif os.path.isfile(single_model_config):
file_path_list.append(single_model_config)
model_conf = m_config.GeneralModelConfig()
model_conf = google.protobuf.text_format.Merge(model_config_str,
model_conf)
f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
if len(file_path_list) > 1:
model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[-1], 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
......@@ -69,11 +86,11 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
v_type = self.feed_types_[name]
data = None
if is_python:
if v_type == 0:
if v_type == 0:# int64
data = np.frombuffer(var.data, dtype="int64")
elif v_type == 1:
elif v_type == 1:# float32
data = np.frombuffer(var.data, dtype="float32")
elif v_type == 2:
elif v_type == 2:# int32
data = np.frombuffer(var.data, dtype="int32")
else:
raise Exception("error type.")
......@@ -82,7 +99,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
data = np.array(list(var.int64_data), dtype="int64")
elif v_type == 1: # float32
data = np.array(list(var.float_data), dtype="float32")
elif v_type == 2:
elif v_type == 2:# int32
data = np.array(list(var.int_data), dtype="int32")
else:
raise Exception("error type.")
......@@ -138,7 +155,7 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
# This porcess and Inference process cannot be operate at the same time.
# For performance reasons, do not add thread lock temporarily.
timeout_ms = request.timeout_ms
self._init_bclient(self.model_config_path_, self.endpoints_, timeout_ms)
self._init_bclient(self.model_config_path_list, self.endpoints_, timeout_ms)
resp = multi_lang_general_model_service_pb2.SimpleResponse()
resp.err_code = 0
return resp
......@@ -155,6 +172,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
return self._pack_inference_response(ret, fetch_names, is_python)
def GetClientConfig(self, request, context):
#model_config_path_list is list right now.
#dict should be added when graphMaker is used.
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
resp.client_config_str = self.model_config_str_
resp.client_config_str_list[:] = self.model_config_path_list
return resp
\ No newline at end of file
......@@ -23,7 +23,6 @@ import json
import base64
import time
from multiprocessing import Process
from .web_service import WebService, port_is_available
from flask import Flask, request
import sys
if sys.version_info.major == 2:
......@@ -182,15 +181,26 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
print("You must specify your serving model")
exit(-1)
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.")
import paddle_serving_server as serving
op_maker = serving.OpMaker()
read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer')
general_response_op = op_maker.create('general_response')
op_seq_maker = serving.OpSeqMaker()
read_op = op_maker.create('general_reader')
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op)
for idx, single_model in enumerate(model):
infer_op_name = "general_infer"
if len(model) == 2 and idx == 0:
infer_op_name = "general_detection"
else:
infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op)
if use_multilang:
......@@ -348,7 +358,7 @@ class MainService(BaseHTTPRequestHandler):
if __name__ == "__main__":
args = parse_args()
args = serve_args()
for single_model_config in args.model:
if os.path.isdir(single_model_config):
pass
......
......@@ -34,6 +34,7 @@ import platform
import numpy as np
import grpc
import sys
import collections
from multiprocessing import Pool, Process
from concurrent import futures
......@@ -154,13 +155,12 @@ class Server(object):
def _prepare_engine(self, model_config_paths, device, use_encryption_model):
if self.model_toolkit_conf == None:
self.model_toolkit_conf = []
self.model_toolkit_conf = server_sdk.ModelToolkitConf()
for engine_name, model_config_path in model_config_paths.items():
engine = server_sdk.EngineDesc()
engine.name = engine_name
# engine.reloadable_meta = model_config_path + "/fluid_time_file"
engine.reloadable_meta = self.workdir + "/fluid_time_file"
engine.reloadable_meta = model_config_path + "/fluid_time_file"
os.system("touch {}".format(engine.reloadable_meta))
engine.reloadable_type = "timestamp_ne"
engine.runtime_thread_num = 0
......@@ -292,7 +292,6 @@ class Server(object):
def get_device_version(self):
avx_flag = False
mkl_flag = self.mkl_flag
openblas_flag = False
r = os.system("cat /proc/cpuinfo | grep avx > /dev/null 2>&1")
if r == 0:
avx_flag = True
......@@ -387,10 +386,8 @@ class Server(object):
os.system("mkdir -p {}".format(workdir))
else:
os.system("mkdir -p {}".format(workdir))
os.system("touch {}/fluid_time_file".format(workdir))
for subdir in self.subdirectory:
os.system("mkdir {}/{}".format(workdir, subdir))
os.system("mkdir -p {}/{}".format(workdir, subdir))
os.system("touch {}/{}/fluid_time_file".format(workdir, subdir))
if not self.port_is_available(port):
......@@ -507,7 +504,7 @@ class MultiLangServer(object):
self.worker_num_ = 4
self.body_size_ = 64 * 1024 * 1024
self.concurrency_ = 100000
self.is_multi_model_ = False # for model ensemble
self.is_multi_model_ = False # for model ensemble, which is not useful right now.
def set_max_concurrency(self, concurrency):
self.concurrency_ = concurrency
......
......@@ -116,7 +116,7 @@ class WebService(object):
device = "arm"
else:
device = "cpu"
op_maker = serving.OpMaker()
op_maker = OpMaker()
op_seq_maker = OpSeqMaker()
read_op = op_maker.create('general_reader')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册