提交 675e80c3 编写于 作者: K kinghuin 提交者: Steffy-zxf

Fix MNLI error (#105)

上级 bff1ab25
......@@ -45,7 +45,7 @@ class InputExample(object):
if self.text_b is None:
return "text={}\tlabel={}".format(self.text_a, self.label)
else:
return "text_a={}\ttext_b{},label={}".format(
return "text_a={}\ttext_b={},label={}".format(
self.text_a, self.text_b, self.label)
......
......@@ -43,7 +43,7 @@ class GLUE(HubDataset):
]:
raise Exception(
sub_dataset +
"is not in GLUE benchmark. Please confirm the data set")
" is not in GLUE benchmark. Please confirm the data set")
self.sub_dataset = sub_dataset
self.dataset_dir = os.path.join(DATA_HOME, "glue_data")
......@@ -120,14 +120,15 @@ class GLUE(HubDataset):
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
examples = []
seq_id = 0
header = next(reader) # skip header
if self.sub_dataset != 'CoLA' or wo_label:
header = next(reader) # skip header
if self.sub_dataset in [
'MRPC',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
label_index, text_a_index, text_b_index = [None, -2, -1]
else:
label_index, text_a_index, text_b_index = [0, -1, -2]
label_index, text_a_index, text_b_index = [0, -2, -1]
elif self.sub_dataset in [
'QNLI',
]:
......@@ -160,9 +161,9 @@ class GLUE(HubDataset):
'MNLI',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, -2, -1]
label_index, text_a_index, text_b_index = [None, 8, 9]
else:
label_index, text_a_index, text_b_index = [-1, -4, -3]
label_index, text_a_index, text_b_index = [-1, 8, 9]
elif self.sub_dataset in ['CoLA']:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, None]
......@@ -170,9 +171,9 @@ class GLUE(HubDataset):
label_index, text_a_index, text_b_index = [1, 3, None]
elif self.sub_dataset in ['STS-B']:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
label_index, text_a_index, text_b_index = [None, -2, -1]
else:
label_index, text_a_index, text_b_index = [-1, -2, -3]
label_index, text_a_index, text_b_index = [-1, -3, -2]
for line in reader:
try:
......@@ -191,26 +192,20 @@ class GLUE(HubDataset):
if __name__ == "__main__":
ds = GLUE(sub_dataset='CoLA')
total_len = 0
max_len = 0
total_num = over_num = 0
overlen = []
for e in ds.get_predict_examples():
length = len(e.text_a.split()) + len(
e.text_b.split()) if e.text_b else len(e.text_a.split())
total_len += length
if length > max_len:
max_len = length
total_num += 1
if length > 128:
over_num += 1
overstr = ("\ntext_a: " + e.text_a + "\ntext_b:" +
e.text_b) if e.text_b else e.text_a
overlen.append(overstr)
avg = total_len / total_num
for o in overlen[:2]:
print("The data length>128:{}".format(o))
print(
"The total number: {}\nThe avrage length: {}\nthe max length: {}\nthe number of data length > 128: {}"
.format(total_num, avg, max_len, over_num))
for sub_dataset in [
'CoLA', 'MNLI', 'MRPC', 'QNLI', 'QQP', 'RTE', 'SST-2', 'STS-B'
]:
print(sub_dataset)
ds = GLUE(sub_dataset=sub_dataset)
for e in ds.get_train_examples()[:2]:
print(e)
print()
for e in ds.get_dev_examples()[:2]:
print(e)
print()
for e in ds.get_test_examples()[:2]:
print(e)
print()
for e in ds.get_predict_examples()[:2]:
print(e)
print()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册