提交 5e9eda4e 编写于 作者: S Steffy-zxf 提交者: Zeyu Chen

Update nlp reader to support predict phase(#43)

* Add the required lib 'chardet' in setup.py (#41)

* Update the nlp_reader to avoid feeding label_data when using predict

* Drop the lib "chardet"

* Drop the lib "chardet"

* Drop the lub "chardet"

* Update the nlp reader to avoid feeding label when predicting
上级 93bf920e
...@@ -18,15 +18,16 @@ from __future__ import print_function ...@@ -18,15 +18,16 @@ from __future__ import print_function
import csv import csv
import json import json
import numpy as np
import platform import platform
import six import six
from collections import namedtuple from collections import namedtuple
import paddle import paddle
import numpy as np
from paddlehub.reader import tokenization from paddlehub.reader import tokenization
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.dataset.dataset import InputExample
from .batching import pad_batch_data from .batching import pad_batch_data
import paddlehub as hub import paddlehub as hub
...@@ -104,7 +105,11 @@ class BaseReader(object): ...@@ -104,7 +105,11 @@ class BaseReader(object):
else: else:
tokens_b.pop() tokens_b.pop()
def _convert_example_to_record(self, example, max_seq_length, tokenizer): def _convert_example_to_record(self,
example,
max_seq_length,
tokenizer,
phase=None):
"""Converts a single `Example` into a single `Record`.""" """Converts a single `Example` into a single `Record`."""
text_a = tokenization.convert_to_unicode(example.text_a) text_a = tokenization.convert_to_unicode(example.text_a)
...@@ -175,11 +180,24 @@ class BaseReader(object): ...@@ -175,11 +180,24 @@ class BaseReader(object):
'Record', 'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_id']) ['token_ids', 'text_type_ids', 'position_ids', 'label_id'])
record = Record( if phase != "predict":
token_ids=token_ids, Record = namedtuple(
text_type_ids=text_type_ids, 'Record',
position_ids=position_ids, ['token_ids', 'text_type_ids', 'position_ids', 'label_id'])
label_id=label_id)
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id)
else:
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids)
return record return record
def _prepare_batch_data(self, examples, batch_size, phase=None): def _prepare_batch_data(self, examples, batch_size, phase=None):
...@@ -189,7 +207,7 @@ class BaseReader(object): ...@@ -189,7 +207,7 @@ class BaseReader(object):
if phase == "train": if phase == "train":
self.current_example = index self.current_example = index
record = self._convert_example_to_record(example, self.max_seq_len, record = self._convert_example_to_record(example, self.max_seq_len,
self.tokenizer) self.tokenizer, phase)
max_len = max(max_len, len(record.token_ids)) max_len = max(max_len, len(record.token_ids))
if self.in_tokens: if self.in_tokens:
to_append = (len(batch_records) + 1) * max_len <= batch_size to_append = (len(batch_records) + 1) * max_len <= batch_size
...@@ -198,11 +216,11 @@ class BaseReader(object): ...@@ -198,11 +216,11 @@ class BaseReader(object):
if to_append: if to_append:
batch_records.append(record) batch_records.append(record)
else: else:
yield self._pad_batch_records(batch_records) yield self._pad_batch_records(batch_records, phase)
batch_records, max_len = [record], len(record.token_ids) batch_records, max_len = [record], len(record.token_ids)
if batch_records: if batch_records:
yield self._pad_batch_records(batch_records) yield self._pad_batch_records(batch_records, phase)
def get_num_examples(self, phase): def get_num_examples(self, phase):
"""Get number of examples for train, dev or test.""" """Get number of examples for train, dev or test."""
...@@ -212,20 +230,51 @@ class BaseReader(object): ...@@ -212,20 +230,51 @@ class BaseReader(object):
) )
return self.num_examples[phase] return self.num_examples[phase]
def data_generator(self, batch_size=1, phase='train', shuffle=True): def data_generator(self,
batch_size=1,
phase='train',
shuffle=True,
data=None):
if phase == 'train': if phase == 'train':
shuffle = True
examples = self.get_train_examples() examples = self.get_train_examples()
self.num_examples['train'] = len(examples) self.num_examples['train'] = len(examples)
elif phase == 'val' or phase == 'dev': elif phase == 'val' or phase == 'dev':
shuffle = False
examples = self.get_dev_examples() examples = self.get_dev_examples()
self.num_examples['dev'] = len(examples) self.num_examples['dev'] = len(examples)
elif phase == 'test': elif phase == 'test':
shuffle = False
examples = self.get_test_examples() examples = self.get_test_examples()
self.num_examples['test'] = len(examples) self.num_examples['test'] = len(examples)
elif phase == 'predict':
shuffle = False
examples = []
seq_id = 0
for item in data:
# set label in order to run the program
label = "0"
if len(item) == 1:
item_i = InputExample(
guid=seq_id, text_a=item[0], label=label)
elif len(item) == 2:
item_i = InputExample(
guid=seq_id,
text_a=item[0],
text_b=item[1],
label=label)
else:
raise ValueError(
"The length of input_text is out of handling, which must be 1 or 2!"
)
examples.append(item_i)
seq_id += 1
else: else:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].") "Unknown phase, which should be in ['train', 'dev', 'test', 'predict']."
)
def wrapper(): def wrapper():
if shuffle: if shuffle:
...@@ -239,20 +288,11 @@ class BaseReader(object): ...@@ -239,20 +288,11 @@ class BaseReader(object):
class ClassifyReader(BaseReader): class ClassifyReader(BaseReader):
def _pad_batch_records(self, batch_records): def _pad_batch_records(self, batch_records, phase=None):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records]
batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1])
# if batch_records[0].qid:
# batch_qids = [record.qid for record in batch_records]
# batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
# else:
# batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids, input_mask = pad_batch_data( padded_token_ids, input_mask = pad_batch_data(
batch_token_ids, batch_token_ids,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
...@@ -267,20 +307,29 @@ class ClassifyReader(BaseReader): ...@@ -267,20 +307,29 @@ class ClassifyReader(BaseReader):
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=self.pad_id) pad_idx=self.pad_id)
return_list = [ if phase != "predict":
padded_token_ids, padded_position_ids, padded_text_type_ids, batch_labels = [record.label_id for record in batch_records]
input_mask, batch_labels batch_labels = np.array(batch_labels).astype("int64").reshape(
] [-1, 1])
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_labels
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask
]
return return_list return return_list
class SequenceLabelReader(BaseReader): class SequenceLabelReader(BaseReader):
def _pad_batch_records(self, batch_records): def _pad_batch_records(self, batch_records, phase=None):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records]
batch_label_ids = [record.label_ids for record in batch_records]
# padding # padding
padded_token_ids, input_mask, batch_seq_lens = pad_batch_data( padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
...@@ -297,65 +346,115 @@ class SequenceLabelReader(BaseReader): ...@@ -297,65 +346,115 @@ class SequenceLabelReader(BaseReader):
batch_position_ids, batch_position_ids,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
pad_idx=self.pad_id) pad_idx=self.pad_id)
padded_label_ids = pad_batch_data(
batch_label_ids,
max_seq_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1)
return_list = [ if phase != "predict":
padded_token_ids, padded_position_ids, padded_text_type_ids, batch_label_ids = [record.label_ids for record in batch_records]
input_mask, padded_label_ids, batch_seq_lens padded_label_ids = pad_batch_data(
] batch_label_ids,
max_seq_len=self.max_seq_len,
pad_idx=len(self.label_map) - 1)
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, padded_label_ids, batch_seq_lens
]
else:
return_list = [
padded_token_ids, padded_position_ids, padded_text_type_ids,
input_mask, batch_seq_lens
]
return return_list return return_list
def _reseg_token_label(self, tokens, labels, tokenizer): def _reseg_token_label(self, tokens, tokenizer, phase, labels=None):
if len(tokens) != len(labels): if phase != "predict":
raise ValueError("The length of tokens must be same with labels") if len(tokens) != len(labels):
ret_tokens = [] raise ValueError(
ret_labels = [] "The length of tokens must be same with labels")
for token, label in zip(tokens, labels): ret_tokens = []
sub_token = tokenizer.tokenize(token) ret_labels = []
if len(sub_token) == 0: for token, label in zip(tokens, labels):
continue sub_token = tokenizer.tokenize(token)
ret_tokens.extend(sub_token) if len(sub_token) == 0:
ret_labels.append(label) continue
if len(sub_token) < 2: ret_tokens.extend(sub_token)
continue ret_labels.append(label)
sub_label = label if len(sub_token) < 2:
if label.startswith("B-"): continue
sub_label = "I-" + label[2:] sub_label = label
ret_labels.extend([sub_label] * (len(sub_token) - 1)) if label.startswith("B-"):
sub_label = "I-" + label[2:]
if len(ret_tokens) != len(labels): ret_labels.extend([sub_label] * (len(sub_token) - 1))
raise ValueError("The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels if len(ret_tokens) != len(labels):
raise ValueError(
def _convert_example_to_record(self, example, max_seq_length, tokenizer): "The length of ret_tokens can't match with labels")
tokens = tokenization.convert_to_unicode(example.text_a).split(u"") return ret_tokens, ret_labels
labels = tokenization.convert_to_unicode(example.label).split(u"") else:
tokens, labels = self._reseg_token_label(tokens, labels, tokenizer) ret_tokens = []
for token in tokens:
sub_token = tokenizer.tokenize(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
if len(sub_token) < 2:
continue
return ret_tokens
def _convert_example_to_record(self,
example,
max_seq_length,
tokenizer,
phase=None):
if len(tokens) > max_seq_length - 2: tokens = tokenization.convert_to_unicode(example.text_a).split(u"")
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"] if phase != "predict":
token_ids = tokenizer.convert_tokens_to_ids(tokens) labels = tokenization.convert_to_unicode(example.label).split(u"")
position_ids = list(range(len(token_ids))) tokens, labels = self._reseg_token_label(
text_type_ids = [0] * len(token_ids) tokens=tokens, labels=labels, tokenizer=tokenizer, phase=phase)
no_entity_id = len(self.label_map) - 1
label_ids = [no_entity_id if len(tokens) > max_seq_length - 2:
] + [self.label_map[label] tokens = tokens[0:(max_seq_length - 2)]
for label in labels] + [no_entity_id] labels = labels[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1
label_ids = [no_entity_id
] + [self.label_map[label]
for label in labels] + [no_entity_id]
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
else:
tokens = self._reseg_token_label(
tokens=tokens, tokenizer=tokenizer, phase=phase)
if len(tokens) > max_seq_length - 2:
tokens = tokens[0:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids)
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
)
Record = namedtuple(
'Record',
['token_ids', 'text_type_ids', 'position_ids', 'label_ids'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_ids=label_ids)
return record return record
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册