提交 675e80c3 编写于 作者: K kinghuin 提交者: Steffy-zxf

Fix MNLI error (#105)

上级 bff1ab25
...@@ -45,7 +45,7 @@ class InputExample(object): ...@@ -45,7 +45,7 @@ class InputExample(object):
if self.text_b is None: if self.text_b is None:
return "text={}\tlabel={}".format(self.text_a, self.label) return "text={}\tlabel={}".format(self.text_a, self.label)
else: else:
return "text_a={}\ttext_b{},label={}".format( return "text_a={}\ttext_b={},label={}".format(
self.text_a, self.text_b, self.label) self.text_a, self.text_b, self.label)
......
...@@ -43,7 +43,7 @@ class GLUE(HubDataset): ...@@ -43,7 +43,7 @@ class GLUE(HubDataset):
]: ]:
raise Exception( raise Exception(
sub_dataset + sub_dataset +
"is not in GLUE benchmark. Please confirm the data set") " is not in GLUE benchmark. Please confirm the data set")
self.sub_dataset = sub_dataset self.sub_dataset = sub_dataset
self.dataset_dir = os.path.join(DATA_HOME, "glue_data") self.dataset_dir = os.path.join(DATA_HOME, "glue_data")
...@@ -120,14 +120,15 @@ class GLUE(HubDataset): ...@@ -120,14 +120,15 @@ class GLUE(HubDataset):
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
examples = [] examples = []
seq_id = 0 seq_id = 0
header = next(reader) # skip header if self.sub_dataset != 'CoLA' or wo_label:
header = next(reader) # skip header
if self.sub_dataset in [ if self.sub_dataset in [
'MRPC', 'MRPC',
]: ]:
if wo_label: if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2] label_index, text_a_index, text_b_index = [None, -2, -1]
else: else:
label_index, text_a_index, text_b_index = [0, -1, -2] label_index, text_a_index, text_b_index = [0, -2, -1]
elif self.sub_dataset in [ elif self.sub_dataset in [
'QNLI', 'QNLI',
]: ]:
...@@ -160,9 +161,9 @@ class GLUE(HubDataset): ...@@ -160,9 +161,9 @@ class GLUE(HubDataset):
'MNLI', 'MNLI',
]: ]:
if wo_label: if wo_label:
label_index, text_a_index, text_b_index = [None, -2, -1] label_index, text_a_index, text_b_index = [None, 8, 9]
else: else:
label_index, text_a_index, text_b_index = [-1, -4, -3] label_index, text_a_index, text_b_index = [-1, 8, 9]
elif self.sub_dataset in ['CoLA']: elif self.sub_dataset in ['CoLA']:
if wo_label: if wo_label:
label_index, text_a_index, text_b_index = [None, 1, None] label_index, text_a_index, text_b_index = [None, 1, None]
...@@ -170,9 +171,9 @@ class GLUE(HubDataset): ...@@ -170,9 +171,9 @@ class GLUE(HubDataset):
label_index, text_a_index, text_b_index = [1, 3, None] label_index, text_a_index, text_b_index = [1, 3, None]
elif self.sub_dataset in ['STS-B']: elif self.sub_dataset in ['STS-B']:
if wo_label: if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2] label_index, text_a_index, text_b_index = [None, -2, -1]
else: else:
label_index, text_a_index, text_b_index = [-1, -2, -3] label_index, text_a_index, text_b_index = [-1, -3, -2]
for line in reader: for line in reader:
try: try:
...@@ -191,26 +192,20 @@ class GLUE(HubDataset): ...@@ -191,26 +192,20 @@ class GLUE(HubDataset):
if __name__ == "__main__": if __name__ == "__main__":
ds = GLUE(sub_dataset='CoLA') for sub_dataset in [
total_len = 0 'CoLA', 'MNLI', 'MRPC', 'QNLI', 'QQP', 'RTE', 'SST-2', 'STS-B'
max_len = 0 ]:
total_num = over_num = 0 print(sub_dataset)
overlen = [] ds = GLUE(sub_dataset=sub_dataset)
for e in ds.get_predict_examples(): for e in ds.get_train_examples()[:2]:
length = len(e.text_a.split()) + len( print(e)
e.text_b.split()) if e.text_b else len(e.text_a.split()) print()
total_len += length for e in ds.get_dev_examples()[:2]:
if length > max_len: print(e)
max_len = length print()
total_num += 1 for e in ds.get_test_examples()[:2]:
if length > 128: print(e)
over_num += 1 print()
overstr = ("\ntext_a: " + e.text_a + "\ntext_b:" + for e in ds.get_predict_examples()[:2]:
e.text_b) if e.text_b else e.text_a print(e)
overlen.append(overstr) print()
avg = total_len / total_num
for o in overlen[:2]:
print("The data length>128:{}".format(o))
print(
"The total number: {}\nThe avrage length: {}\nthe max length: {}\nthe number of data length > 128: {}"
.format(total_num, avg, max_len, over_num))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册