From 675e80c364f6b41540f4f508c47ad019f97673f9 Mon Sep 17 00:00:00 2001 From: kinghuin Date: Tue, 6 Aug 2019 12:05:14 +0800 Subject: [PATCH] Fix MNLI error (#105) --- paddlehub/dataset/dataset.py | 2 +- paddlehub/dataset/glue.py | 57 ++++++++++++++++-------------------- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/paddlehub/dataset/dataset.py b/paddlehub/dataset/dataset.py index 53e7f947..ca45cd90 100644 --- a/paddlehub/dataset/dataset.py +++ b/paddlehub/dataset/dataset.py @@ -45,7 +45,7 @@ class InputExample(object): if self.text_b is None: return "text={}\tlabel={}".format(self.text_a, self.label) else: - return "text_a={}\ttext_b{},label={}".format( + return "text_a={}\ttext_b={},label={}".format( self.text_a, self.text_b, self.label) diff --git a/paddlehub/dataset/glue.py b/paddlehub/dataset/glue.py index f5806cb4..f8c0240c 100644 --- a/paddlehub/dataset/glue.py +++ b/paddlehub/dataset/glue.py @@ -43,7 +43,7 @@ class GLUE(HubDataset): ]: raise Exception( sub_dataset + - "is not in GLUE benchmark. Please confirm the data set") + " is not in GLUE benchmark. Please confirm the data set") self.sub_dataset = sub_dataset self.dataset_dir = os.path.join(DATA_HOME, "glue_data") @@ -120,14 +120,15 @@ class GLUE(HubDataset): reader = csv.reader(f, delimiter="\t", quotechar=quotechar) examples = [] seq_id = 0 - header = next(reader) # skip header + if self.sub_dataset != 'CoLA' or wo_label: + header = next(reader) # skip header if self.sub_dataset in [ 'MRPC', ]: if wo_label: - label_index, text_a_index, text_b_index = [None, -1, -2] + label_index, text_a_index, text_b_index = [None, -2, -1] else: - label_index, text_a_index, text_b_index = [0, -1, -2] + label_index, text_a_index, text_b_index = [0, -2, -1] elif self.sub_dataset in [ 'QNLI', ]: @@ -160,9 +161,9 @@ class GLUE(HubDataset): 'MNLI', ]: if wo_label: - label_index, text_a_index, text_b_index = [None, -2, -1] + label_index, text_a_index, text_b_index = [None, 8, 9] else: - label_index, text_a_index, text_b_index = [-1, -4, -3] + label_index, text_a_index, text_b_index = [-1, 8, 9] elif self.sub_dataset in ['CoLA']: if wo_label: label_index, text_a_index, text_b_index = [None, 1, None] @@ -170,9 +171,9 @@ class GLUE(HubDataset): label_index, text_a_index, text_b_index = [1, 3, None] elif self.sub_dataset in ['STS-B']: if wo_label: - label_index, text_a_index, text_b_index = [None, -1, -2] + label_index, text_a_index, text_b_index = [None, -2, -1] else: - label_index, text_a_index, text_b_index = [-1, -2, -3] + label_index, text_a_index, text_b_index = [-1, -3, -2] for line in reader: try: @@ -191,26 +192,20 @@ class GLUE(HubDataset): if __name__ == "__main__": - ds = GLUE(sub_dataset='CoLA') - total_len = 0 - max_len = 0 - total_num = over_num = 0 - overlen = [] - for e in ds.get_predict_examples(): - length = len(e.text_a.split()) + len( - e.text_b.split()) if e.text_b else len(e.text_a.split()) - total_len += length - if length > max_len: - max_len = length - total_num += 1 - if length > 128: - over_num += 1 - overstr = ("\ntext_a: " + e.text_a + "\ntext_b:" + - e.text_b) if e.text_b else e.text_a - overlen.append(overstr) - avg = total_len / total_num - for o in overlen[:2]: - print("The data length>128:{}".format(o)) - print( - "The total number: {}\nThe avrage length: {}\nthe max length: {}\nthe number of data length > 128: {}" - .format(total_num, avg, max_len, over_num)) + for sub_dataset in [ + 'CoLA', 'MNLI', 'MRPC', 'QNLI', 'QQP', 'RTE', 'SST-2', 'STS-B' + ]: + print(sub_dataset) + ds = GLUE(sub_dataset=sub_dataset) + for e in ds.get_train_examples()[:2]: + print(e) + print() + for e in ds.get_dev_examples()[:2]: + print(e) + print() + for e in ds.get_test_examples()[:2]: + print(e) + print() + for e in ds.get_predict_examples()[:2]: + print(e) + print() -- GitLab