提交 5e9eda4e 编写于 作者: S Steffy-zxf 提交者: Zeyu Chen

Update nlp reader to support predict phase(#43)

* Add the required lib 'chardet' in setup.py (#41)

* Update the nlp_reader to avoid feeding label_data when using predict

* Drop the lib "chardet"

* Drop the lib "chardet"

* Drop the lub "chardet"

* Update the nlp reader to avoid feeding label when predicting
上级 93bf920e
......@@ -18,15 +18,16 @@ from __future__ import print_function
import csv
import json
import numpy as np
import platform
import six
from collections import namedtuple
import paddle
import numpy as np
from paddlehub.reader import tokenization
from paddlehub.common.logger import logger
from paddlehub.dataset.dataset import InputExample
from .batching import pad_batch_data
import paddlehub as hub
......@@ -104,7 +105,11 @@ class BaseReader(object):
else:
tokens_b.pop()
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
def _convert_example_to_record(self,
example,
max_seq_length,
tokenizer,
phase=None):
"""Converts a single `Example` into a single `Record`."""
text_a = tokenization.convert_to_unicode(example.text_a)
......@@ -175,11 +180,24 @@ class BaseReader(object):
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id)
if phase != "predict":
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id)
else:
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids)
return record
def _prepare_batch_data(self, examples, batch_size, phase=None):
......@@ -189,7 +207,7 @@ class BaseReader(object):
if phase == "train":
self.current_example = index
record = self._convert_example_to_record(example, self.max_seq_len,
self.tokenizer)
self.tokenizer, phase)
max_len = max(max_len, len(record.token_ids))
if self.in_tokens:
to_append = (len(batch_records) + 1) * max_len <= batch_size
......@@ -198,11 +216,11 @@ class BaseReader(object):
if to_append:
batch_records.append(record)
else:
yield self._pad_batch_records(batch_records)
yield self._pad_batch_records(batch_records, phase)
batch_records, max_len = [record], len(record.token_ids)
if batch_records:
yield self._pad_batch_records(batch_records)
yield self._pad_batch_records(batch_records, phase)
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
......@@ -212,20 +230,51 @@ class BaseReader(object):
)
return self.num_examples[phase]
def data_generator(self, batch_size=1, phase='train', shuffle=True):
def data_generator(self,
batch_size=1,
phase='train',
shuffle=True,
data=None):
if phase == 'train':
shuffle = True
examples = self.get_train_examples()
self.num_examples['train'] = len(examples)
elif phase == 'val' or phase == 'dev':
shuffle = False
examples = self.get_dev_examples()
self.num_examples['dev'] = len(examples)
elif phase == 'test':
shuffle = False
examples = self.get_test_examples()
self.num_examples['test'] = len(examples)
elif phase == 'predict':
shuffle = False
examples = []
seq_id = 0
for item in data:
# set label in order to run the program
label = "0"
if len(item) == 1:
item_i = InputExample(
guid=seq_id, text_a=item[0], label=label)
elif len(item) == 2:
item_i = InputExample(
guid=seq_id,
text_a=item[0],
text_b=item[1],
label=label)
else:
raise ValueError(
"The length of input_text is out of handling, which must be 1 or 2!"
)
examples.append(item_i)
seq_id += 1
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
"Unknown phase, which should be in ['train', 'dev', 'test', 'predict']."
)
def wrapper():
if shuffle:
......@@ -239,20 +288,11 @@ class BaseReader(object):
class ClassifyReader(BaseReader):
def _pad_batch_records(self, batch_records):
def _pad_batch_records(self, batch_records, phase=None):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
# if batch_records[0].qid:
# batch_qids = [record.qid for record in batch_records]
# batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
# else:
# batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids, input_mask = pad_batch_data(
batch_token_ids,
max_seq_len=self.max_seq_len,
......@@ -267,20 +307,29 @@ class ClassifyReader(BaseReader):
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
if phase != "predict":
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape(
[-1, 1])
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
return return_list
class SequenceLabelReader(BaseReader):
def _pad_batch_records(self, batch_records):
def _pad_batch_records(self, batch_records, phase=None):
batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records]
batch_label_ids = [record.label_ids for record in batch_records]
# padding
padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
......@@ -297,65 +346,115 @@ class SequenceLabelReader(BaseReader):
batch_position_ids,
max_seq_len=self.max_seq_len,
pad_idx=self.pad_id)
padded_label_ids = pad_batch_data(
batch_label_ids,
max_seq_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_label_ids, batch_seq_lens
]
if phase != "predict":
batch_label_ids = [record.label_ids for record in batch_records]
padded_label_ids = pad_batch_data(
batch_label_ids,
max_seq_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_label_ids, batch_seq_lens
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_seq_lens
]
return return_list
def _reseg_token_label(self, tokens, labels, tokenizer):
if len(tokens) != len(labels):
raise ValueError("The length of tokens must be same with labels")
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
if len(ret_tokens) != len(labels):
raise ValueError("The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels
def _convert_example_to_record(self, example, max_seq_length, tokenizer):
tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
labels = tokenization.convert_to_unicode(example.label).split(u"")
tokens, labels = self._reseg_token_label(tokens, labels, tokenizer)
def _reseg_token_label(self, tokens, tokenizer, phase, labels=None):
if phase != "predict":
if len(tokens) != len(labels):
raise ValueError(
"The length of tokens must be same with labels")
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
ret_labels.append(label)
if len(sub_token) < 2:
continue
sub_label = label
if label.startswith("B-"):
sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1))
if len(ret_tokens) != len(labels):
raise ValueError(
"The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels
else:
ret_tokens = []
for token in tokens:
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
if len(sub_token) < 2:
continue
return ret_tokens
def _convert_example_to_record(self,
example,
max_seq_length,
tokenizer,
phase=None):
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
label_ids = [no_entity_id
] + [self.label_map[label]
for label in labels] + [no_entity_id]
if phase != "predict":
labels = tokenization.convert_to_unicode(example.label).split(u"")
tokens, labels = self._reseg_token_label(
tokens=tokens, labels=labels, tokenizer=tokenizer, phase=phase)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
label_ids = [no_entity_id
] + [self.label_map[label]
for label in labels] + [no_entity_id]
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
else:
tokens = self._reseg_token_label(
tokens=tokens, tokenizer=tokenizer, phase=phase)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
)
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
return record
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册