提交 59d9ad8c 编写于 作者: H HexToString

update http_session

上级 3be11a3a
......@@ -84,7 +84,7 @@ class GeneralClient(object):
self.feed_shapes_ = {}
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
self.timeout_ms = 200000
self.timeout_s = 20
self.ip = ip
self.port = port
self.server_port = port
......@@ -96,6 +96,17 @@ class GeneralClient(object):
self.http_proto = True
self.max_body_size = 512 * 1024 * 1024
self.use_grpc_client = False
# 使用连接池能够不用反复建立连接
self.requests_session = requests.session()
# 初始化grpc_stub
options = [('grpc.max_receive_message_length', self.max_body_size),
('grpc.max_send_message_length', self.max_body_size)]
endpoints = [self.ip + ":" + self.server_port]
g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub(
self.channel_)
def load_client_config(self, model_config_path_list):
if isinstance(model_config_path_list, str):
......@@ -155,21 +166,24 @@ class GeneralClient(object):
def set_max_body_size(self, max_body_size):
self.max_body_size = max_body_size
self.init_grpc_stub()
def set_timeout_ms(self, timeout_ms):
if not isinstance(timeout_ms, int):
raise ValueError("timeout_ms must be int type.")
def set_timeout_s(self, timeout_s):
if not isinstance(timeout_s, int):
raise ValueError("timeout_s must be int type.")
else:
self.timeout_ms = timeout_ms
self.timeout_s = timeout_s
def set_ip(self, ip):
self.ip = ip
self.init_grpc_stub()
def set_service_name(self, service_name):
self.service_name = service_name
def set_port(self, port):
self.port = port
self.init_grpc_stub()
def set_request_compress(self, try_request_gzip):
self.try_request_gzip = try_request_gzip
......@@ -195,13 +209,13 @@ class GeneralClient(object):
req = json.dumps({"key": base64.b64encode(self.key).decode()})
else:
req = json.dumps({})
r = requests.post(encrypt_url, req)
result = r.json()
if "endpoint_list" not in result:
raise ValueError("server not ready")
else:
self.server_port = str(result["endpoint_list"][0])
print("rpc port is ", self.server_port)
with requests.post(encrypt_url, data=req, timeout=self.timeout_s) as r:
result = r.json()
if "endpoint_list" not in result:
raise ValueError("server not ready")
else:
self.server_port = str(result["endpoint_list"][0])
print("rpc port is ", self.server_port)
def get_feed_names(self):
return self.feed_names_
......@@ -444,7 +458,10 @@ class GeneralClient(object):
try:
if self.try_request_gzip and self.total_data_number > 512:
origin_data = postData
postData = gzip.compress(bytes(postData, 'utf-8'))
if http_proto:
postData = gzip.compress(postData)
else:
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
......@@ -453,10 +470,13 @@ class GeneralClient(object):
print("compress error, we will use the no-compress data")
headers.pop("Content-Encoding", "nokey")
postData = origin_data
# requests支持自动识别解压
try:
result = requests.post(url=web_url, headers=headers, data=postData)
result = self.requests_session.post(
url=web_url,
headers=headers,
data=postData,
timeout=self.timeout_s)
except:
print("http post error")
return None
......@@ -484,6 +504,15 @@ class GeneralClient(object):
postData = self.process_proto_data(feed_dict, fetch_list, batch, log_id)
try:
resp = self.stub_.inference(postData, timeout=self.timeout_s)
except:
print("Grpc inference error occur")
return None
else:
return resp
def init_grpc_stub(self):
# https://github.com/tensorflow/serving/issues/1382
options = [('grpc.max_receive_message_length', self.max_body_size),
('grpc.max_send_message_length', self.max_body_size)]
......@@ -493,10 +522,7 @@ class GeneralClient(object):
self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub(
self.channel_)
try:
resp = self.stub_.inference(postData, timeout=self.timeout_ms)
except:
print("Grpc inference error occur")
return None
else:
return resp
def __del__(self):
self.requests_session.close()
self.channel_.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册