未验证 提交 9171a266 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1350 from bjjwwang/v063

 [WIP]V063 modify
......@@ -28,5 +28,5 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["save_infer_model/scale_0"])
print("{} {}".format(fetch_map["save_infer_model/scale_0"][0], data[0][1][0]))
......@@ -61,6 +61,8 @@ def serve_args():
default=False,
action="store_true",
help="Use TensorRT Calibration")
parser.add_argument(
"--encryption_rpc_port", type=int, required=False, default=9292, help="Port of encryption model, only valid for arg.use_encryption_model")
parser.add_argument(
"--mem_optim_off",
default=False,
......@@ -114,6 +116,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
max_body_size = args.max_body_size
use_mkl = args.use_mkl
use_encryption_model = args.use_encryption_model
encryption_rpc_port = args.encryption_rpc_port
use_multilang = args.use_multilang
if model == "":
......@@ -160,6 +163,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
server.set_precision(args.precision)
server.set_use_calib(args.use_calib)
server.use_encryption_model(use_encryption_model)
server.encryption_rpc_port(encryption_rpc_port)
if args.product_name != None:
server.set_product_name(args.product_name)
if args.container_id != None:
......@@ -246,7 +250,8 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
workdir=workdir,
port=port,
device=device,
use_encryption_model=args.use_encryption_model)
use_encryption_model=args.use_encryption_model,
encryption_rpc_port=args.encryption_rpc_port)
if gpuid >= 0:
server.set_gpuid(gpuid)
server.run_server()
......@@ -293,7 +298,8 @@ def start_multi_card(args, serving_port=None): # pylint: disable=doc-string-mis
class MainService(BaseHTTPRequestHandler):
def get_available_port(self):
default_port = 12000
global encryption_rpc_port
default_port = encryption_rpc_port
for i in range(1000):
if port_is_available(default_port + i):
return default_port + i
......@@ -381,6 +387,7 @@ if __name__ == "__main__":
if args.name == "None":
from .web_service import port_is_available
if args.use_encryption_model:
encryption_rpc_port = args.encryption_rpc_port
p_flag = False
p = None
serving_port = 0
......
......@@ -430,6 +430,7 @@ class Server(object):
port=9292,
device="cpu",
use_encryption_model=False,
encryption_rpc_port=9293,
cube_conf=None):
if workdir == None:
workdir = "./tmp"
......@@ -442,7 +443,7 @@ class Server(object):
if not self.port_is_available(port):
raise SystemExit("Port {} is already used".format(port))
print("set brpc port here: {}".format(port))
self.set_port(port)
self._prepare_resource(workdir, cube_conf)
self._prepare_engine(self.model_config_paths, device,
......@@ -588,6 +589,9 @@ class MultiLangServer(object):
def use_encryption_model(self, flag=False):
self.encryption_model = flag
def encryption_rpc_port(self, port=9293):
self.encryption_rpc_port=port
def set_port(self, port):
self.gport_ = port
......@@ -673,16 +677,22 @@ class MultiLangServer(object):
port=9292,
device="cpu",
use_encryption_model=False,
encryption_rpc_port=9293,
cube_conf=None):
if not self._port_is_available(port):
raise SystemExit("Port {} is already used".format(port))
default_port = 12000
self.port_list_ = []
for i in range(1000):
if default_port + i != port and self._port_is_available(default_port
+ i):
self.port_list_.append(default_port + i)
break
if use_encryption_model is True and self._port_is_available(encryption_rpc_port) is False:
raise SystemExit("Encryption Rpc Port {} is already used".format(encryption_rpc_port))
if use_encryption_model is False:
default_port = 12000
self.port_list_ = []
for i in range(1000):
if default_port + i != port and self._port_is_available(default_port + i):
self.port_list_.append(default_port + i)
break
else:
self.port_list_ = []
self.port_list_.append(encryption_rpc_port)
self.bserver_.prepare_server(
workdir=workdir,
port=self.port_list_[0],
......
......@@ -75,7 +75,8 @@ class Op(object):
self._retry = max(1, retry)
self._batch_size = batch_size
self._auto_batching_timeout = auto_batching_timeout
self._use_encryption_model = None
self._encryption_key = ""
self._input = None
self._outputs = []
......@@ -110,7 +111,11 @@ class Op(object):
self._fetch_names = conf.get("fetch_list")
if self._client_config is None:
self._client_config = conf.get("client_config")
if self._use_encryption_model is None:
print ("config use_encryption model here", conf.get("use_encryption_model"))
self._use_encryption_model = conf.get("use_encryption_model")
if self._encryption_key is None or self._encryption_key=="":
self._encryption_key = conf.get("encryption_key")
if self._timeout is None:
self._timeout = conf["timeout"]
if self._timeout > 0:
......@@ -343,7 +348,12 @@ class Op(object):
self._fetch_names = client.fetch_names_
_LOGGER.info("Op({}) has no fetch name set. So fetch all vars")
if self.client_type != "local_predictor":
client.connect(server_endpoints)
if self._use_encryption_model is None or self._use_encryption_model is False:
client.connect(server_endpoints)
else:
print("connect to encryption rpc client")
client.use_key(self._encryption_key)
client.connect(server_endpoints, encryption=True)
return client
def get_input_ops(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册