提交 a19f2f8b 编写于 作者: Y Yuqi Li 提交者: A. Unique TensorFlower

Add SST-2 and QNLI processor to Bert classifier lib

PiperOrigin-RevId: 285967157
上级 74c97026
......@@ -290,6 +290,98 @@ class ColaProcessor(DataProcessor):
return examples
class SstProcessor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "SST-2"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
label = "0"
else:
text_a = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "QNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, 1)
if set_type == "test":
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = "entailment"
else:
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
......
......@@ -41,7 +41,7 @@ flags.DEFINE_string(
"for the task.")
flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "XNLI"],
["COLA", "MNLI", "MRPC", "QNLI", "SST-2", "XNLI"],
"The name of the task to train BERT classifier.")
# BERT Squad task specific flags.
......@@ -102,6 +102,8 @@ def generate_classifier_dataset():
"cola": classifier_data_lib.ColaProcessor,
"mnli": classifier_data_lib.MnliProcessor,
"mrpc": classifier_data_lib.MrpcProcessor,
"qnli": classifier_data_lib.QnliProcessor,
"sst-2": classifier_data_lib.SstProcessor,
"xnli": classifier_data_lib.XnliProcessor,
}
task_name = FLAGS.classification_task_name.lower()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册