提交 d45340de 编写于 作者: H HexToString

temp fix encryption

上级 69135767
...@@ -31,8 +31,8 @@ message( "WITH_GPU = ${WITH_GPU}") ...@@ -31,8 +31,8 @@ message( "WITH_GPU = ${WITH_GPU}")
# Paddle Version should be one of: # Paddle Version should be one of:
# latest: latest develop build # latest: latest develop build
# version number like 1.5.2 # version number like 1.5.2
SET(PADDLE_VERSION "2.0.0-rc1") #SET(PADDLE_VERSION "2.0.0-rc1")
SET(PADDLE_VERSION "latest")
if (WITH_GPU) if (WITH_GPU)
if (WITH_TRT) if (WITH_TRT)
SET(PADDLE_LIB_VERSION "${PADDLE_VERSION}-gpu-cuda10.1-cudnn7-avx-mkl-trt6") SET(PADDLE_LIB_VERSION "${PADDLE_VERSION}-gpu-cuda10.1-cudnn7-avx-mkl-trt6")
...@@ -114,7 +114,7 @@ ADD_LIBRARY(openblas STATIC IMPORTED GLOBAL) ...@@ -114,7 +114,7 @@ ADD_LIBRARY(openblas STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET openblas PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/openblas/lib/libopenblas.a) SET_PROPERTY(TARGET openblas PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/openblas/lib/libopenblas.a)
ADD_LIBRARY(paddle_fluid SHARED IMPORTED GLOBAL) ADD_LIBRARY(paddle_fluid SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET paddle_fluid PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/lib/libpaddle_fluid.so) SET_PROPERTY(TARGET paddle_fluid PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/lib/libpaddle_fluid.a)
if (WITH_TRT) if (WITH_TRT)
ADD_LIBRARY(nvinfer SHARED IMPORTED GLOBAL) ADD_LIBRARY(nvinfer SHARED IMPORTED GLOBAL)
...@@ -127,10 +127,13 @@ endif() ...@@ -127,10 +127,13 @@ endif()
ADD_LIBRARY(xxhash STATIC IMPORTED GLOBAL) ADD_LIBRARY(xxhash STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET xxhash PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/xxhash/lib/libxxhash.a) SET_PROPERTY(TARGET xxhash PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/xxhash/lib/libxxhash.a)
ADD_LIBRARY(cryptopp STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET cryptopp PROPERTY IMPORTED_LOCATION ${PADDLE_INSTALL_DIR}/third_party/install/cryptopp/lib/libcryptopp.a)
LIST(APPEND external_project_dependencies paddle) LIST(APPEND external_project_dependencies paddle)
LIST(APPEND paddle_depend_libs LIST(APPEND paddle_depend_libs
xxhash) xxhash cryptopp)
if(WITH_TRT) if(WITH_TRT)
LIST(APPEND paddle_depend_libs LIST(APPEND paddle_depend_libs
......
...@@ -263,6 +263,62 @@ class Parameter { ...@@ -263,6 +263,62 @@ class Parameter {
float* _params; float* _params;
}; };
class FluidCpuAnalysisEncryptCore : public FluidFamilyCore {
public:
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
int create(const predictor::InferEngineCreationParams& params) {
std::string data_path = params.get_path();
if (access(data_path.c_str(), F_OK) == -1) {
LOG(ERROR) << "create paddle predictor failed, path note exits: "
<< data_path;
return -1;
}
std::string model_buffer, params_buffer, key_buffer;
ReadBinaryFile(data_path + "encrypt_model", &model_buffer);
ReadBinaryFile(data_path + "encrypt_params", &params_buffer);
ReadBinaryFile(data_path + "key", &key_buffer);
VLOG(2) << "prepare for encryption model";
auto cipher = paddle::MakeCipher("");
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
std::string real_params_buffer = cipher->Decrypt(params_buffer, key_buffer);
Config analysis_config;
//paddle::AnalysisConfig analysis_config;
analysis_config.SetModelBuffer(&real_model_buffer[0],
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
analysis_config.DisableGpu();
analysis_config.SetCpuMathLibraryNumThreads(1);
if (params.enable_memory_optimization()) {
analysis_config.EnableMemoryOptim();
}
analysis_config.SwitchSpecifyInputNames(true);
AutoLock lock(GlobalPaddleCreateMutex::instance());
VLOG(2) << "decrypt model file sucess";
_core =
CreatePredictor(analysis_config);
if (NULL == _core.get()) {
LOG(ERROR) << "create paddle predictor failed, path: " << data_path;
return -1;
}
VLOG(2) << "create paddle predictor sucess, path: " << data_path;
return 0;
}
};
} // namespace fluid_cpu } // namespace fluid_cpu
} // namespace paddle_serving } // namespace paddle_serving
} // namespace baidu } // namespace baidu
...@@ -30,6 +30,13 @@ REGIST_FACTORY_OBJECT_IMPL_WITH_NAME( ...@@ -30,6 +30,13 @@ REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::InferEngine, ::baidu::paddle_serving::predictor::InferEngine,
"FLUID_CPU_ANALYSIS_DIR"); "FLUID_CPU_ANALYSIS_DIR");
#if 1
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<
FluidCpuAnalysisEncryptCore>,
::baidu::paddle_serving::predictor::InferEngine,
"FLUID_CPU_ANALYSIS_ENCRYPT");
#endif
} // namespace fluid_cpu } // namespace fluid_cpu
} // namespace paddle_serving } // namespace paddle_serving
} // namespace baidu } // namespace baidu
...@@ -19,6 +19,9 @@ from .proto import sdk_configure_pb2 as sdk ...@@ -19,6 +19,9 @@ from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format import google.protobuf.text_format
import numpy as np import numpy as np
import requests
import json
import base64
import time import time
import sys import sys
...@@ -161,6 +164,7 @@ class Client(object): ...@@ -161,6 +164,7 @@ class Client(object):
self.fetch_names_to_idx_ = {} self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set() self.lod_tensor_set = set()
self.feed_tensor_len = {} self.feed_tensor_len = {}
self.key = None
for i, var in enumerate(model_conf.feed_var): for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i self.feed_names_to_idx_[var.alias_name] = i
...@@ -193,7 +197,28 @@ class Client(object): ...@@ -193,7 +197,28 @@ class Client(object):
else: else:
self.rpc_timeout_ms = rpc_timeout self.rpc_timeout_ms = rpc_timeout
def connect(self, endpoints=None): def use_key(self, key_filename):
with open(key_filename, "r") as f:
self.key = f.read()
def get_serving_port(self, endpoints):
if self.key is not None:
req = json.dumps({"key": base64.b64encode(self.key)})
else:
req = json.dumps({})
r = requests.post("http://" + endpoints[0], req)
result = r.json()
print(result)
if "endpoint_list" not in result:
raise ValueError("server not ready")
else:
endpoints = [
endpoints[0].split(":")[0] + ":" +
str(result["endpoint_list"][0])
]
return endpoints
def connect(self, endpoints=None, encryption=False):
# check whether current endpoint is available # check whether current endpoint is available
# init from client config # init from client config
# create predictor here # create predictor here
...@@ -203,6 +228,8 @@ class Client(object): ...@@ -203,6 +228,8 @@ class Client(object):
"You must set the endpoints parameter or use add_variant function to create a variant." "You must set the endpoints parameter or use add_variant function to create a variant."
) )
else: else:
if encryption:
endpoints = self.get_serving_port(endpoints)
if self.predictor_sdk_ is None: if self.predictor_sdk_ is None:
self.add_variant('default_tag_{}'.format(id(self)), endpoints, self.add_variant('default_tag_{}'.format(id(self)), endpoints,
100) 100)
......
...@@ -21,10 +21,14 @@ from paddle.fluid.framework import Program ...@@ -21,10 +21,14 @@ from paddle.fluid.framework import Program
from paddle.fluid import CPUPlace from paddle.fluid import CPUPlace
from paddle.fluid.io import save_inference_model from paddle.fluid.io import save_inference_model
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.core import CipherUtils
from paddle.fluid.core import CipherFactory
from paddle.fluid.core import Cipher
from ..proto import general_model_config_pb2 as model_conf from ..proto import general_model_config_pb2 as model_conf
import os import os
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import errno
from paddle.jit import to_static from paddle.jit import to_static
def save_dygraph_model(serving_model_folder, client_config_folder, model): def save_dygraph_model(serving_model_folder, client_config_folder, model):
...@@ -112,7 +116,10 @@ def save_model(server_model_folder, ...@@ -112,7 +116,10 @@ def save_model(server_model_folder,
client_config_folder, client_config_folder,
feed_var_dict, feed_var_dict,
fetch_var_dict, fetch_var_dict,
main_program=None): main_program=None,
encryption=False,
key_len=128,
encrypt_conf=None):
executor = Executor(place=CPUPlace()) executor = Executor(place=CPUPlace())
feed_var_names = [feed_var_dict[x].name for x in feed_var_dict] feed_var_names = [feed_var_dict[x].name for x in feed_var_dict]
...@@ -122,14 +129,31 @@ def save_model(server_model_folder, ...@@ -122,14 +129,31 @@ def save_model(server_model_folder,
target_vars.append(fetch_var_dict[key]) target_vars.append(fetch_var_dict[key])
target_var_names.append(key) target_var_names.append(key)
save_inference_model( if not encryption:
server_model_folder, save_inference_model(
feed_var_names, server_model_folder,
target_vars, feed_var_names,
executor, target_vars,
model_filename="__model__", executor,
params_filename="__params__", model_filename="__model__",
main_program=main_program) params_filename="__params__",
main_program=main_program)
else:
if encrypt_conf == None:
aes_cipher = CipherFactory.create_cipher()
else:
#todo: more encryption algorithms
pass
key = CipherUtils.gen_key_to_file(128, "key")
params = fluid.io.save_persistables(
executor=executor, dirname=None, main_program=main_program)
model = main_program.desc.serialize_to_string()
if not os.path.exists(server_model_folder):
os.makedirs(server_model_folder)
os.chdir(server_model_folder)
aes_cipher.encrypt_to_file(params, key, "encrypt_params")
aes_cipher.encrypt_to_file(model, key, "encrypt_model")
os.chdir("..")
config = model_conf.GeneralModelConfig() config = model_conf.GeneralModelConfig()
...@@ -201,7 +225,11 @@ def inference_model_to_serving(dirname, ...@@ -201,7 +225,11 @@ def inference_model_to_serving(dirname,
serving_server="serving_server", serving_server="serving_server",
serving_client="serving_client", serving_client="serving_client",
model_filename=None, model_filename=None,
params_filename=None): params_filename=None,
encryption=False,
key_len=128,
encrypt_conf=None):
paddle.enable_static()
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
inference_program, feed_target_names, fetch_targets = \ inference_program, feed_target_names, fetch_targets = \
...@@ -212,7 +240,7 @@ def inference_model_to_serving(dirname, ...@@ -212,7 +240,7 @@ def inference_model_to_serving(dirname,
} }
fetch_dict = {x.name: x for x in fetch_targets} fetch_dict = {x.name: x for x in fetch_targets}
save_model(serving_server, serving_client, feed_dict, fetch_dict, save_model(serving_server, serving_client, feed_dict, fetch_dict,
inference_program) inference_program, encryption, key_len, encrypt_conf)
feed_names = feed_dict.keys() feed_names = feed_dict.keys()
fetch_names = fetch_dict.keys() fetch_names = fetch_dict.keys()
return feed_names, fetch_names return feed_names, fetch_names
...@@ -157,7 +157,8 @@ class Server(object): ...@@ -157,7 +157,8 @@ class Server(object):
self.cur_path = os.getcwd() self.cur_path = os.getcwd()
self.use_local_bin = False self.use_local_bin = False
self.mkl_flag = False self.mkl_flag = False
self.product_name = None self.encryption_model = False
self.product_name = None
self.container_id = None self.container_id = None
self.model_config_paths = None # for multi-model in a workflow self.model_config_paths = None # for multi-model in a workflow
...@@ -196,6 +197,8 @@ class Server(object): ...@@ -196,6 +197,8 @@ class Server(object):
def set_ir_optimize(self, flag=False): def set_ir_optimize(self, flag=False):
self.ir_optimization = flag self.ir_optimization = flag
def use_encryption_model(self, flag=False):
self.encryption_model = flag
def set_product_name(self, product_name=None): def set_product_name(self, product_name=None):
if product_name == None: if product_name == None:
...@@ -236,9 +239,15 @@ class Server(object): ...@@ -236,9 +239,15 @@ class Server(object):
suffix = "_DIR" suffix = "_DIR"
if device == "cpu": if device == "cpu":
engine.type = "FLUID_CPU_ANALYSIS" + suffix if self.encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRYPT"
else:
engine.type = "FLUID_CPU_ANALYSIS" + suffix
elif device == "gpu": elif device == "gpu":
engine.type = "FLUID_GPU_ANALYSIS" + suffix if self.encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRYPT"
else:
engine.type = "FLUID_GPU_ANALYSIS" + suffix
self.model_toolkit_conf.engines.extend([engine]) self.model_toolkit_conf.engines.extend([engine])
......
...@@ -18,8 +18,14 @@ Usage: ...@@ -18,8 +18,14 @@ Usage:
python -m paddle_serving_server.serve --model ./serving_server_model --port 9292 python -m paddle_serving_server.serve --model ./serving_server_model --port 9292
""" """
import argparse import argparse
from .web_service import WebService import sys
import json
import base64
import time
from multiprocessing import Process
from web_service import WebService, port_is_available
from flask import Flask, request from flask import Flask, request
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
def parse_args(): # pylint: disable=doc-string-missing def parse_args(): # pylint: disable=doc-string-missing
...@@ -53,6 +59,11 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -53,6 +59,11 @@ def parse_args(): # pylint: disable=doc-string-missing
type=int, type=int,
default=512 * 1024 * 1024, default=512 * 1024 * 1024,
help="Limit sizes of messages") help="Limit sizes of messages")
parser.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
help="Use encryption model")
parser.add_argument( parser.add_argument(
"--use_multilang", "--use_multilang",
default=False, default=False,
...@@ -71,17 +82,18 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -71,17 +82,18 @@ def parse_args(): # pylint: disable=doc-string-missing
return parser.parse_args() return parser.parse_args()
def start_standard_model(): # pylint: disable=doc-string-missing def start_standard_model(serving_port): # pylint: disable=doc-string-missing
args = parse_args() args = parse_args()
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
port = args.port port = serving_port
workdir = args.workdir workdir = args.workdir
device = args.device device = args.device
mem_optim = args.mem_optim_off is False mem_optim = args.mem_optim_off is False
ir_optim = args.ir_optim ir_optim = args.ir_optim
max_body_size = args.max_body_size max_body_size = args.max_body_size
use_mkl = args.use_mkl use_mkl = args.use_mkl
use_encryption_model = args.use_encryption_model
use_multilang = args.use_multilang use_multilang = args.use_multilang
if model == "": if model == "":
...@@ -111,6 +123,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -111,6 +123,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.use_mkl(use_mkl) server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size) server.set_max_body_size(max_body_size)
server.set_port(port) server.set_port(port)
server.use_encryption_model(use_encryption_model)
if args.product_name != None: if args.product_name != None:
server.set_product_name(args.product_name) server.set_product_name(args.product_name)
if args.container_id != None: if args.container_id != None:
...@@ -120,12 +133,88 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -120,12 +133,88 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
server.run_server() server.run_server()
class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
def start_serving(self):
start_standard_model(serving_port)
def get_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "w") as f:
f.write(key)
return True
def check_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "r") as f:
cur_key = f.read()
return (key == cur_key)
def start(self, post_data):
post_data = json.loads(post_data)
global p_flag
if not p_flag:
if args.use_encryption_model:
print("waiting key for model")
if not self.get_key(post_data):
print("not found key in request")
return False
global serving_port
global p
serving_port = self.get_available_port()
p = Process(target=self.start_serving)
p.start()
time.sleep(3)
if p.is_alive():
p_flag = True
else:
return False
else:
if p.is_alive():
if not self.check_key(post_data):
return False
else:
return False
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start(post_data):
response = {"endpoint_list": [serving_port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
if args.name == "None": if args.name == "None":
start_standard_model() if args.use_encryption_model:
p_flag = False
p = None
serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService)
print(
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
start_standard_model(args.port)
else: else:
service = WebService(name=args.name) service = WebService(name=args.name)
service.load_model_config(args.model) service.load_model_config(args.model)
......
...@@ -25,6 +25,16 @@ from paddle_serving_server import pipeline ...@@ -25,6 +25,16 @@ from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op from paddle_serving_server.pipeline import Op
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
class WebService(object): class WebService(object):
def __init__(self, name="default_service"): def __init__(self, name="default_service"):
self.name = name self.name = name
...@@ -110,7 +120,7 @@ class WebService(object): ...@@ -110,7 +120,7 @@ class WebService(object):
self.mem_optim = mem_optim self.mem_optim = mem_optim
self.ir_optim = ir_optim self.ir_optim = ir_optim
for i in range(1000): for i in range(1000):
if self.port_is_available(default_port + i): if port_is_available(default_port + i):
self.port_list.append(default_port + i) self.port_list.append(default_port + i)
break break
......
...@@ -70,6 +70,11 @@ def serve_args(): ...@@ -70,6 +70,11 @@ def serve_args():
type=int, type=int,
default=512 * 1024 * 1024, default=512 * 1024 * 1024,
help="Limit sizes of messages") help="Limit sizes of messages")
parser.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
help="Use encryption model")
parser.add_argument( parser.add_argument(
"--use_multilang", "--use_multilang",
default=False, default=False,
...@@ -279,7 +284,7 @@ class Server(object): ...@@ -279,7 +284,7 @@ class Server(object):
def set_trt(self): def set_trt(self):
self.use_trt = True self.use_trt = True
def _prepare_engine(self, model_config_paths, device): def _prepare_engine(self, model_config_paths, device, use_encryption_model):
if self.model_toolkit_conf == None: if self.model_toolkit_conf == None:
self.model_toolkit_conf = server_sdk.ModelToolkitConf() self.model_toolkit_conf = server_sdk.ModelToolkitConf()
...@@ -301,9 +306,15 @@ class Server(object): ...@@ -301,9 +306,15 @@ class Server(object):
engine.use_trt = self.use_trt engine.use_trt = self.use_trt
if device == "cpu": if device == "cpu":
engine.type = "FLUID_CPU_ANALYSIS_DIR" if use_encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu": elif device == "gpu":
engine.type = "FLUID_GPU_ANALYSIS_DIR" if use_encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_GPU_ANALYSIS_DIR"
self.model_toolkit_conf.engines.extend([engine]) self.model_toolkit_conf.engines.extend([engine])
...@@ -460,6 +471,7 @@ class Server(object): ...@@ -460,6 +471,7 @@ class Server(object):
workdir=None, workdir=None,
port=9292, port=9292,
device="cpu", device="cpu",
use_encryption_model=False,
cube_conf=None): cube_conf=None):
if workdir == None: if workdir == None:
workdir = "./tmp" workdir = "./tmp"
...@@ -473,7 +485,8 @@ class Server(object): ...@@ -473,7 +485,8 @@ class Server(object):
self.set_port(port) self.set_port(port)
self._prepare_resource(workdir, cube_conf) self._prepare_resource(workdir, cube_conf)
self._prepare_engine(self.model_config_paths, device) self._prepare_engine(self.model_config_paths, device,
use_encryption_model)
self._prepare_infer_service(port) self._prepare_infer_service(port)
self.workdir = workdir self.workdir = workdir
......
...@@ -19,19 +19,21 @@ Usage: ...@@ -19,19 +19,21 @@ Usage:
""" """
import argparse import argparse
import os import os
import json
import base64
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args from paddle_serving_server_gpu import serve_args
from flask import Flask, request from flask import Flask, request
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-string-missing
gpuid = int(gpuid) gpuid = int(gpuid)
device = "gpu" device = "gpu"
port = args.port
if gpuid == -1: if gpuid == -1:
device = "cpu" device = "cpu"
elif gpuid >= 0: elif gpuid >= 0:
port = args.port + index port = port + index
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
mem_optim = args.mem_optim_off is False mem_optim = args.mem_optim_off is False
...@@ -73,14 +75,19 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss ...@@ -73,14 +75,19 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss
server.set_container_id(args.container_id) server.set_container_id(args.container_id)
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(
workdir=workdir,
port=port,
device=device,
use_encryption_model=args.use_encryption_model)
if gpuid >= 0: if gpuid >= 0:
server.set_gpuid(gpuid) server.set_gpuid(gpuid)
server.run_server() server.run_server()
def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing
def start_multi_card(args): # pylint: disable=doc-string-missing
gpus = "" gpus = ""
if serving_port == None:
serving_port = args.port
if args.gpu_ids == "": if args.gpu_ids == "":
gpus = [] gpus = []
else: else:
...@@ -97,14 +104,16 @@ def start_multi_card(args): # pylint: disable=doc-string-missing ...@@ -97,14 +104,16 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
env_gpus = [] env_gpus = []
if len(gpus) <= 0: if len(gpus) <= 0:
print("gpu_ids not set, going to run cpu service.") print("gpu_ids not set, going to run cpu service.")
start_gpu_card_model(-1, -1, args) start_gpu_card_model(-1, -1, serving_port, args)
else: else:
gpu_processes = [] gpu_processes = []
for i, gpu_id in enumerate(gpus): for i, gpu_id in enumerate(gpus):
p = Process( p = Process(
target=start_gpu_card_model, args=( target=start_gpu_card_model,
args=(
i, i,
gpu_id, gpu_id,
serving_port,
args, )) args, ))
gpu_processes.append(p) gpu_processes.append(p)
for p in gpu_processes: for p in gpu_processes:
...@@ -112,11 +121,89 @@ def start_multi_card(args): # pylint: disable=doc-string-missing ...@@ -112,11 +121,89 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
for p in gpu_processes: for p in gpu_processes:
p.join() p.join()
class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
def start_serving(self):
start_multi_card(args, serving_port)
def get_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "w") as f:
f.write(key)
return True
def check_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "r") as f:
cur_key = f.read()
return (key == cur_key)
def start(self, post_data):
post_data = json.loads(post_data)
global p_flag
if not p_flag:
if args.use_encryption_model:
print("waiting key for model")
if not self.get_key(post_data):
print("not found key in request")
return False
global serving_port
global p
serving_port = self.get_available_port()
p = Process(target=self.start_serving)
p.start()
time.sleep(3)
if p.is_alive():
p_flag = True
else:
return False
else:
if p.is_alive():
if not self.check_key(post_data):
return False
else:
return False
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start(post_data):
response = {"endpoint_list": [serving_port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__": if __name__ == "__main__":
args = serve_args() args = serve_args()
if args.name == "None": if args.name == "None":
start_multi_card(args) from .web_service import port_is_available
if args.use_encryption_model:
p_flag = False
p = None
serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService)
print(
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
start_multi_card(args)
else: else:
from .web_service import WebService from .web_service import WebService
web_service = WebService(name=args.name) web_service = WebService(name=args.name)
......
...@@ -28,6 +28,16 @@ from paddle_serving_server_gpu import pipeline ...@@ -28,6 +28,16 @@ from paddle_serving_server_gpu import pipeline
from paddle_serving_server_gpu.pipeline import Op from paddle_serving_server_gpu.pipeline import Op
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
class WebService(object): class WebService(object):
def __init__(self, name="default_service"): def __init__(self, name="default_service"):
self.name = name self.name = name
...@@ -136,7 +146,7 @@ class WebService(object): ...@@ -136,7 +146,7 @@ class WebService(object):
self.port_list = [] self.port_list = []
default_port = 12000 default_port = 12000
for i in range(1000): for i in range(1000):
if self.port_is_available(default_port + i): if port_is_available(default_port + i):
self.port_list.append(default_port + i) self.port_list.append(default_port + i)
if len(self.port_list) > len(self.gpus): if len(self.port_list) > len(self.gpus):
break break
......
...@@ -39,6 +39,8 @@ RUN yum -y install wget && \ ...@@ -39,6 +39,8 @@ RUN yum -y install wget && \
make clean && \ make clean && \
echo 'export PATH=/usr/local/python3.6/bin:$PATH' >> /root/.bashrc && \ echo 'export PATH=/usr/local/python3.6/bin:$PATH' >> /root/.bashrc && \
echo 'export LD_LIBRARY_PATH=/usr/local/python3.6/lib:$LD_LIBRARY_PATH' >> /root/.bashrc && \ echo 'export LD_LIBRARY_PATH=/usr/local/python3.6/lib:$LD_LIBRARY_PATH' >> /root/.bashrc && \
pip install requests && \
pip3 install requests && \
source /root/.bashrc && \ source /root/.bashrc && \
cd .. && rm -rf Python-3.6.8* && \ cd .. && rm -rf Python-3.6.8* && \
wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/protobuf-all-3.11.2.tar.gz && \ wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/protobuf-all-3.11.2.tar.gz && \
......
...@@ -49,6 +49,8 @@ RUN yum -y install wget && \ ...@@ -49,6 +49,8 @@ RUN yum -y install wget && \
cd .. && rm -rf protobuf-* && \ cd .. && rm -rf protobuf-* && \
yum -y install epel-release && yum -y install patchelf libXext libSM libXrender && \ yum -y install epel-release && yum -y install patchelf libXext libSM libXrender && \
yum clean all && \ yum clean all && \
pip install requests && \
pip3 install requests && \
localedef -c -i en_US -f UTF-8 en_US.UTF-8 && \ localedef -c -i en_US -f UTF-8 en_US.UTF-8 && \
echo "export LANG=en_US.utf8" >> /root/.bashrc && \ echo "export LANG=en_US.utf8" >> /root/.bashrc && \
echo "export LANGUAGE=en_US.utf8" >> /root/.bashrc echo "export LANGUAGE=en_US.utf8" >> /root/.bashrc
...@@ -23,7 +23,8 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \ ...@@ -23,7 +23,8 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel >/dev/null \ RUN yum -y install python-devel sqlite-devel >/dev/null \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \ && curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \ && python get-pip.py >/dev/null \
&& rm get-pip.py && rm get-pip.py \
&& pip install requests
RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \ RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 \
&& yum -y install bzip2 >/dev/null \ && yum -y install bzip2 >/dev/null \
...@@ -34,6 +35,9 @@ RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2 ...@@ -34,6 +35,9 @@ RUN wget http://nixos.org/releases/patchelf/patchelf-0.10/patchelf-0.10.tar.bz2
&& cd .. \ && cd .. \
&& rm -rf patchelf-0.10* && rm -rf patchelf-0.10*
RUN yum install -y python3 python3-devel \
&& pip3 install requests
RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/protobuf-all-3.11.2.tar.gz && \ RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/protobuf-all-3.11.2.tar.gz && \
tar zxf protobuf-all-3.11.2.tar.gz && \ tar zxf protobuf-all-3.11.2.tar.gz && \
cd protobuf-3.11.2 && \ cd protobuf-3.11.2 && \
...@@ -41,8 +45,6 @@ RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/p ...@@ -41,8 +45,6 @@ RUN wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/p
make clean && \ make clean && \
cd .. && rm -rf protobuf-* cd .. && rm -rf protobuf-*
RUN yum install -y python3 python3-devel
RUN yum -y update >/dev/null \ RUN yum -y update >/dev/null \
&& yum -y install dnf >/dev/null \ && yum -y install dnf >/dev/null \
&& yum -y install dnf-plugins-core >/dev/null \ && yum -y install dnf-plugins-core >/dev/null \
......
...@@ -30,11 +30,13 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \ ...@@ -30,11 +30,13 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel \ RUN yum -y install python-devel sqlite-devel \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \ && curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \ && python get-pip.py >/dev/null \
&& rm get-pip.py && rm get-pip.py \
&& pip install requests
RUN yum install -y python3 python3-devel \ RUN yum install -y python3 python3-devel \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\ && yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all && yum clean all \
&& pip3 install requests
RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \ RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc \ && echo "export LANG=en_US.utf8" >> /root/.bashrc \
......
...@@ -29,11 +29,13 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \ ...@@ -29,11 +29,13 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel \ RUN yum -y install python-devel sqlite-devel \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \ && curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \ && python get-pip.py >/dev/null \
&& rm get-pip.py && rm get-pip.py \
&& pip install requests
RUN yum install -y python3 python3-devel \ RUN yum install -y python3 python3-devel \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\ && yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all && yum clean all \
&& pip3 install requests
RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \ RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc \ && echo "export LANG=en_US.utf8" >> /root/.bashrc \
......
...@@ -19,11 +19,13 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \ ...@@ -19,11 +19,13 @@ RUN wget https://dl.google.com/go/go1.14.linux-amd64.tar.gz >/dev/null \
RUN yum -y install python-devel sqlite-devel \ RUN yum -y install python-devel sqlite-devel \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \ && curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py >/dev/null \
&& python get-pip.py >/dev/null \ && python get-pip.py >/dev/null \
&& rm get-pip.py && rm get-pip.py \
&& pip install requests
RUN yum install -y python3 python3-devel \ RUN yum install -y python3 python3-devel \
&& yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\ && yum -y install epel-release && yum -y install patchelf libXext libSM libXrender\
&& yum clean all && yum clean all \
&& pip3 install requests
RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \ RUN localedef -c -i en_US -f UTF-8 en_US.UTF-8 \
&& echo "export LANG=en_US.utf8" >> /root/.bashrc \ && echo "export LANG=en_US.utf8" >> /root/.bashrc \
......
...@@ -485,6 +485,42 @@ function python_test_lac() { ...@@ -485,6 +485,42 @@ function python_test_lac() {
cd .. cd ..
} }
function python_test_encryption(){
#pwd: /Serving/python/examples
cd encryption
sh get_data.sh
local TYPE=$1
export SERVING_BIN=${SERIVNG_WORKDIR}/build-server-${TYPE}/core/general-server/serving
case $TYPE in
CPU)
#check_cmd "python encrypt.py"
#sleep 5
check_cmd "python -m paddle_serving_server.serve --model encrypt_server/ --port 9300 --use_encryption_model > /dev/null &"
sleep 5
check_cmd "python test_client.py encrypt_client/serving_client_conf.prototxt"
kill_server_process
;;
GPU)
#check_cmd "python encrypt.py"
#sleep 5
check_cmd "python -m paddle_serving_server_gpu.serve --model encrypt_server/ --port 9300 --use_encryption_model --gpu_ids 0"
sleep 5
check_cmd "python test_client.py encrypt_client/serving_client_conf.prototxt"
kill_servere_process
;;
*)
echo "error type"
exit 1
;;
esac
echo "encryption $TYPE test finished as expected"
setproxy
unset SERVING_BIN
cd ..
}
function java_run_test() { function java_run_test() {
# pwd: /Serving # pwd: /Serving
local TYPE=$1 local TYPE=$1
...@@ -921,6 +957,7 @@ function python_run_test() { ...@@ -921,6 +957,7 @@ function python_run_test() {
python_test_lac $TYPE # pwd: /Serving/python/examples python_test_lac $TYPE # pwd: /Serving/python/examples
python_test_multi_process $TYPE # pwd: /Serving/python/examples python_test_multi_process $TYPE # pwd: /Serving/python/examples
python_test_multi_fetch $TYPE # pwd: /Serving/python/examples python_test_multi_fetch $TYPE # pwd: /Serving/python/examples
python_test_encryption $TYPE # pwd: /Serving/python/examples
python_test_yolov4 $TYPE # pwd: /Serving/python/examples python_test_yolov4 $TYPE # pwd: /Serving/python/examples
python_test_grpc_impl $TYPE # pwd: /Serving/python/examples python_test_grpc_impl $TYPE # pwd: /Serving/python/examples
python_test_resnet50 $TYPE # pwd: /Serving/python/examples python_test_resnet50 $TYPE # pwd: /Serving/python/examples
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册