提交 3be11a3a 编写于 作者: H HexToString

temp fix webService and add try catch

上级 6bfaada2
......@@ -197,7 +197,6 @@ class GeneralClient(object):
req = json.dumps({})
r = requests.post(encrypt_url, req)
result = r.json()
print(result)
if "endpoint_list" not in result:
raise ValueError("server not ready")
else:
......@@ -442,24 +441,36 @@ class GeneralClient(object):
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and self.total_data_number > 512:
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
try:
if self.try_request_gzip and self.total_data_number > 512:
origin_data = postData
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
# 压缩异常,使用原始数据
except:
print("compress error, we will use the no-compress data")
headers.pop("Content-Encoding", "nokey")
postData = origin_data
# requests支持自动识别解压
result = requests.post(url=web_url, headers=headers, data=postData)
if result == None:
try:
result = requests.post(url=web_url, headers=headers, data=postData)
except:
print("http post error")
return None
if result.status_code == 200:
if result.headers["Content-Type"] == 'application/proto':
response = general_model_service_pb2.Response()
response.ParseFromString(result.content)
return response
else:
return result.json()
return result
else:
if result == None:
return None
if result.status_code == 200:
if result.headers["Content-Type"] == 'application/proto':
response = general_model_service_pb2.Response()
response.ParseFromString(result.content)
return response
else:
return result.json()
return result
def grpc_client_predict(self,
feed=None,
......@@ -479,10 +490,13 @@ class GeneralClient(object):
endpoints = [self.ip + ":" + self.server_port]
g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
print("my endpoint is ", g_endpoint)
self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub(
self.channel_)
resp = self.stub_.inference(postData, timeout=self.timeout_ms)
return resp
try:
resp = self.stub_.inference(postData, timeout=self.timeout_ms)
except:
print("Grpc inference error occur")
return None
else:
return resp
......@@ -123,7 +123,7 @@ class WebService(object):
workdir,
port=9292,
gpus=None,
thread_num=2,
thread_num=4,
mem_optim=True,
use_lite=False,
use_xpu=False,
......@@ -236,7 +236,7 @@ class WebService(object):
use_lite=False,
use_xpu=False,
ir_optim=False,
thread_num=2,
thread_num=4,
mem_optim=True,
use_trt=False,
gpu_multi_stream=False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册