异步读取问题
Created by: chuj625
你好,我需要封一个ernie的分词接口,提供给别人使用 在用PaddleNLP/lexical_analysis/run_ernie_sequence_labeling.py时,发现采用的是一个异步读文件的方式进行预测 首先,对为什么采用异步读文件的方式感到奇怪,也不太理解 其次,对如何改为接口传参的方式,反复调用,感到无从下手
求指导
附,部分代码理解如下,如不正确,请指正,非常感谢
#这部分,我已经重新封了一个类,完成了[string] -> reader.data_generator
#reader = task_reader.SequenceLabelReader(
reader=ListReader(
vocab_path=args.vocab_path,
label_map_config=args.label_map_config,
max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case,
in_tokens=False,
random_seed=args.random_seed)
lines = read_file(args.infer_set) # 读文件返回["a b c"]
。。。
with 。。。
with 。。。
# 这里提供lines
infer_pyreader.decorate_tensor_provider(
reader.data_generator(
lines, args.batch_size, phase='infer', epoch=1, shuffle=False
)
)
...
infer_pyreader.start()
while True:
try:
# 问题是,这里没办法在lines里数据变化时,反复地给出预测结果
(words, crf_decode) = exe.run(infer_program,
fetch_list=[infer_ret["words"], infer_ret["crf_decode"]],
return_numpy=False)
# User should notice that words had been clipped if long than args.max_seq_len
results = utils.parse_result(words, crf_decode, dataset)
for result in results:
print(type(result), result)
except fluid.core.EOFException:
infer_pyreader.reset()
break
ListReader代码如下
class ListReader(task_reader.SequenceLabelReader):
def _read_tsv(self, list_input, quotechar=None):
res = list_input
Example = namedtuple('LR', "text_a")
exam = []
for l in res:
l = l.strip()
#l = l.decode("utf-8")
example = Example(l)
exam.append(example)
return exam
def _reseg_token_label(self, tokens, tokenizer):
ret_tokens = []
ret_labels = []
for token in tokens:
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels = ['t']*len(ret_tokens)
assert len(ret_tokens) == len(ret_labels)
return ret_tokens, ret_labels
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
tokens, labels = self._reseg_token_label(tokens, tokenizer)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
# tokens to ids
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
labels = [label if label in self.label_map else u"O" for label in labels]
label_ids = [no_entity_id] + [
self.label_map[label] for label in labels
] + [no_entity_id]
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
return record