提交 bff1ab25 编写于 作者: K kinghuin 提交者: Steffy-zxf

Fix MNLI error (#103)

上级 4ab705ba
......@@ -93,6 +93,9 @@ class GLUE(HubDataset):
def get_test_examples(self):
return self.test_examples
def get_predict_examples(self):
return self.predict_examples
def get_labels(self):
"""See base class."""
if self.sub_dataset in ['MRPC', 'QQP', 'SST-2', 'CoLA']:
......@@ -157,9 +160,9 @@ class GLUE(HubDataset):
'MNLI',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
label_index, text_a_index, text_b_index = [None, -2, -1]
else:
label_index, text_a_index, text_b_index = [-1, -4, -5]
label_index, text_a_index, text_b_index = [-1, -4, -3]
elif self.sub_dataset in ['CoLA']:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, None]
......@@ -188,7 +191,26 @@ class GLUE(HubDataset):
if __name__ == "__main__":
ds = GLUE(sub_dataset='SST-2')
for e in ds.get_train_examples()[:3]:
print(e)
labels = set()
ds = GLUE(sub_dataset='CoLA')
total_len = 0
max_len = 0
total_num = over_num = 0
overlen = []
for e in ds.get_predict_examples():
length = len(e.text_a.split()) + len(
e.text_b.split()) if e.text_b else len(e.text_a.split())
total_len += length
if length > max_len:
max_len = length
total_num += 1
if length > 128:
over_num += 1
overstr = ("\ntext_a: " + e.text_a + "\ntext_b:" +
e.text_b) if e.text_b else e.text_a
overlen.append(overstr)
avg = total_len / total_num
for o in overlen[:2]:
print("The data length>128:{}".format(o))
print(
"The total number: {}\nThe avrage length: {}\nthe max length: {}\nthe number of data length > 128: {}"
.format(total_num, avg, max_len, over_num))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册