提交 aacd85c4 编写于 作者: M MRXLT

add encryption service for gpu

上级 6866bff0
...@@ -68,6 +68,11 @@ def serve_args(): ...@@ -68,6 +68,11 @@ def serve_args():
type=int, type=int,
default=512 * 1024 * 1024, default=512 * 1024 * 1024,
help="Limit sizes of messages") help="Limit sizes of messages")
parser.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
help="Use encryption model")
return parser.parse_args() return parser.parse_args()
...@@ -244,7 +249,7 @@ class Server(object): ...@@ -244,7 +249,7 @@ class Server(object):
def set_gpuid(self, gpuid=0): def set_gpuid(self, gpuid=0):
self.gpuid = gpuid self.gpuid = gpuid
def _prepare_engine(self, model_config_paths, device): def _prepare_engine(self, model_config_paths, device, use_encryption_model):
if self.model_toolkit_conf == None: if self.model_toolkit_conf == None:
self.model_toolkit_conf = server_sdk.ModelToolkitConf() self.model_toolkit_conf = server_sdk.ModelToolkitConf()
...@@ -265,9 +270,15 @@ class Server(object): ...@@ -265,9 +270,15 @@ class Server(object):
engine.force_update_static_cache = False engine.force_update_static_cache = False
if device == "cpu": if device == "cpu":
engine.type = "FLUID_CPU_ANALYSIS_DIR" if use_encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu": elif device == "gpu":
engine.type = "FLUID_GPU_ANALYSIS_DIR" if use_encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_GPU_ANALYSIS_DIR"
self.model_toolkit_conf.engines.extend([engine]) self.model_toolkit_conf.engines.extend([engine])
...@@ -401,7 +412,11 @@ class Server(object): ...@@ -401,7 +412,11 @@ class Server(object):
os.chdir(self.cur_path) os.chdir(self.cur_path)
self.bin_path = self.server_path + "/serving" self.bin_path = self.server_path + "/serving"
def prepare_server(self, workdir=None, port=9292, device="cpu"): def prepare_server(self,
workdir=None,
port=9292,
device="cpu",
use_encryption_model=False):
if workdir == None: if workdir == None:
workdir = "./tmp" workdir = "./tmp"
os.system("mkdir {}".format(workdir)) os.system("mkdir {}".format(workdir))
...@@ -414,7 +429,8 @@ class Server(object): ...@@ -414,7 +429,8 @@ class Server(object):
self.set_port(port) self.set_port(port)
self._prepare_resource(workdir) self._prepare_resource(workdir)
self._prepare_engine(self.model_config_paths, device) self._prepare_engine(self.model_config_paths, device,
use_encryption_model)
self._prepare_infer_service(port) self._prepare_infer_service(port)
self.workdir = workdir self.workdir = workdir
......
...@@ -19,19 +19,21 @@ Usage: ...@@ -19,19 +19,21 @@ Usage:
""" """
import argparse import argparse
import os import os
import json
import base64
from multiprocessing import Pool, Process from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args from paddle_serving_server_gpu import serve_args
from flask import Flask, request from flask import Flask, request
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-string-missing
gpuid = int(gpuid) gpuid = int(gpuid)
device = "gpu" device = "gpu"
port = args.port
if gpuid == -1: if gpuid == -1:
device = "cpu" device = "cpu"
elif gpuid >= 0: elif gpuid >= 0:
port = args.port + index port = port + index
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
mem_optim = args.mem_optim mem_optim = args.mem_optim
...@@ -62,14 +64,20 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss ...@@ -62,14 +64,20 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss
server.set_max_body_size(max_body_size) server.set_max_body_size(max_body_size)
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(
workdir=workdir,
port=port,
device=device,
use_encryption_model=args.use_encryption_model)
if gpuid >= 0: if gpuid >= 0:
server.set_gpuid(gpuid) server.set_gpuid(gpuid)
server.run_server() server.run_server()
def start_multi_card(args): # pylint: disable=doc-string-missing def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing
gpus = "" gpus = ""
if serving_port == None:
serving_port = args.port
if args.gpu_ids == "": if args.gpu_ids == "":
gpus = [] gpus = []
else: else:
...@@ -86,14 +94,16 @@ def start_multi_card(args): # pylint: disable=doc-string-missing ...@@ -86,14 +94,16 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
env_gpus = [] env_gpus = []
if len(gpus) <= 0: if len(gpus) <= 0:
print("gpu_ids not set, going to run cpu service.") print("gpu_ids not set, going to run cpu service.")
start_gpu_card_model(-1, -1, args) start_gpu_card_model(-1, -1, serving_port, args)
else: else:
gpu_processes = [] gpu_processes = []
for i, gpu_id in enumerate(gpus): for i, gpu_id in enumerate(gpus):
p = Process( p = Process(
target=start_gpu_card_model, args=( target=start_gpu_card_model,
args=(
i, i,
gpu_id, gpu_id,
serving_port,
args, )) args, ))
gpu_processes.append(p) gpu_processes.append(p)
for p in gpu_processes: for p in gpu_processes:
...@@ -102,10 +112,71 @@ def start_multi_card(args): # pylint: disable=doc-string-missing ...@@ -102,10 +112,71 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
p.join() p.join()
class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
def start_serving(self):
start_multi_card(args, serving_port)
def get_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "w") as f:
f.write(key)
return True
def start(self, post_data):
post_data = json.loads(post_data)
global p_flag
if not p_flag:
if args.use_encryption_model:
print("waiting key for model")
if not self.get_key(post_data):
print("not found key in request")
return False
global serving_port
serving_port = self.get_available_port()
p = Process(target=self.start_serving)
p.start()
p_flag = True
else:
if not p.is_alive():
return False
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start(post_data):
response = {"endpoint_list": [serving_port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__": if __name__ == "__main__":
args = serve_args() args = serve_args()
if args.name == "None": if args.name == "None":
start_multi_card(args) from .web_service import port_is_available
if args.use_encryption_model:
p_flag = False
serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService)
print(
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
start_multi_card(args)
else: else:
from .web_service import WebService from .web_service import WebService
web_service = WebService(name=args.name) web_service = WebService(name=args.name)
......
...@@ -25,6 +25,16 @@ import numpy as np ...@@ -25,6 +25,16 @@ import numpy as np
import paddle_serving_server_gpu as serving import paddle_serving_server_gpu as serving
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
class WebService(object): class WebService(object):
def __init__(self, name="default_service"): def __init__(self, name="default_service"):
self.name = name self.name = name
...@@ -68,15 +78,6 @@ class WebService(object): ...@@ -68,15 +78,6 @@ class WebService(object):
def _launch_rpc_service(self, service_idx): def _launch_rpc_service(self, service_idx):
self.rpc_service_list[service_idx].run_server() self.rpc_service_list[service_idx].run_server()
def port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0): def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0):
self.workdir = workdir self.workdir = workdir
self.port = port self.port = port
...@@ -85,7 +86,7 @@ class WebService(object): ...@@ -85,7 +86,7 @@ class WebService(object):
self.port_list = [] self.port_list = []
default_port = 12000 default_port = 12000
for i in range(1000): for i in range(1000):
if self.port_is_available(default_port + i): if port_is_available(default_port + i):
self.port_list.append(default_port + i) self.port_list.append(default_port + i)
if len(self.port_list) > len(self.gpus): if len(self.port_list) > len(self.gpus):
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册