提交 5edc2a6f 编写于 作者: K kinghuin 提交者: Steffy-zxf

Support GLUE XNLI (#101)

* Support GLUE XNLI

* Fix description error

* fix yapf failed
上级 b857db71
......@@ -23,7 +23,6 @@ import ast
import numpy as np
import os
import time
import paddle
import paddle.fluid as fluid
import paddlehub as hub
......@@ -35,24 +34,56 @@ parser.add_argument("--batch_size", type=int, default=1, help="Total examp
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
# loading Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie")
inputs, outputs, program = module.context(max_seq_len=args.max_seq_len)
dataset = None
# Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "lcqmc":
dataset = hub.dataset.LCQMC()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qqp":
dataset = hub.dataset.GLUE("QQP")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "sst-2":
dataset = hub.dataset.GLUE("SST-2")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "cola":
dataset = hub.dataset.GLUE("CoLA")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qnli":
dataset = hub.dataset.GLUE("QNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "rte":
dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "mnli":
dataset = hub.dataset.GLUE("MNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
else:
raise ValueError("%s dataset is not defined" % args.dataset)
# Sentence classification dataset reader
dataset = hub.dataset.ChnSentiCorp()
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
......@@ -86,15 +117,7 @@ if __name__ == '__main__':
config=config)
# Data to be prdicted
data = [
["这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"], ["交通方便;环境很好;服务态度很好 房间较小"],
[
"还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。"
],
[
"前台接待太差,酒店有A B楼之分,本人check-in后,前台未告诉B楼在何处,并且B楼无明显指示;房间太小,根本不像4星级设施,下次不会再选择入住此店啦"
], ["19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"]
]
data = [[d.text_a, d.text_b] for d in dataset.get_dev_examples()[:3]]
index = 0
run_states = cls_task.predict(data=data)
......
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc for different task
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}"
# Recommending hyper parameters for difference task
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
# NLPCC_DBQA: batch_size=8, weight_decay=0.01, num_epoch=3, max_seq_len=512, lr=2e-5
# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=2e-5
# Support ChnSentiCorp NLPCC_DBQA LCQMC MRPC QQP SST-2
# CoLA QNLI RTE MNLI XNLI
# for XNLI: Specify the language with an underscore like xnli_zh.
# ar: Arabic bg: Bulgarian de: German
# el: Greek en: English es: Spanish
# fr: French hi: Hindi ru: Russian
# sw: Swahili th: Thai tr: Turkish
# ur: Urdu vi: Vietnamese zh: Chinese (Simplified)
python -u text_classifier.py \
--batch_size=24 \
......
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_chnsenticorp"
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}"
# Support ChnSentiCorp NLPCC_DBQA LCQMC MRPC QQP SST-2
# CoLA QNLI RTE MNLI XNLI
# for XNLI: Specify the language with an underscore like xnli_zh.
# ar: Arabic bg: Bulgarian de: German
# el: Greek en: English es: Spanish
# fr: French hi: Hindi ru: Russian
# sw: Swahili th: Thai tr: Turkish
# ur: Urdu vi: Vietnamese zh: Chinese (Simplified)
python -u predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128 --use_gpu False --dataset=${DATASET}
......@@ -16,7 +16,6 @@
import argparse
import ast
import paddle.fluid as fluid
import paddlehub as hub
......@@ -24,7 +23,7 @@ import paddlehub as hub
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=False, help="Whether use GPU for finetuning, input should be True or False")
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to model checkpoint", choices=["chnsenticorp", "nlpcc_dbqa", "lcqmc"])
parser.add_argument("--dataset", type=str, default="chnsenticorp", help="Directory to dataset")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
......@@ -38,23 +37,46 @@ args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
# Load Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie")
# module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
# Download dataset and use ClassifyReader to read dataset
dataset = None
# Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "lcqmc":
dataset = hub.dataset.LCQMC()
module = hub.Module(name="ernie")
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qqp":
dataset = hub.dataset.GLUE("QQP")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "sst-2":
dataset = hub.dataset.GLUE("SST-2")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "cola":
dataset = hub.dataset.GLUE("CoLA")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "qnli":
dataset = hub.dataset.GLUE("QNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "rte":
dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower() == "mnli":
dataset = hub.dataset.GLUE("MNLI")
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
else:
raise ValueError("%s dataset is not defined" % args.dataset)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len)
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
......
......@@ -20,6 +20,8 @@ from .msra_ner import MSRA_NER
from .nlpcc_dbqa import NLPCC_DBQA
from .lcqmc import LCQMC
from .toxic import Toxic
from .xnli import XNLI
from .glue import GLUE
# CV Dataset
from .dogcat import DogCatDataset as DogCat
......
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import csv
import io
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/glue_data.tar.gz"
class GLUE(HubDataset):
"""
Please refer to
https://gluebenchmark.com
for more information
"""
def __init__(self, sub_dataset='SST-2'):
# sub_dataset : CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, STS-B
if sub_dataset not in [
'CoLA', 'MNLI', 'MRPC', 'QNLI', 'QQP', 'RTE', 'SST-2', 'STS-B'
]:
raise Exception(
sub_dataset +
"is not in GLUE benchmark. Please confirm the data set")
self.sub_dataset = sub_dataset
self.dataset_dir = os.path.join(DATA_HOME, "glue_data")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_dev_examples()
self._load_test_examples()
self._load_predict_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, self.sub_dataset,
"train.tsv")
self.train_examples = self._read_tsv(self.train_file)
def _load_dev_examples(self):
if self.sub_dataset == 'MNLI':
self.dev_file = os.path.join(self.dataset_dir, self.sub_dataset,
"dev_matched.tsv")
else:
self.dev_file = os.path.join(self.dataset_dir, self.sub_dataset,
"dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_examples = []
def _load_predict_examples(self):
if self.sub_dataset == 'MNLI':
self.predict_file = os.path.join(self.dataset_dir, self.sub_dataset,
"test_matched.tsv")
else:
self.predict_file = os.path.join(self.dataset_dir, self.sub_dataset,
"test.tsv")
self.predict_examples = self._read_tsv(self.predict_file, wo_label=True)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
if self.sub_dataset in ['MRPC', 'QQP', 'SST-2', 'CoLA']:
return ["0", "1"]
elif self.sub_dataset in ['QNLI', 'RTE']:
return ['not_entailment', 'entailment']
elif self.sub_dataset in ['MNLI']:
return ["neutral", "contradiction", "entailment"]
elif self.sub_dataset in ['STS-B']:
return Exception("No category labels for regreesion tasks")
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None, wo_label=False):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
examples = []
seq_id = 0
header = next(reader) # skip header
if self.sub_dataset in [
'MRPC',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
else:
label_index, text_a_index, text_b_index = [0, -1, -2]
elif self.sub_dataset in [
'QNLI',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, 2]
else:
label_index, text_a_index, text_b_index = [3, 1, 2]
elif self.sub_dataset in [
'QQP',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, 2]
else:
label_index, text_a_index, text_b_index = [5, 3, 4]
elif self.sub_dataset in [
'RTE',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, 2]
else:
label_index, text_a_index, text_b_index = [3, 1, 2]
elif self.sub_dataset in [
'SST-2',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, None]
else:
label_index, text_a_index, text_b_index = [1, 0, None]
elif self.sub_dataset in [
'MNLI',
]:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
else:
label_index, text_a_index, text_b_index = [-1, -4, -5]
elif self.sub_dataset in ['CoLA']:
if wo_label:
label_index, text_a_index, text_b_index = [None, 1, None]
else:
label_index, text_a_index, text_b_index = [1, 3, None]
elif self.sub_dataset in ['STS-B']:
if wo_label:
label_index, text_a_index, text_b_index = [None, -1, -2]
else:
label_index, text_a_index, text_b_index = [-1, -2, -3]
for line in reader:
try:
example = InputExample(
guid=seq_id,
text_a=line[text_a_index],
text_b=line[text_b_index]
if text_b_index is not None else None,
label=line[label_index]
if label_index is not None else None)
seq_id += 1
examples.append(example)
except:
print("[Discard Incorrect Data] " + "\t".join(line))
return examples
if __name__ == "__main__":
ds = GLUE(sub_dataset='SST-2')
for e in ds.get_train_examples()[:3]:
print(e)
labels = set()
# default_downloader.download_file_and_uncompress(
# url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/XNLI-lan.tar.gz"
class XNLI(HubDataset):
"""
Please refer to
https://arxiv.org/pdf/1809.05053.pdf
for more information
"""
def __init__(self, language='zh'):
if language not in [
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw",
"th", "tr", "ur", "vi", "zh"
]:
raise Exception(language +
"is not in XNLI. Please confirm the language")
self.language = language
self.dataset_dir = os.path.join(DATA_HOME, "XNLI-lan")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, self.language,
self.language + "_train.tsv")
self.train_examples = self._read_tsv(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, self.language,
self.language + "_dev.tsv")
self.dev_examples = self._read_tsv(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, self.language,
self.language + "_test.tsv")
self.test_examples = self._read_tsv(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""See base class."""
return ["neutral", "contradiction", "entailment"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
examples = []
seq_id = 0
header = next(reader) # skip header
for line in reader:
example = InputExample(
guid=seq_id, label=line[2], text_a=line[0], text_b=line[1])
seq_id += 1
examples.append(example)
return examples
if __name__ == "__main__":
ds = XNLI()
for e in ds.get_train_examples()[:3]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册