From d326596260605d2d650783cbe30100ba9892f186 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Mon, 26 Oct 2020 15:44:40 -0700 Subject: [PATCH] Add in IMDB dataset processor to TF-NLP. PiperOrigin-RevId: 339134320 --- official/nlp/data/classifier_data_lib.py | 55 ++++++++++++++++++--- official/nlp/data/create_finetuning_data.py | 8 +-- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/official/nlp/data/classifier_data_lib.py b/official/nlp/data/classifier_data_lib.py index ca53d8f23..3ec420916 100644 --- a/official/nlp/data/classifier_data_lib.py +++ b/official/nlp/data/classifier_data_lib.py @@ -214,6 +214,44 @@ class ColaProcessor(DataProcessor): return examples +class ImdbProcessor(DataProcessor): + """Processor for the IMDb dataset.""" + + def get_labels(self): + return ["neg", "pos"] + + def get_train_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, "train")) + + def get_dev_examples(self, data_dir): + return self._create_examples(os.path.join(data_dir, "test")) + + @staticmethod + def get_processor_name(): + """See base class.""" + return "IMDB" + + def _create_examples(self, data_dir): + """Creates examples.""" + examples = [] + for label in ["neg", "pos"]: + cur_dir = os.path.join(data_dir, label) + for filename in tf.io.gfile.listdir(cur_dir): + if not filename.endswith("txt"): + continue + + if len(examples) % 1000 == 0: + logging.info("Loading dev example %d", len(examples)) + + path = os.path.join(cur_dir, filename) + with tf.io.gfile.GFile(path, "r") as f: + text = f.read().strip().replace("
", " ") + examples.append( + InputExample( + guid="unused_id", text_a=text, text_b=None, label=label)) + return examples + + class MnliProcessor(DataProcessor): """Processor for the MultiNLI data set (GLUE version).""" @@ -1032,6 +1070,11 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, if len(tokens_a) > max_seq_length - 2: tokens_a = tokens_a[0:(max_seq_length - 2)] + seg_id_a = 0 + seg_id_b = 1 + seg_id_cls = 0 + seg_id_pad = 0 + # The convention in BERT is: # (a) For sequence pairs: # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] @@ -1053,19 +1096,19 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, tokens = [] segment_ids = [] tokens.append("[CLS]") - segment_ids.append(0) + segment_ids.append(seg_id_cls) for token in tokens_a: tokens.append(token) - segment_ids.append(0) + segment_ids.append(seg_id_a) tokens.append("[SEP]") - segment_ids.append(0) + segment_ids.append(seg_id_a) if tokens_b: for token in tokens_b: tokens.append(token) - segment_ids.append(1) + segment_ids.append(seg_id_b) tokens.append("[SEP]") - segment_ids.append(1) + segment_ids.append(seg_id_b) input_ids = tokenizer.convert_tokens_to_ids(tokens) @@ -1077,7 +1120,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) - segment_ids.append(0) + segment_ids.append(seg_id_pad) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length diff --git a/official/nlp/data/create_finetuning_data.py b/official/nlp/data/create_finetuning_data.py index 4b163576f..a82cf429c 100644 --- a/official/nlp/data/create_finetuning_data.py +++ b/official/nlp/data/create_finetuning_data.py @@ -47,9 +47,9 @@ flags.DEFINE_string( "for the task.") flags.DEFINE_enum("classification_task_name", "MNLI", - ["AX", "COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", - "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", - "XTREME-PAWS-X"], + ["AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", + "QQP", "RTE", "SST-2", "STS-B", "WNLI", "XNLI", + "XTREME-XNLI", "XTREME-PAWS-X"], "The name of the task to train BERT classifier. The " "difference between XTREME-XNLI and XNLI is: 1. the format " "of input tsv files; 2. the dev set for XTREME is english " @@ -182,6 +182,8 @@ def generate_classifier_dataset(): classifier_data_lib.AxProcessor, "cola": classifier_data_lib.ColaProcessor, + "imdb": + classifier_data_lib.ImdbProcessor, "mnli": functools.partial(classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type), -- GitLab