diff --git a/paddlehub/commands/serving.py b/paddlehub/commands/serving.py index fd39bf90d12a17d237ee9b22a00a05636683f4c0..aa8208271909ef6c4ffdd5aadb79572dfe74896b 100644 --- a/paddlehub/commands/serving.py +++ b/paddlehub/commands/serving.py @@ -66,7 +66,7 @@ class ServingCommand(BaseCommand): from paddle_gpu_serving.run import BertServer bs = BertServer(with_gpu=args.use_gpu) bs.with_model(model_name=args.modules[0]) - bs.run(gpu_index=args.gpu, port=args.port) + bs.run(gpu_index=args.gpu, port=int(args.port)) @staticmethod def is_port_occupied(ip, port): diff --git a/paddlehub/serving/bert_serving/bert_service.py b/paddlehub/serving/bert_serving/bert_service.py index d78698fa965fabe9664e5089792f56093e33e792..b6c3636c9ad27005295dde3c4b76b7c3afc141ee 100644 --- a/paddlehub/serving/bert_serving/bert_service.py +++ b/paddlehub/serving/bert_serving/bert_service.py @@ -18,6 +18,7 @@ import paddlehub as hub import ujson import random from paddlehub.common.logger import logger +import socket _ver = sys.version_info is_py2 = (_ver[0] == 2) @@ -51,6 +52,7 @@ class BertService(object): self.con_index = 0 self.load_balance = load_balance self.server_list = [] + self.serving_list = [] self.feed_var_names = '' self.retry = retry @@ -71,29 +73,58 @@ class BertService(object): def add_server(self, server='127.0.0.1:8010'): self.server_list.append(server) + self.check_server() def add_server_list(self, server_list): for server_str in server_list: self.server_list.append(server_str) + self.check_server() + + def check_server(self): + for server in self.server_list: + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_ip = server.split(':')[0] + server_port = int(server.split(':')[1]) + client.connect((server_ip, server_port)) + client.send(b'pending server') + response = client.recv(1024).decode() + + response_list = response.split('\t') + status_code = int(response_list[0].split(':')[1]) + + if status_code == 0: + server_model = response_list[1].split(':')[1] + if server_model == self.model_name: + serving_port = response_list[2].split(':')[1] + serving_ip = server_ip + self.serving_list.append(serving_ip + ':' + serving_port) + else: + logger.error( + 'model_name not match, server {} using : {} '.format( + server, server_model)) + else: + error_msg = response_list[1] + logger.error('connect server {} failed. {}'.format( + server, error_msg)) def request_server(self, request_msg): if self.load_balance == 'round_robin': try: cur_con = httplib.HTTPConnection( - self.server_list[self.con_index]) + self.serving_list[self.con_index]) cur_con.request('POST', "/BertService/inference", request_msg, {"Content-Type": "application/json"}) response = cur_con.getresponse() response_msg = response.read() response_msg = ujson.loads(response_msg) self.con_index += 1 - self.con_index = self.con_index % len(self.server_list) + self.con_index = self.con_index % len(self.serving_list) return response_msg except BaseException as err: logger.warning("Infer Error with server {} : {}".format( - self.server_list[self.con_index], err)) - if len(self.server_list) == 0: + self.serving_list[self.con_index], err)) + if len(self.serving_list) == 0: logger.error('All server failed, process will exit') return 'fail' else: @@ -103,10 +134,10 @@ class BertService(object): elif self.load_balance == 'random': try: random.seed() - self.con_index = random.randint(0, len(self.server_list) - 1) + self.con_index = random.randint(0, len(self.serving_list) - 1) logger.info(self.con_index) cur_con = httplib.HTTPConnection( - self.server_list[self.con_index]) + self.serving_list[self.con_index]) cur_con.request('POST', "/BertService/inference", request_msg, {"Content-Type": "application/json"}) response = cur_con.getresponse() @@ -117,21 +148,21 @@ class BertService(object): except BaseException as err: logger.warning("Infer Error with server {} : {}".format( - self.server_list[self.con_index], err)) - if len(self.server_list) == 0: + self.serving_list[self.con_index], err)) + if len(self.serving_list) == 0: logger.error('All server failed, process will exit') return 'fail' else: self.con_index = random.randint(0, - len(self.server_list) - 1) + len(self.serving_list) - 1) return 'retry' elif self.load_balance == 'bind': try: - self.con_index = int(self.process_id) % len(self.server_list) + self.con_index = int(self.process_id) % len(self.serving_list) cur_con = httplib.HTTPConnection( - self.server_list[self.con_index]) + self.serving_list[self.con_index]) cur_con.request('POST', "/BertService/inference", request_msg, {"Content-Type": "application/json"}) response = cur_con.getresponse() @@ -142,13 +173,13 @@ class BertService(object): except BaseException as err: logger.warning("Infer Error with server {} : {}".format( - self.server_list[self.con_index], err)) - if len(self.server_list) == 0: + self.serving_list[self.con_index], err)) + if len(self.serving_list) == 0: logger.error('All server failed, process will exit') return 'fail' else: self.con_index = int(self.process_id) % len( - self.server_list) + self.serving_list) return 'retry' def prepare_data(self, text): @@ -184,6 +215,9 @@ class BertService(object): return request_msg def encode(self, text): + if len(self.serving_list) == 0: + logger.error('No match server.') + return -1 if type(text) != list: raise TypeError('Only support list') request_msg = self.prepare_data(text)