提交 7abe8a1c 编写于 作者: M MRXLT

clean code

上级 a369de86
...@@ -45,20 +45,6 @@ class BertService(): ...@@ -45,20 +45,6 @@ class BertService():
self.pid = os.getpid() self.pid = os.getpid()
self.profile = True if ("FLAGS_profile_client" in os.environ and self.profile = True if ("FLAGS_profile_client" in os.environ and
os.environ["FLAGS_profile_client"]) else False os.environ["FLAGS_profile_client"]) else False
'''
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
input_ids = inputs["input_ids"]
position_ids = inputs["position_ids"]
segment_ids = inputs["segment_ids"]
input_mask = inputs["input_mask"]
self.reader = hub.reader.ClassifyReader(
vocab_path=module.get_vocab_path(),
dataset=None,
max_seq_len=self.max_seq_len,
do_lower_case=self.do_lower_case)
'''
self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20)
self.reader_flag = True self.reader_flag = True
...@@ -68,28 +54,8 @@ class BertService(): ...@@ -68,28 +54,8 @@ class BertService():
self.client.connect(server_addr) self.client.connect(server_addr)
def run_general(self, text, fetch): def run_general(self, text, fetch):
'''
self.batch_size = len(text)
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
'''
result = [] result = []
prepro_start = time.time() prepro_start = time.time()
'''
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
for si in range(self.batch_size):
feed = {
"input_ids": token_list,
"position_ids": pos_list,
"segment_ids": sent_list,
"input_mask": mask_list
}
prepro_end = time.time()
'''
feed = self.reader.process(text) feed = self.reader.process(text)
if self.profile: if self.profile:
print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format( print("PROFILE\tpid:{}\tbert_pre_0:{} bert_pre_1:{}".format(
...@@ -102,33 +68,10 @@ class BertService(): ...@@ -102,33 +68,10 @@ class BertService():
def run_batch_general(self, text, fetch): def run_batch_general(self, text, fetch):
self.batch_size = len(text) self.batch_size = len(text)
'''
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
'''
result = [] result = []
prepro_start = time.time() prepro_start = time.time()
'''
for run_step, batch in enumerate(data_generator(), start=1):
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
mask_list = batch[0][3].reshape(-1).tolist()
'''
feed_batch = [] feed_batch = []
for si in range(self.batch_size): for si in range(self.batch_size):
'''
feed = {
"input_ids": token_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"position_ids":
pos_list[si * self.max_seq_len:(si + 1) * self.max_seq_len],
"segment_ids": sent_list[si * self.max_seq_len:(si + 1) *
self.max_seq_len],
"input_mask":
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
}
'''
feed = self.reader.process(text[si]) feed = self.reader.process(text[si])
feed_batch.append(feed) feed_batch.append(feed)
prepro_end = time.time() prepro_end = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册