提交 3be11a3a 编写于 作者: H HexToString

temp fix webService and add try catch

上级 6bfaada2
...@@ -197,7 +197,6 @@ class GeneralClient(object): ...@@ -197,7 +197,6 @@ class GeneralClient(object):
req = json.dumps({}) req = json.dumps({})
r = requests.post(encrypt_url, req) r = requests.post(encrypt_url, req)
result = r.json() result = r.json()
print(result)
if "endpoint_list" not in result: if "endpoint_list" not in result:
raise ValueError("server not ready") raise ValueError("server not ready")
else: else:
...@@ -442,24 +441,36 @@ class GeneralClient(object): ...@@ -442,24 +441,36 @@ class GeneralClient(object):
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
# 当数据区长度大于512字节时才压缩. # 当数据区长度大于512字节时才压缩.
if self.try_request_gzip and self.total_data_number > 512: try:
postData = gzip.compress(bytes(postData, 'utf-8')) if self.try_request_gzip and self.total_data_number > 512:
headers["Content-Encoding"] = "gzip" origin_data = postData
if self.try_response_gzip: postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Accept-encoding"] = "gzip" headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
# 压缩异常,使用原始数据
except:
print("compress error, we will use the no-compress data")
headers.pop("Content-Encoding", "nokey")
postData = origin_data
# requests支持自动识别解压 # requests支持自动识别解压
result = requests.post(url=web_url, headers=headers, data=postData) try:
if result == None: result = requests.post(url=web_url, headers=headers, data=postData)
except:
print("http post error")
return None return None
if result.status_code == 200: else:
if result.headers["Content-Type"] == 'application/proto': if result == None:
response = general_model_service_pb2.Response() return None
response.ParseFromString(result.content) if result.status_code == 200:
return response if result.headers["Content-Type"] == 'application/proto':
else: response = general_model_service_pb2.Response()
return result.json() response.ParseFromString(result.content)
return result return response
else:
return result.json()
return result
def grpc_client_predict(self, def grpc_client_predict(self,
feed=None, feed=None,
...@@ -479,10 +490,13 @@ class GeneralClient(object): ...@@ -479,10 +490,13 @@ class GeneralClient(object):
endpoints = [self.ip + ":" + self.server_port] endpoints = [self.ip + ":" + self.server_port]
g_endpoint = 'ipv4:{}'.format(','.join(endpoints)) g_endpoint = 'ipv4:{}'.format(','.join(endpoints))
print("my endpoint is ", g_endpoint)
self.channel_ = grpc.insecure_channel(g_endpoint, options=options) self.channel_ = grpc.insecure_channel(g_endpoint, options=options)
self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub( self.stub_ = general_model_service_pb2_grpc.GeneralModelServiceStub(
self.channel_) self.channel_)
resp = self.stub_.inference(postData, timeout=self.timeout_ms) try:
resp = self.stub_.inference(postData, timeout=self.timeout_ms)
return resp except:
print("Grpc inference error occur")
return None
else:
return resp
...@@ -123,7 +123,7 @@ class WebService(object): ...@@ -123,7 +123,7 @@ class WebService(object):
workdir, workdir,
port=9292, port=9292,
gpus=None, gpus=None,
thread_num=2, thread_num=4,
mem_optim=True, mem_optim=True,
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
...@@ -236,7 +236,7 @@ class WebService(object): ...@@ -236,7 +236,7 @@ class WebService(object):
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
ir_optim=False, ir_optim=False,
thread_num=2, thread_num=4,
mem_optim=True, mem_optim=True,
use_trt=False, use_trt=False,
gpu_multi_stream=False, gpu_multi_stream=False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册