From 5e9eda4e87fa2f44c7e5644c882f7ff11645b080 Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Wed, 29 May 2019 17:48:40 +0800 Subject: [PATCH] Update nlp reader to support predict phase(#43) * Add the required lib 'chardet' in setup.py (#41) * Update the nlp_reader to avoid feeding label_data when using predict * Drop the lib "chardet" * Drop the lib "chardet" * Drop the lub "chardet" * Update the nlp reader to avoid feeding label when predicting --- paddlehub/reader/nlp_reader.py | 261 +++++++++++++++++++++++---------- 1 file changed, 180 insertions(+), 81 deletions(-) diff --git a/paddlehub/reader/nlp_reader.py b/paddlehub/reader/nlp_reader.py index 0ecd43b2..5f723adb 100644 --- a/paddlehub/reader/nlp_reader.py +++ b/paddlehub/reader/nlp_reader.py @@ -18,15 +18,16 @@ from __future__ import print_function import csv import json +import numpy as np import platform import six from collections import namedtuple import paddle -import numpy as np from paddlehub.reader import tokenization from paddlehub.common.logger import logger +from paddlehub.dataset.dataset import InputExample from .batching import pad_batch_data import paddlehub as hub @@ -104,7 +105,11 @@ class BaseReader(object): else: tokens_b.pop() - def _convert_example_to_record(self, example, max_seq_length, tokenizer): + def _convert_example_to_record(self, + example, + max_seq_length, + tokenizer, + phase=None): """Converts a single `Example` into a single `Record`.""" text_a = tokenization.convert_to_unicode(example.text_a) @@ -175,11 +180,24 @@ class BaseReader(object): 'Record', ['token_ids', 'text_type_ids', 'position_ids', 'label_id']) - record = Record( - token_ids=token_ids, - text_type_ids=text_type_ids, - position_ids=position_ids, - label_id=label_id) + if phase != "predict": + Record = namedtuple( + 'Record', + ['token_ids', 'text_type_ids', 'position_ids', 'label_id']) + + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids, + label_id=label_id) + else: + Record = namedtuple('Record', + ['token_ids', 'text_type_ids', 'position_ids']) + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids) + return record def _prepare_batch_data(self, examples, batch_size, phase=None): @@ -189,7 +207,7 @@ class BaseReader(object): if phase == "train": self.current_example = index record = self._convert_example_to_record(example, self.max_seq_len, - self.tokenizer) + self.tokenizer, phase) max_len = max(max_len, len(record.token_ids)) if self.in_tokens: to_append = (len(batch_records) + 1) * max_len <= batch_size @@ -198,11 +216,11 @@ class BaseReader(object): if to_append: batch_records.append(record) else: - yield self._pad_batch_records(batch_records) + yield self._pad_batch_records(batch_records, phase) batch_records, max_len = [record], len(record.token_ids) if batch_records: - yield self._pad_batch_records(batch_records) + yield self._pad_batch_records(batch_records, phase) def get_num_examples(self, phase): """Get number of examples for train, dev or test.""" @@ -212,20 +230,51 @@ class BaseReader(object): ) return self.num_examples[phase] - def data_generator(self, batch_size=1, phase='train', shuffle=True): + def data_generator(self, + batch_size=1, + phase='train', + shuffle=True, + data=None): if phase == 'train': + shuffle = True examples = self.get_train_examples() self.num_examples['train'] = len(examples) elif phase == 'val' or phase == 'dev': + shuffle = False examples = self.get_dev_examples() self.num_examples['dev'] = len(examples) elif phase == 'test': + shuffle = False examples = self.get_test_examples() self.num_examples['test'] = len(examples) + elif phase == 'predict': + shuffle = False + examples = [] + seq_id = 0 + + for item in data: + # set label in order to run the program + label = "0" + if len(item) == 1: + item_i = InputExample( + guid=seq_id, text_a=item[0], label=label) + elif len(item) == 2: + item_i = InputExample( + guid=seq_id, + text_a=item[0], + text_b=item[1], + label=label) + else: + raise ValueError( + "The length of input_text is out of handling, which must be 1 or 2!" + ) + examples.append(item_i) + seq_id += 1 else: raise ValueError( - "Unknown phase, which should be in ['train', 'dev', 'test'].") + "Unknown phase, which should be in ['train', 'dev', 'test', 'predict']." + ) def wrapper(): if shuffle: @@ -239,20 +288,11 @@ class BaseReader(object): class ClassifyReader(BaseReader): - def _pad_batch_records(self, batch_records): + def _pad_batch_records(self, batch_records, phase=None): batch_token_ids = [record.token_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records] - batch_labels = [record.label_id for record in batch_records] - batch_labels = np.array(batch_labels).astype("int64").reshape([-1, 1]) - - # if batch_records[0].qid: - # batch_qids = [record.qid for record in batch_records] - # batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) - # else: - # batch_qids = np.array([]).astype("int64").reshape([-1, 1]) - # padding padded_token_ids, input_mask = pad_batch_data( batch_token_ids, max_seq_len=self.max_seq_len, @@ -267,20 +307,29 @@ class ClassifyReader(BaseReader): max_seq_len=self.max_seq_len, pad_idx=self.pad_id) - return_list = [ - padded_token_ids, padded_position_ids, padded_text_type_ids, - input_mask, batch_labels - ] + if phase != "predict": + batch_labels = [record.label_id for record in batch_records] + batch_labels = np.array(batch_labels).astype("int64").reshape( + [-1, 1]) + + return_list = [ + padded_token_ids, padded_position_ids, padded_text_type_ids, + input_mask, batch_labels + ] + else: + return_list = [ + padded_token_ids, padded_position_ids, padded_text_type_ids, + input_mask + ] return return_list class SequenceLabelReader(BaseReader): - def _pad_batch_records(self, batch_records): + def _pad_batch_records(self, batch_records, phase=None): batch_token_ids = [record.token_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records] - batch_label_ids = [record.label_ids for record in batch_records] # padding padded_token_ids, input_mask, batch_seq_lens = pad_batch_data( @@ -297,65 +346,115 @@ class SequenceLabelReader(BaseReader): batch_position_ids, max_seq_len=self.max_seq_len, pad_idx=self.pad_id) - padded_label_ids = pad_batch_data( - batch_label_ids, - max_seq_len=self.max_seq_len, - pad_idx=len(self.label_map) - 1) - return_list = [ - padded_token_ids, padded_position_ids, padded_text_type_ids, - input_mask, padded_label_ids, batch_seq_lens - ] + if phase != "predict": + batch_label_ids = [record.label_ids for record in batch_records] + padded_label_ids = pad_batch_data( + batch_label_ids, + max_seq_len=self.max_seq_len, + pad_idx=len(self.label_map) - 1) + + return_list = [ + padded_token_ids, padded_position_ids, padded_text_type_ids, + input_mask, padded_label_ids, batch_seq_lens + ] + else: + return_list = [ + padded_token_ids, padded_position_ids, padded_text_type_ids, + input_mask, batch_seq_lens + ] return return_list - def _reseg_token_label(self, tokens, labels, tokenizer): - if len(tokens) != len(labels): - raise ValueError("The length of tokens must be same with labels") - ret_tokens = [] - ret_labels = [] - for token, label in zip(tokens, labels): - sub_token = tokenizer.tokenize(token) - if len(sub_token) == 0: - continue - ret_tokens.extend(sub_token) - ret_labels.append(label) - if len(sub_token) < 2: - continue - sub_label = label - if label.startswith("B-"): - sub_label = "I-" + label[2:] - ret_labels.extend([sub_label] * (len(sub_token) - 1)) - - if len(ret_tokens) != len(labels): - raise ValueError("The length of ret_tokens can't match with labels") - return ret_tokens, ret_labels - - def _convert_example_to_record(self, example, max_seq_length, tokenizer): - tokens = tokenization.convert_to_unicode(example.text_a).split(u"") - labels = tokenization.convert_to_unicode(example.label).split(u"") - tokens, labels = self._reseg_token_label(tokens, labels, tokenizer) + def _reseg_token_label(self, tokens, tokenizer, phase, labels=None): + if phase != "predict": + if len(tokens) != len(labels): + raise ValueError( + "The length of tokens must be same with labels") + ret_tokens = [] + ret_labels = [] + for token, label in zip(tokens, labels): + sub_token = tokenizer.tokenize(token) + if len(sub_token) == 0: + continue + ret_tokens.extend(sub_token) + ret_labels.append(label) + if len(sub_token) < 2: + continue + sub_label = label + if label.startswith("B-"): + sub_label = "I-" + label[2:] + ret_labels.extend([sub_label] * (len(sub_token) - 1)) + + if len(ret_tokens) != len(labels): + raise ValueError( + "The length of ret_tokens can't match with labels") + return ret_tokens, ret_labels + else: + ret_tokens = [] + for token in tokens: + sub_token = tokenizer.tokenize(token) + if len(sub_token) == 0: + continue + ret_tokens.extend(sub_token) + if len(sub_token) < 2: + continue + + return ret_tokens + + def _convert_example_to_record(self, + example, + max_seq_length, + tokenizer, + phase=None): - if len(tokens) > max_seq_length - 2: - tokens = tokens[0:(max_seq_length - 2)] - labels = labels[0:(max_seq_length - 2)] + tokens = tokenization.convert_to_unicode(example.text_a).split(u"") - tokens = ["[CLS]"] + tokens + ["[SEP]"] - token_ids = tokenizer.convert_tokens_to_ids(tokens) - position_ids = list(range(len(token_ids))) - text_type_ids = [0] * len(token_ids) - no_entity_id = len(self.label_map) - 1 - label_ids = [no_entity_id - ] + [self.label_map[label] - for label in labels] + [no_entity_id] + if phase != "predict": + labels = tokenization.convert_to_unicode(example.label).split(u"") + tokens, labels = self._reseg_token_label( + tokens=tokens, labels=labels, tokenizer=tokenizer, phase=phase) + + if len(tokens) > max_seq_length - 2: + tokens = tokens[0:(max_seq_length - 2)] + labels = labels[0:(max_seq_length - 2)] + + tokens = ["[CLS]"] + tokens + ["[SEP]"] + token_ids = tokenizer.convert_tokens_to_ids(tokens) + position_ids = list(range(len(token_ids))) + text_type_ids = [0] * len(token_ids) + no_entity_id = len(self.label_map) - 1 + label_ids = [no_entity_id + ] + [self.label_map[label] + for label in labels] + [no_entity_id] + + Record = namedtuple( + 'Record', + ['token_ids', 'text_type_ids', 'position_ids', 'label_ids']) + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids, + label_ids=label_ids) + else: + tokens = self._reseg_token_label( + tokens=tokens, tokenizer=tokenizer, phase=phase) + + if len(tokens) > max_seq_length - 2: + tokens = tokens[0:(max_seq_length - 2)] + + tokens = ["[CLS]"] + tokens + ["[SEP]"] + token_ids = tokenizer.convert_tokens_to_ids(tokens) + position_ids = list(range(len(token_ids))) + text_type_ids = [0] * len(token_ids) + + Record = namedtuple('Record', + ['token_ids', 'text_type_ids', 'position_ids']) + record = Record( + token_ids=token_ids, + text_type_ids=text_type_ids, + position_ids=position_ids, + ) - Record = namedtuple( - 'Record', - ['token_ids', 'text_type_ids', 'position_ids', 'label_ids']) - record = Record( - token_ids=token_ids, - text_type_ids=text_type_ids, - position_ids=position_ids, - label_ids=label_ids) return record -- GitLab