提交 d3265962 编写于 作者: A Allen Wang 提交者: A. Unique TensorFlower

Add in IMDB dataset processor to TF-NLP.

PiperOrigin-RevId: 339134320
上级 4dc945c0
...@@ -214,6 +214,44 @@ class ColaProcessor(DataProcessor): ...@@ -214,6 +214,44 @@ class ColaProcessor(DataProcessor):
return examples return examples
class ImdbProcessor(DataProcessor):
"""Processor for the IMDb dataset."""
def get_labels(self):
return ["neg", "pos"]
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train"))
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test"))
@staticmethod
def get_processor_name():
"""See base class."""
return "IMDB"
def _create_examples(self, data_dir):
"""Creates examples."""
examples = []
for label in ["neg", "pos"]:
cur_dir = os.path.join(data_dir, label)
for filename in tf.io.gfile.listdir(cur_dir):
if not filename.endswith("txt"):
continue
if len(examples) % 1000 == 0:
logging.info("Loading dev example %d", len(examples))
path = os.path.join(cur_dir, filename)
with tf.io.gfile.GFile(path, "r") as f:
text = f.read().strip().replace("<br />", " ")
examples.append(
InputExample(
guid="unused_id", text_a=text, text_b=None, label=label))
return examples
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version).""" """Processor for the MultiNLI data set (GLUE version)."""
...@@ -1032,6 +1070,11 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -1032,6 +1070,11 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
if len(tokens_a) > max_seq_length - 2: if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)] tokens_a = tokens_a[0:(max_seq_length - 2)]
seg_id_a = 0
seg_id_b = 1
seg_id_cls = 0
seg_id_pad = 0
# The convention in BERT is: # The convention in BERT is:
# (a) For sequence pairs: # (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
...@@ -1053,19 +1096,19 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -1053,19 +1096,19 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens = [] tokens = []
segment_ids = [] segment_ids = []
tokens.append("[CLS]") tokens.append("[CLS]")
segment_ids.append(0) segment_ids.append(seg_id_cls)
for token in tokens_a: for token in tokens_a:
tokens.append(token) tokens.append(token)
segment_ids.append(0) segment_ids.append(seg_id_a)
tokens.append("[SEP]") tokens.append("[SEP]")
segment_ids.append(0) segment_ids.append(seg_id_a)
if tokens_b: if tokens_b:
for token in tokens_b: for token in tokens_b:
tokens.append(token) tokens.append(token)
segment_ids.append(1) segment_ids.append(seg_id_b)
tokens.append("[SEP]") tokens.append("[SEP]")
segment_ids.append(1) segment_ids.append(seg_id_b)
input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids = tokenizer.convert_tokens_to_ids(tokens)
...@@ -1077,7 +1120,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -1077,7 +1120,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
while len(input_ids) < max_seq_length: while len(input_ids) < max_seq_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(0) input_mask.append(0)
segment_ids.append(0) segment_ids.append(seg_id_pad)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
......
...@@ -47,9 +47,9 @@ flags.DEFINE_string( ...@@ -47,9 +47,9 @@ flags.DEFINE_string(
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["AX", "COLA", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", ["AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI",
"SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "QQP", "RTE", "SST-2", "STS-B", "WNLI", "XNLI",
"XTREME-PAWS-X"], "XTREME-XNLI", "XTREME-PAWS-X"],
"The name of the task to train BERT classifier. The " "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format " "difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english " "of input tsv files; 2. the dev set for XTREME is english "
...@@ -182,6 +182,8 @@ def generate_classifier_dataset(): ...@@ -182,6 +182,8 @@ def generate_classifier_dataset():
classifier_data_lib.AxProcessor, classifier_data_lib.AxProcessor,
"cola": "cola":
classifier_data_lib.ColaProcessor, classifier_data_lib.ColaProcessor,
"imdb":
classifier_data_lib.ImdbProcessor,
"mnli": "mnli":
functools.partial(classifier_data_lib.MnliProcessor, functools.partial(classifier_data_lib.MnliProcessor,
mnli_type=FLAGS.mnli_type), mnli_type=FLAGS.mnli_type),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册