diff --git a/python/examples/bert/bert_rpc_client.py b/python/examples/bert/bert_rpc_client.py index a3a86aca72514d506572f1edb6a4697c19215de3..959143d1c5703bbcd1615e90414e5bd4327f0017 100644 --- a/python/examples/bert/bert_rpc_client.py +++ b/python/examples/bert/bert_rpc_client.py @@ -45,20 +45,6 @@ class BertService(): self.pid = os.getpid() self.profile = True if ("FLAGS_profile_client" in os.environ and os.environ["FLAGS_profile_client"]) else False - ''' - module = hub.Module(name=self.model_name) - inputs, outputs, program = module.context( - trainable=True, max_seq_len=self.max_seq_len) - input_ids = inputs["input_ids"] - position_ids = inputs["position_ids"] - segment_ids = inputs["segment_ids"] - input_mask = inputs["input_mask"] - self.reader = hub.reader.ClassifyReader( - vocab_path=module.get_vocab_path(), - dataset=None, - max_seq_len=self.max_seq_len, - do_lower_case=self.do_lower_case) - ''' self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) self.reader_flag = True @@ -68,28 +54,8 @@ class BertService(): self.client.connect(server_addr) def run_general(self, text, fetch): - ''' - self.batch_size = len(text) - data_generator = self.reader.data_generator( - batch_size=self.batch_size, phase='predict', data=text) - ''' result = [] prepro_start = time.time() - ''' - for run_step, batch in enumerate(data_generator(), start=1): - token_list = batch[0][0].reshape(-1).tolist() - pos_list = batch[0][1].reshape(-1).tolist() - sent_list = batch[0][2].reshape(-1).tolist() - mask_list = batch[0][3].reshape(-1).tolist() - for si in range(self.batch_size): - feed = { - "input_ids": token_list, - "position_ids": pos_list, - "segment_ids": sent_list, - "input_mask": mask_list - } - prepro_end = time.time() - ''' feed = self.reader.process(text) if self.profile: print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format( @@ -102,33 +68,10 @@ class BertService(): def run_batch_general(self, text, fetch): self.batch_size = len(text) - ''' - data_generator = self.reader.data_generator( - batch_size=self.batch_size, phase='predict', data=text) - ''' result = [] prepro_start = time.time() - ''' - for run_step, batch in enumerate(data_generator(), start=1): - token_list = batch[0][0].reshape(-1).tolist() - pos_list = batch[0][1].reshape(-1).tolist() - sent_list = batch[0][2].reshape(-1).tolist() - mask_list = batch[0][3].reshape(-1).tolist() - ''' feed_batch = [] for si in range(self.batch_size): - ''' - feed = { - "input_ids": token_list[si * self.max_seq_len:(si + 1) * - self.max_seq_len], - "position_ids": - pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len], - "segment_ids": sent_list[si * self.max_seq_len:(si + 1) * - self.max_seq_len], - "input_mask": - mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len] - } - ''' feed = self.reader.process(text[si]) feed_batch.append(feed) prepro_end = time.time()