未验证 提交 9171a266 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1350 from bjjwwang/v063

 [WIP]V063 modify
...@@ -28,5 +28,5 @@ test_reader = paddle.batch( ...@@ -28,5 +28,5 @@ test_reader = paddle.batch(
batch_size=1) batch_size=1)
for data in test_reader(): for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"]) fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["save_infer_model/scale_0"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0])) print("{} {}".format(fetch_map["save_infer_model/scale_0"][0], data[0][1][0]))
...@@ -61,6 +61,8 @@ def serve_args(): ...@@ -61,6 +61,8 @@ def serve_args():
default=False, default=False,
action="store_true", action="store_true",
help="Use TensorRT Calibration") help="Use TensorRT Calibration")
parser.add_argument(
"--encryption_rpc_port", type=int, required=False, default=9292, help="Port of encryption model, only valid for arg.use_encryption_model")
parser.add_argument( parser.add_argument(
"--mem_optim_off", "--mem_optim_off",
default=False, default=False,
...@@ -114,6 +116,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing ...@@ -114,6 +116,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
max_body_size = args.max_body_size max_body_size = args.max_body_size
use_mkl = args.use_mkl use_mkl = args.use_mkl
use_encryption_model = args.use_encryption_model use_encryption_model = args.use_encryption_model
encryption_rpc_port = args.encryption_rpc_port
use_multilang = args.use_multilang use_multilang = args.use_multilang
if model == "": if model == "":
...@@ -160,6 +163,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing ...@@ -160,6 +163,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
server.set_precision(args.precision) server.set_precision(args.precision)
server.set_use_calib(args.use_calib) server.set_use_calib(args.use_calib)
server.use_encryption_model(use_encryption_model) server.use_encryption_model(use_encryption_model)
server.encryption_rpc_port(encryption_rpc_port)
if args.product_name != None: if args.product_name != None:
server.set_product_name(args.product_name) server.set_product_name(args.product_name)
if args.container_id != None: if args.container_id != None:
...@@ -246,7 +250,8 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin ...@@ -246,7 +250,8 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
workdir=workdir, workdir=workdir,
port=port, port=port,
device=device, device=device,
use_encryption_model=args.use_encryption_model) use_encryption_model=args.use_encryption_model,
encryption_rpc_port=args.encryption_rpc_port)
if gpuid >= 0: if gpuid >= 0:
server.set_gpuid(gpuid) server.set_gpuid(gpuid)
server.run_server() server.run_server()
...@@ -293,7 +298,8 @@ def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-mis ...@@ -293,7 +298,8 @@ def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-mis
class MainService(BaseHTTPRequestHandler): class MainService(BaseHTTPRequestHandler):
def get_available_port(self): def get_available_port(self):
default_port = 12000 global encryption_rpc_port
default_port = encryption_rpc_port
for i in range(1000): for i in range(1000):
if port_is_available(default_port + i): if port_is_available(default_port + i):
return default_port + i return default_port + i
...@@ -381,6 +387,7 @@ if __name__ == "__main__": ...@@ -381,6 +387,7 @@ if __name__ == "__main__":
if args.name == "None": if args.name == "None":
from .web_service import port_is_available from .web_service import port_is_available
if args.use_encryption_model: if args.use_encryption_model:
encryption_rpc_port = args.encryption_rpc_port
p_flag = False p_flag = False
p = None p = None
serving_port = 0 serving_port = 0
......
...@@ -430,6 +430,7 @@ class Server(object): ...@@ -430,6 +430,7 @@ class Server(object):
port=9292, port=9292,
device="cpu", device="cpu",
use_encryption_model=False, use_encryption_model=False,
encryption_rpc_port=9293,
cube_conf=None): cube_conf=None):
if workdir == None: if workdir == None:
workdir = "./tmp" workdir = "./tmp"
...@@ -442,7 +443,7 @@ class Server(object): ...@@ -442,7 +443,7 @@ class Server(object):
if not self.port_is_available(port): if not self.port_is_available(port):
raise SystemExit("Port {} is already used".format(port)) raise SystemExit("Port {} is already used".format(port))
print("set brpc port here: {}".format(port))
self.set_port(port) self.set_port(port)
self._prepare_resource(workdir, cube_conf) self._prepare_resource(workdir, cube_conf)
self._prepare_engine(self.model_config_paths, device, self._prepare_engine(self.model_config_paths, device,
...@@ -588,6 +589,9 @@ class MultiLangServer(object): ...@@ -588,6 +589,9 @@ class MultiLangServer(object):
def use_encryption_model(self, flag=False): def use_encryption_model(self, flag=False):
self.encryption_model = flag self.encryption_model = flag
def encryption_rpc_port(self, port=9293):
self.encryption_rpc_port=port
def set_port(self, port): def set_port(self, port):
self.gport_ = port self.gport_ = port
...@@ -673,16 +677,22 @@ class MultiLangServer(object): ...@@ -673,16 +677,22 @@ class MultiLangServer(object):
port=9292, port=9292,
device="cpu", device="cpu",
use_encryption_model=False, use_encryption_model=False,
encryption_rpc_port=9293,
cube_conf=None): cube_conf=None):
if not self._port_is_available(port): if not self._port_is_available(port):
raise SystemExit("Port {} is already used".format(port)) raise SystemExit("Port {} is already used".format(port))
if use_encryption_model is True and self._port_is_available(encryption_rpc_port) is False:
raise SystemExit("Encryption Rpc Port {} is already used".format(encryption_rpc_port))
if use_encryption_model is False:
default_port = 12000 default_port = 12000
self.port_list_ = [] self.port_list_ = []
for i in range(1000): for i in range(1000):
if default_port + i != port and self._port_is_available(default_port if default_port + i != port and self._port_is_available(default_port + i):
+ i):
self.port_list_.append(default_port + i) self.port_list_.append(default_port + i)
break break
else:
self.port_list_ = []
self.port_list_.append(encryption_rpc_port)
self.bserver_.prepare_server( self.bserver_.prepare_server(
workdir=workdir, workdir=workdir,
port=self.port_list_[0], port=self.port_list_[0],
......
...@@ -75,7 +75,8 @@ class Op(object): ...@@ -75,7 +75,8 @@ class Op(object):
self._retry = max(1, retry) self._retry = max(1, retry)
self._batch_size = batch_size self._batch_size = batch_size
self._auto_batching_timeout = auto_batching_timeout self._auto_batching_timeout = auto_batching_timeout
self._use_encryption_model = None
self._encryption_key = ""
self._input = None self._input = None
self._outputs = [] self._outputs = []
...@@ -110,7 +111,11 @@ class Op(object): ...@@ -110,7 +111,11 @@ class Op(object):
self._fetch_names = conf.get("fetch_list") self._fetch_names = conf.get("fetch_list")
if self._client_config is None: if self._client_config is None:
self._client_config = conf.get("client_config") self._client_config = conf.get("client_config")
if self._use_encryption_model is None:
print ("config use_encryption model here", conf.get("use_encryption_model"))
self._use_encryption_model = conf.get("use_encryption_model")
if self._encryption_key is None or self._encryption_key=="":
self._encryption_key = conf.get("encryption_key")
if self._timeout is None: if self._timeout is None:
self._timeout = conf["timeout"] self._timeout = conf["timeout"]
if self._timeout > 0: if self._timeout > 0:
...@@ -343,7 +348,12 @@ class Op(object): ...@@ -343,7 +348,12 @@ class Op(object):
self._fetch_names = client.fetch_names_ self._fetch_names = client.fetch_names_
_LOGGER.info("Op({}) has no fetch name set. So fetch all vars") _LOGGER.info("Op({}) has no fetch name set. So fetch all vars")
if self.client_type != "local_predictor": if self.client_type != "local_predictor":
if self._use_encryption_model is None or self._use_encryption_model is False:
client.connect(server_endpoints) client.connect(server_endpoints)
else:
print("connect to encryption rpc client")
client.use_key(self._encryption_key)
client.connect(server_endpoints, encryption=True)
return client return client
def get_input_ops(self): def get_input_ops(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册