提交 f6fb1c20 编写于 作者: M MRXLT

add model_name check

上级 e4263533
......@@ -66,7 +66,7 @@ class ServingCommand(BaseCommand):
from paddle_gpu_serving.run import BertServer
bs = BertServer(with_gpu=args.use_gpu)
bs.with_model(model_name=args.modules[0])
bs.run(gpu_index=args.gpu, port=args.port)
bs.run(gpu_index=args.gpu, port=int(args.port))
@staticmethod
def is_port_occupied(ip, port):
......
......@@ -18,6 +18,7 @@ import paddlehub as hub
import ujson
import random
from paddlehub.common.logger import logger
import socket
_ver = sys.version_info
is_py2 = (_ver[0] == 2)
......@@ -51,6 +52,7 @@ class BertService(object):
self.con_index = 0
self.load_balance = load_balance
self.server_list = []
self.serving_list = []
self.feed_var_names = ''
self.retry = retry
......@@ -71,29 +73,58 @@ class BertService(object):
def add_server(self, server='127.0.0.1:8010'):
self.server_list.append(server)
self.check_server()
def add_server_list(self, server_list):
for server_str in server_list:
self.server_list.append(server_str)
self.check_server()
def check_server(self):
for server in self.server_list:
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_ip = server.split(':')[0]
server_port = int(server.split(':')[1])
client.connect((server_ip, server_port))
client.send(b'pending server')
response = client.recv(1024).decode()
response_list = response.split('\t')
status_code = int(response_list[0].split(':')[1])
if status_code == 0:
server_model = response_list[1].split(':')[1]
if server_model == self.model_name:
serving_port = response_list[2].split(':')[1]
serving_ip = server_ip
self.serving_list.append(serving_ip + ':' + serving_port)
else:
logger.error(
'model_name not match, server {} using : {} '.format(
server, server_model))
else:
error_msg = response_list[1]
logger.error('connect server {} failed. {}'.format(
server, error_msg))
def request_server(self, request_msg):
if self.load_balance == 'round_robin':
try:
cur_con = httplib.HTTPConnection(
self.server_list[self.con_index])
self.serving_list[self.con_index])
cur_con.request('POST', "/BertService/inference", request_msg,
{"Content-Type": "application/json"})
response = cur_con.getresponse()
response_msg = response.read()
response_msg = ujson.loads(response_msg)
self.con_index += 1
self.con_index = self.con_index % len(self.server_list)
self.con_index = self.con_index % len(self.serving_list)
return response_msg
except BaseException as err:
logger.warning("Infer Error with server {} : {}".format(
self.server_list[self.con_index], err))
if len(self.server_list) == 0:
self.serving_list[self.con_index], err))
if len(self.serving_list) == 0:
logger.error('All server failed, process will exit')
return 'fail'
else:
......@@ -103,10 +134,10 @@ class BertService(object):
elif self.load_balance == 'random':
try:
random.seed()
self.con_index = random.randint(0, len(self.server_list) - 1)
self.con_index = random.randint(0, len(self.serving_list) - 1)
logger.info(self.con_index)
cur_con = httplib.HTTPConnection(
self.server_list[self.con_index])
self.serving_list[self.con_index])
cur_con.request('POST', "/BertService/inference", request_msg,
{"Content-Type": "application/json"})
response = cur_con.getresponse()
......@@ -117,21 +148,21 @@ class BertService(object):
except BaseException as err:
logger.warning("Infer Error with server {} : {}".format(
self.server_list[self.con_index], err))
if len(self.server_list) == 0:
self.serving_list[self.con_index], err))
if len(self.serving_list) == 0:
logger.error('All server failed, process will exit')
return 'fail'
else:
self.con_index = random.randint(0,
len(self.server_list) - 1)
len(self.serving_list) - 1)
return 'retry'
elif self.load_balance == 'bind':
try:
self.con_index = int(self.process_id) % len(self.server_list)
self.con_index = int(self.process_id) % len(self.serving_list)
cur_con = httplib.HTTPConnection(
self.server_list[self.con_index])
self.serving_list[self.con_index])
cur_con.request('POST', "/BertService/inference", request_msg,
{"Content-Type": "application/json"})
response = cur_con.getresponse()
......@@ -142,13 +173,13 @@ class BertService(object):
except BaseException as err:
logger.warning("Infer Error with server {} : {}".format(
self.server_list[self.con_index], err))
if len(self.server_list) == 0:
self.serving_list[self.con_index], err))
if len(self.serving_list) == 0:
logger.error('All server failed, process will exit')
return 'fail'
else:
self.con_index = int(self.process_id) % len(
self.server_list)
self.serving_list)
return 'retry'
def prepare_data(self, text):
......@@ -184,6 +215,9 @@ class BertService(object):
return request_msg
def encode(self, text):
if len(self.serving_list) == 0:
logger.error('No match server.')
return -1
if type(text) != list:
raise TypeError('Only support list')
request_msg = self.prepare_data(text)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册