提交 aacd85c4 编写于 作者: M MRXLT

add encryption service for gpu

上级 6866bff0
......@@ -68,6 +68,11 @@ def serve_args():
type=int,
default=512 * 1024 * 1024,
help="Limit sizes of messages")
parser.add_argument(
"--use_encryption_model",
default=False,
action="store_true",
help="Use encryption model")
return parser.parse_args()
......@@ -244,7 +249,7 @@ class Server(object):
def set_gpuid(self, gpuid=0):
self.gpuid = gpuid
def _prepare_engine(self, model_config_paths, device):
def _prepare_engine(self, model_config_paths, device, use_encryption_model):
if self.model_toolkit_conf == None:
self.model_toolkit_conf = server_sdk.ModelToolkitConf()
......@@ -265,9 +270,15 @@ class Server(object):
engine.force_update_static_cache = False
if device == "cpu":
engine.type = "FLUID_CPU_ANALYSIS_DIR"
if use_encryption_model:
engine.type = "FLUID_CPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu":
engine.type = "FLUID_GPU_ANALYSIS_DIR"
if use_encryption_model:
engine.type = "FLUID_GPU_ANALYSIS_ENCRPT"
else:
engine.type = "FLUID_GPU_ANALYSIS_DIR"
self.model_toolkit_conf.engines.extend([engine])
......@@ -401,7 +412,11 @@ class Server(object):
os.chdir(self.cur_path)
self.bin_path = self.server_path + "/serving"
def prepare_server(self, workdir=None, port=9292, device="cpu"):
def prepare_server(self,
workdir=None,
port=9292,
device="cpu",
use_encryption_model=False):
if workdir == None:
workdir = "./tmp"
os.system("mkdir {}".format(workdir))
......@@ -414,7 +429,8 @@ class Server(object):
self.set_port(port)
self._prepare_resource(workdir)
self._prepare_engine(self.model_config_paths, device)
self._prepare_engine(self.model_config_paths, device,
use_encryption_model)
self._prepare_infer_service(port)
self.workdir = workdir
......
......@@ -19,19 +19,21 @@ Usage:
"""
import argparse
import os
import json
import base64
from multiprocessing import Pool, Process
from paddle_serving_server_gpu import serve_args
from flask import Flask, request
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-missing
def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-string-missing
gpuid = int(gpuid)
device = "gpu"
port = args.port
if gpuid == -1:
device = "cpu"
elif gpuid >= 0:
port = args.port + index
port = port + index
thread_num = args.thread
model = args.model
mem_optim = args.mem_optim
......@@ -62,14 +64,20 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss
server.set_max_body_size(max_body_size)
server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device)
server.prepare_server(
workdir=workdir,
port=port,
device=device,
use_encryption_model=args.use_encryption_model)
if gpuid >= 0:
server.set_gpuid(gpuid)
server.run_server()
def start_multi_card(args): # pylint: disable=doc-string-missing
def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-missing
gpus = ""
if serving_port == None:
serving_port = args.port
if args.gpu_ids == "":
gpus = []
else:
......@@ -86,14 +94,16 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
env_gpus = []
if len(gpus) <= 0:
print("gpu_ids not set, going to run cpu service.")
start_gpu_card_model(-1, -1, args)
start_gpu_card_model(-1, -1, serving_port, args)
else:
gpu_processes = []
for i, gpu_id in enumerate(gpus):
p = Process(
target=start_gpu_card_model, args=(
target=start_gpu_card_model,
args=(
i,
gpu_id,
serving_port,
args, ))
gpu_processes.append(p)
for p in gpu_processes:
......@@ -102,10 +112,71 @@ def start_multi_card(args): # pylint: disable=doc-string-missing
p.join()
class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
def start_serving(self):
start_multi_card(args, serving_port)
def get_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "w") as f:
f.write(key)
return True
def start(self, post_data):
post_data = json.loads(post_data)
global p_flag
if not p_flag:
if args.use_encryption_model:
print("waiting key for model")
if not self.get_key(post_data):
print("not found key in request")
return False
global serving_port
serving_port = self.get_available_port()
p = Process(target=self.start_serving)
p.start()
p_flag = True
else:
if not p.is_alive():
return False
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start(post_data):
response = {"endpoint_list": [serving_port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__":
args = serve_args()
if args.name == "None":
start_multi_card(args)
from .web_service import port_is_available
if args.use_encryption_model:
p_flag = False
serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService)
print(
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
start_multi_card(args)
else:
from .web_service import WebService
web_service = WebService(name=args.name)
......
......@@ -25,6 +25,16 @@ import numpy as np
import paddle_serving_server_gpu as serving
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
class WebService(object):
def __init__(self, name="default_service"):
self.name = name
......@@ -68,15 +78,6 @@ class WebService(object):
def _launch_rpc_service(self, service_idx):
self.rpc_service_list[service_idx].run_server()
def port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0):
self.workdir = workdir
self.port = port
......@@ -85,7 +86,7 @@ class WebService(object):
self.port_list = []
default_port = 12000
for i in range(1000):
if self.port_is_available(default_port + i):
if port_is_available(default_port + i):
self.port_list.append(default_port + i)
if len(self.port_list) > len(self.gpus):
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册