提交 5edc2a6f 编写于 作者: K kinghuin 提交者: Steffy-zxf

Support GLUE XNLI (#101)

* Support GLUE XNLI

* Fix description error

* fix yapf failed
上级 b857db71
...@@ -23,7 +23,6 @@ import ast ...@@ -23,7 +23,6 @@ import ast
import numpy as np import numpy as np
import os import os
import time import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
...@@ -35,24 +34,56 @@ parser.add_argument("--batch_size", type=int, default=1, help="Total examp ...@@ -35,24 +34,56 @@ parser.add_argument("--batch_size", type=int, default=1, help="Total examp
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.") parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.") parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
if __name__ == '__main__': if __name__ == '__main__':
# loading Paddlehub ERNIE pretrained model dataset = None
module = hub.Module(name="ernie") # Download dataset and use ClassifyReader to read dataset
inputs, outputs, program = module.context(max_seq_len=args.max_seq_len) if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "lcqmc":
dataset = hub.dataset.LCQMC()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qqp":
dataset = hub.dataset.GLUE("QQP")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "sst-2":
dataset = hub.dataset.GLUE("SST-2")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "cola":
dataset = hub.dataset.GLUE("CoLA")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qnli":
dataset = hub.dataset.GLUE("QNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "rte":
dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "mnli":
dataset = hub.dataset.GLUE("MNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
else:
raise ValueError("%s dataset is not defined" % args.dataset)
# Sentence classification dataset reader inputs, outputs, program = module.context(
dataset = hub.dataset.ChnSentiCorp() trainable=True, max_seq_len=args.max_seq_len)
reader = hub.reader.ClassifyReader( reader = hub.reader.ClassifyReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len) max_seq_len=args.max_seq_len)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# Construct transfer learning network # Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence. # Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output. # Use "sequence_output" for token-level output.
...@@ -86,15 +117,7 @@ if __name__ == '__main__': ...@@ -86,15 +117,7 @@ if __name__ == '__main__':
config=config) config=config)
# Data to be prdicted # Data to be prdicted
data = [ data = [[d.text_a, d.text_b] for d in dataset.get_dev_examples()[:3]]
["这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"], ["交通方便;环境很好;服务态度很好 房间较小"],
[
"还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。"
],
[
"前台接待太差,酒店有A B楼之分,本人check-in后,前台未告诉B楼在何处,并且B楼无明显指示;房间太小,根本不像4星级设施,下次不会再选择入住此店啦"
], ["19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"]
]
index = 0 index = 0
run_states = cls_task.predict(data=data) run_states = cls_task.predict(data=data)
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task # User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
DATASET="chnsenticorp" DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}" CKPT_DIR="./ckpt_${DATASET}"
# Recommending hyper parameters for difference task
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5 # Support ChnSentiCorp NLPCC_DBQA LCQMC MRPC QQP SST-2
# NLPCC_DBQA: batch_size=8, weight_decay=0.01, num_epoch=3, max_seq_len=512, lr=2e-5 # CoLA QNLI RTE MNLI XNLI
# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=2e-5 # for XNLI: Specify the language with an underscore like xnli_zh.
# ar: Arabic bg: Bulgarian de: German
# el: Greek en: English es: Spanish
# fr: French hi: Hindi ru: Russian
# sw: Swahili th: Thai tr: Turkish
# ur: Urdu vi: Vietnamese zh: Chinese (Simplified)
python -u text_classifier.py \ python -u text_classifier.py \
--batch_size=24 \ --batch_size=24 \
......
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_chnsenticorp" # User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False
DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}"
# Support ChnSentiCorp NLPCC_DBQA LCQMC MRPC QQP SST-2
# CoLA QNLI RTE MNLI XNLI
# for XNLI: Specify the language with an underscore like xnli_zh.
# ar: Arabic bg: Bulgarian de: German
# el: Greek en: English es: Spanish
# fr: French hi: Hindi ru: Russian
# sw: Swahili th: Thai tr: Turkish
# ur: Urdu vi: Vietnamese zh: Chinese (Simplified)
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False --dataset=${DATASET}
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import argparse import argparse
import ast import ast
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
...@@ -24,7 +23,7 @@ import paddlehub as hub ...@@ -24,7 +23,7 @@ import paddlehub as hub
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.") parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False") parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to model checkpoint", choices=["chnsenticorp", "nlpcc_dbqa", "lcqmc"]) parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy") parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
...@@ -38,23 +37,46 @@ args = parser.parse_args() ...@@ -38,23 +37,46 @@ args = parser.parse_args()
# yapf: enable. # yapf: enable.
if __name__ == '__main__': if __name__ == '__main__':
# Load Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie")
# module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
# Download dataset and use ClassifyReader to read dataset
dataset = None dataset = None
# Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "chnsenticorp": if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp() dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "nlpcc_dbqa": elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA() dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "lcqmc": elif args.dataset.lower() == "lcqmc":
dataset = hub.dataset.LCQMC() dataset = hub.dataset.LCQMC()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qqp":
dataset = hub.dataset.GLUE("QQP")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "sst-2":
dataset = hub.dataset.GLUE("SST-2")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "cola":
dataset = hub.dataset.GLUE("CoLA")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qnli":
dataset = hub.dataset.GLUE("QNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "rte":
dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "mnli":
dataset = hub.dataset.GLUE("MNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
else: else:
raise ValueError("%s dataset is not defined" % args.dataset) raise ValueError("%s dataset is not defined" % args.dataset)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
reader = hub.reader.ClassifyReader( reader = hub.reader.ClassifyReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
......
...@@ -20,6 +20,8 @@ from .msra_ner import MSRA_NER ...@@ -20,6 +20,8 @@ from .msra_ner import MSRA_NER
from .nlpcc_dbqa import NLPCC_DBQA from .nlpcc_dbqa import NLPCC_DBQA
from .lcqmc import LCQMC from .lcqmc import LCQMC
from .toxic import Toxic from .toxic import Toxic
from .xnli import XNLI
from .glue import GLUE
# CV Dataset # CV Dataset
from .dogcat import DogCatDataset as DogCat from .dogcat import DogCatDataset as DogCat
......
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import csv
import io
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/glue_data.tar.gz"
class GLUE(HubDataset):
"""
Please refer to
https://gluebenchmark.com
for more information
"""
def __init__(self, sub_dataset='SST-2'):
# sub_dataset : CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, STS-B
if sub_dataset not in [
'CoLA', 'MNLI', 'MRPC', 'QNLI', 'QQP', 'RTE', 'SST-2', 'STS-B'
]:
raise Exception(
sub_dataset +
"is not in GLUE benchmark. Please confirm the data set")
self.sub_dataset = sub_dataset
self.dataset_dir = os.path.join(DATA_HOME, "glue_data")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_dev_examples()
self._load_test_examples()
self._load_predict_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, self.sub_dataset,
"train.tsv")
self.train_examples = self._read_tsv(self.train_file)
def _load_dev_examples(self):
if self.sub_dataset == 'MNLI':
self.dev_file = os.path.join(self.dataset_dir, self.sub_dataset,
"dev_matched.tsv")
else:
self.dev_file = os.path.join(self.dataset_dir, self.sub_dataset,
"dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_examples = []
def _load_predict_examples(self):
if self.sub_dataset == 'MNLI':
self.predict_file = os.path.join(self.dataset_dir, self.sub_dataset,
"test_matched.tsv")
else:
self.predict_file = os.path.join(self.dataset_dir, self.sub_dataset,
"test.tsv")
self.predict_examples = self._read_tsv(self.predict_file, wo_label=True)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
if self.sub_dataset in ['MRPC', 'QQP', 'SST-2', 'CoLA']:
return ["0", "1"]
elif self.sub_dataset in ['QNLI', 'RTE']:
return ['not_entailment', 'entailment']
elif self.sub_dataset in ['MNLI']:
return ["neutral", "contradiction", "entailment"]
elif self.sub_dataset in ['STS-B']:
return Exception("No category labels for regreesion tasks")
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None, wo_label=False):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
examples = []
seq_id = 0
header = next(reader) # skip header
if self.sub_dataset in [
'MRPC',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
else:
label_index, text_a_index, text_b_index = [0, -1, -2]
elif self.sub_dataset in [
'QNLI',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, 2]
else:
label_index, text_a_index, text_b_index = [3, 1, 2]
elif self.sub_dataset in [
'QQP',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, 2]
else:
label_index, text_a_index, text_b_index = [5, 3, 4]
elif self.sub_dataset in [
'RTE',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, 2]
else:
label_index, text_a_index, text_b_index = [3, 1, 2]
elif self.sub_dataset in [
'SST-2',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, None]
else:
label_index, text_a_index, text_b_index = [1, 0, None]
elif self.sub_dataset in [
'MNLI',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
else:
label_index, text_a_index, text_b_index = [-1, -4, -5]
elif self.sub_dataset in ['CoLA']:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, None]
else:
label_index, text_a_index, text_b_index = [1, 3, None]
elif self.sub_dataset in ['STS-B']:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
else:
label_index, text_a_index, text_b_index = [-1, -2, -3]
for line in reader:
try:
example = InputExample(
guid=seq_id,
text_a=line[text_a_index],
text_b=line[text_b_index]
if text_b_index is not None else None,
label=line[label_index]
if label_index is not None else None)
seq_id += 1
examples.append(example)
except:
print("[Discard Incorrect Data] " + "\t".join(line))
return examples
if __name__ == "__main__":
ds = GLUE(sub_dataset='SST-2')
for e in ds.get_train_examples()[:3]:
print(e)
labels = set()
# default_downloader.download_file_and_uncompress(
# url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/XNLI-lan.tar.gz"
class XNLI(HubDataset):
"""
Please refer to
https://arxiv.org/pdf/1809.05053.pdf
for more information
"""
def __init__(self, language='zh'):
if language not in [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw",
"th", "tr", "ur", "vi", "zh"
]:
raise Exception(language +
"is not in XNLI. Please confirm the language")
self.language = language
self.dataset_dir = os.path.join(DATA_HOME, "XNLI-lan")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, self.language,
self.language + "_train.tsv")
self.train_examples = self._read_tsv(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, self.language,
self.language + "_dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, self.language,
self.language + "_test.tsv")
self.test_examples = self._read_tsv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
return ["neutral", "contradiction", "entailment"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
examples = []
seq_id = 0
header = next(reader) # skip header
for line in reader:
example = InputExample(
guid=seq_id, label=line[2], text_a=line[0], text_b=line[1])
seq_id += 1
examples.append(example)
return examples
if __name__ == "__main__":
ds = XNLI()
for e in ds.get_train_examples()[:3]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册