提交 54166c03 编写于 作者: M MRXLT

add switch for client

上级 b0b72230
...@@ -13,16 +13,19 @@ ...@@ -13,16 +13,19 @@
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing # pylint: disable=doc-string-missing
import paddle_serving_client
import os import os
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import numpy as np
import time import time
import sys import sys
import requests
import json
import base64
import numpy as np
import paddle_serving_client
import google.protobuf.text_format
import grpc import grpc
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
from .proto import multi_lang_general_model_service_pb2 from .proto import multi_lang_general_model_service_pb2
sys.path.append( sys.path.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto')) os.path.join(os.path.abspath(os.path.dirname(__file__)), 'proto'))
...@@ -158,6 +161,7 @@ class Client(object): ...@@ -158,6 +161,7 @@ class Client(object):
self.fetch_names_to_idx_ = {} self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set() self.lod_tensor_set = set()
self.feed_tensor_len = {} self.feed_tensor_len = {}
self.key = None
for i, var in enumerate(model_conf.feed_var): for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i self.feed_names_to_idx_[var.alias_name] = i
...@@ -190,9 +194,14 @@ class Client(object): ...@@ -190,9 +194,14 @@ class Client(object):
else: else:
self.rpc_timeout_ms = rpc_timeout self.rpc_timeout_ms = rpc_timeout
def use_key(self, key_filename):
with open(key_filename, "r") as f:
self.key = f.read()
def get_serving_port(self, endpoints): def get_serving_port(self, endpoints):
import requests if self.key is not None:
import json req = json.dumps({"key": base64.b64encode(self.key)})
else:
req = json.dumps({}) req = json.dumps({})
r = requests.post("http://" + endpoints[0], req) r = requests.post("http://" + endpoints[0], req)
result = r.json() result = r.json()
...@@ -206,7 +215,7 @@ class Client(object): ...@@ -206,7 +215,7 @@ class Client(object):
] ]
return endpoints return endpoints
def connect(self, endpoints=None): def connect(self, endpoints=None, encryption=False):
# check whether current endpoint is available # check whether current endpoint is available
# init from client config # init from client config
# create predictor here # create predictor here
...@@ -216,6 +225,7 @@ class Client(object): ...@@ -216,6 +225,7 @@ class Client(object):
"You must set the endpoints parameter or use add_variant function to create a variant." "You must set the endpoints parameter or use add_variant function to create a variant."
) )
else: else:
if encryption:
endpoints = self.get_serving_port(endpoints) endpoints = self.get_serving_port(endpoints)
if self.predictor_sdk_ is None: if self.predictor_sdk_ is None:
self.add_variant('default_tag_{}'.format(id(self)), endpoints, self.add_variant('default_tag_{}'.format(id(self)), endpoints,
......
...@@ -18,11 +18,13 @@ Usage: ...@@ -18,11 +18,13 @@ Usage:
python -m paddle_serving_server.serve --model ./serving_server_model --port 9292 python -m paddle_serving_server.serve --model ./serving_server_model --port 9292
""" """
import argparse import argparse
from web_service import WebService import sys
import json
import base64
from multiprocessing import Process
from web_service import WebService, port_is_available
from flask import Flask, request from flask import Flask, request
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
import json
import subprocess
def parse_args(): # pylint: disable=doc-string-missing def parse_args(): # pylint: disable=doc-string-missing
...@@ -64,11 +66,11 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -64,11 +66,11 @@ def parse_args(): # pylint: disable=doc-string-missing
return parser.parse_args() return parser.parse_args()
def start_standard_model(): # pylint: disable=doc-string-missing def start_standard_model(serving_port): # pylint: disable=doc-string-missing
args = parse_args() args = parse_args()
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
port = args.port port = serving_port
workdir = args.workdir workdir = args.workdir
device = args.device device = args.device
mem_optim = args.mem_optim mem_optim = args.mem_optim
...@@ -107,9 +109,70 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -107,9 +109,70 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.run_server() server.run_server()
def start_serving(): class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
def start_serving(self):
start_standard_model(serving_port)
def get_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "w") as f:
f.write(key)
return True
def start(self, post_data):
post_data = json.loads(post_data)
global p_flag
if not p_flag:
if args.use_encryption_model:
print("waiting key for model")
if not self.get_key(post_data):
print("not found key in request")
return False
global serving_port
serving_port = self.get_available_port()
p = Process(target=self.start_serving)
p.start()
p_flag = True
else:
if not p.is_alive():
return False
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start(post_data):
response = {"endpoint_list": [serving_port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__":
args = parse_args()
if args.name == "None": if args.name == "None":
start_standard_model() if args.use_encryption_model:
p_flag = False
serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService)
print(
'Starting encryption server, waiting for key from client, use <Ctrl-C> to stop'
)
server.serve_forever()
else:
start_standard_model(args.port)
else: else:
service = WebService(name=args.name) service = WebService(name=args.name)
service.load_model_config(args.model) service.load_model_config(args.model)
...@@ -133,37 +196,3 @@ def start_serving(): ...@@ -133,37 +196,3 @@ def start_serving():
port=service.port, port=service.port,
threaded=False, threaded=False,
processes=4) processes=4)
class MainService(BaseHTTPRequestHandler):
def start(self):
global p_flag
print(p_flag)
if not p_flag:
from multiprocessing import Process
p = Process(target=start_serving)
p.start()
p_flag = True
else:
pass
return True
def do_POST(self):
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
if self.start():
response = {"endpoint_list": [args.port]}
else:
response = {"message": "start serving failed"}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response))
if __name__ == "__main__":
p_flag = False
args = parse_args()
server = HTTPServer(('localhost', 8080), MainService)
print('Starting server, use <Ctrl-C> to stop')
server.serve_forever()
...@@ -22,6 +22,16 @@ from contextlib import closing ...@@ -22,6 +22,16 @@ from contextlib import closing
import socket import socket
def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
class WebService(object): class WebService(object):
def __init__(self, name="default_service"): def __init__(self, name="default_service"):
self.name = name self.name = name
...@@ -46,15 +56,6 @@ class WebService(object): ...@@ -46,15 +56,6 @@ class WebService(object):
workdir=self.workdir, port=self.port_list[0], device=self.device) workdir=self.workdir, port=self.port_list[0], device=self.device)
server.run_server() server.run_server()
def port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
if result != 0:
return True
else:
return False
def prepare_server(self, workdir="", port=9393, device="cpu"): def prepare_server(self, workdir="", port=9393, device="cpu"):
self.workdir = workdir self.workdir = workdir
self.port = port self.port = port
...@@ -62,7 +63,7 @@ class WebService(object): ...@@ -62,7 +63,7 @@ class WebService(object):
default_port = 12000 default_port = 12000
self.port_list = [] self.port_list = []
for i in range(1000): for i in range(1000):
if self.port_is_available(default_port + i): if port_is_available(default_port + i):
self.port_list.append(default_port + i) self.port_list.append(default_port + i)
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册