提交 88762e26 编写于 作者: M MRXLT

add check

上级 9c12f7c3
...@@ -21,6 +21,7 @@ import argparse ...@@ -21,6 +21,7 @@ import argparse
import sys import sys
import json import json
import base64 import base64
import time
from multiprocessing import Process from multiprocessing import Process
from web_service import WebService, port_is_available from web_service import WebService, port_is_available
from flask import Flask, request from flask import Flask, request
...@@ -128,6 +129,15 @@ class MainService(BaseHTTPRequestHandler): ...@@ -128,6 +129,15 @@ class MainService(BaseHTTPRequestHandler):
f.write(key) f.write(key)
return True return True
def check_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "r") as f:
cur_key = f.read()
return (key == cur_key)
def start(self, post_data): def start(self, post_data):
post_data = json.loads(post_data) post_data = json.loads(post_data)
global p_flag global p_flag
...@@ -138,12 +148,20 @@ class MainService(BaseHTTPRequestHandler): ...@@ -138,12 +148,20 @@ class MainService(BaseHTTPRequestHandler):
print("not found key in request") print("not found key in request")
return False return False
global serving_port global serving_port
global p
serving_port = self.get_available_port() serving_port = self.get_available_port()
p = Process(target=self.start_serving) p = Process(target=self.start_serving)
p.start() p.start()
p_flag = True time.sleep(3)
if p.is_alive():
p_flag = True
else:
return False
else: else:
if not p.is_alive(): if p.is_alive():
if not self.check_key(post_data):
return False
else:
return False return False
return True return True
...@@ -165,6 +183,7 @@ if __name__ == "__main__": ...@@ -165,6 +183,7 @@ if __name__ == "__main__":
if args.name == "None": if args.name == "None":
if args.use_encryption_model: if args.use_encryption_model:
p_flag = False p_flag = False
p = None
serving_port = 0 serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService) server = HTTPServer(('localhost', int(args.port)), MainService)
print( print(
......
...@@ -131,6 +131,15 @@ class MainService(BaseHTTPRequestHandler): ...@@ -131,6 +131,15 @@ class MainService(BaseHTTPRequestHandler):
f.write(key) f.write(key)
return True return True
def check_key(self, post_data):
if "key" not in post_data:
return False
else:
key = base64.b64decode(post_data["key"])
with open(args.model + "/key", "r") as f:
cur_key = f.read()
return (key == cur_key)
def start(self, post_data): def start(self, post_data):
post_data = json.loads(post_data) post_data = json.loads(post_data)
global p_flag global p_flag
...@@ -141,12 +150,20 @@ class MainService(BaseHTTPRequestHandler): ...@@ -141,12 +150,20 @@ class MainService(BaseHTTPRequestHandler):
print("not found key in request") print("not found key in request")
return False return False
global serving_port global serving_port
global p
serving_port = self.get_available_port() serving_port = self.get_available_port()
p = Process(target=self.start_serving) p = Process(target=self.start_serving)
p.start() p.start()
p_flag = True time.sleep(3)
if p.is_alive():
p_flag = True
else:
return False
else: else:
if not p.is_alive(): if p.is_alive():
if not self.check_key(post_data):
return False
else:
return False return False
return True return True
...@@ -169,6 +186,7 @@ if __name__ == "__main__": ...@@ -169,6 +186,7 @@ if __name__ == "__main__":
from .web_service import port_is_available from .web_service import port_is_available
if args.use_encryption_model: if args.use_encryption_model:
p_flag = False p_flag = False
p = None
serving_port = 0 serving_port = 0
server = HTTPServer(('localhost', int(args.port)), MainService) server = HTTPServer(('localhost', int(args.port)), MainService)
print( print(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册