未验证 提交 ad4720ec 编写于 作者: S smallv0221 提交者: GitHub

Update datasets naming style (#5014)

* update lrscheduler

* minor fix

* add pre-commit

* minor fix

* Add __len__ to squad dataset

* minor fix

* Add dureader robust prototype

* dataset implement

* minor fix

* fix var name

* add dureader-yesno train script and dataset

* add readme and fix md5sum

* integrete dureader datasets

* change var names: segment to mode, root to data_file

* minor fix

* update var name
上级 26a0cd1e
......@@ -121,7 +121,7 @@ def do_train(args):
root=root,
max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length,
segment="train")
mode="train")
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True)
......@@ -146,7 +146,7 @@ def do_train(args):
root=root,
max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length,
segment="dev")
mode="dev")
dev_batch_sampler = paddle.io.BatchSampler(
dev_dataset, batch_size=args.batch_size, shuffle=False)
......
......@@ -120,7 +120,7 @@ def do_train(args):
version_2_with_negative=args.version_2_with_negative,
max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length,
segment="train")
mode="train")
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True)
......@@ -146,7 +146,7 @@ def do_train(args):
version_2_with_negative=args.version_2_with_negative,
max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length,
segment="dev")
mode="dev")
dev_batch_sampler = paddle.io.BatchSampler(
dev_dataset, batch_size=args.batch_size, shuffle=False)
......
......@@ -37,38 +37,38 @@ class ChnSentiCorp(TSVDataset):
URL = "https://bj.bcebos.com/paddlehub-dataset/chnsenticorp.tar.gz"
MD5 = "fbb3217aeac76a2840d2d5cd19688b07"
SEGMENT_INFO = collections.namedtuple(
'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SEGMENTS = {
'train': SEGMENT_INFO(
META_INFO = collections.namedtuple(
'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SPLITS = {
'train': META_INFO(
os.path.join('chnsenticorp', 'train.tsv'),
'689360c4a4a9ce8d8719ed500ae80907', (1, 0), 1),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('chnsenticorp', 'dev.tsv'),
'05e4b02561c2a327833e05bbe8156cec', (1, 0), 1),
'test': SEGMENT_INFO(
'test': META_INFO(
os.path.join('chnsenticorp', 'test.tsv'),
'917dfc6fbce596bb01a91abaa6c86f9e', (1, 0), 1)
}
def __init__(self,
segment='train',
mode='train',
root=None,
return_all_fields=False,
**kwargs):
if return_all_fields:
segments = copy.deepcopy(self.__class__.SEGMENTS)
segment_info = list(segments[segment])
segment_info[2] = None
segments[segment] = self.SEGMENT_INFO(*segment_info)
self.SEGMENTS = segments
splits = copy.deepcopy(self.__class__.SPLITS)
mode_info = list(splits[mode])
mode_info[2] = None
splits[mode] = self.META_INFO(*mode_info)
self.SPLITS = splits
self._get_data(root, segment, **kwargs)
self._get_data(root, mode, **kwargs)
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = DATA_HOME
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
filename, data_hash, field_indices, num_discard_samples = self.SPLITS[
mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
......
......@@ -40,17 +40,17 @@ def get_datasets(cls, *args, **kwargs):
.. code-block:: python
from paddlenlp.datasets import GlueQNLI
train_dataset, dev_dataset, test_dataset = GlueQNLI.get_datasets(['train', 'dev', 'test'])
train_dataset, dev_dataset, test_dataset = GlueQNLI.get_datasets(segment=['train', 'dev', 'test'])
train_dataset, dev_dataset, test_dataset = GlueQNLI.get_datasets(mode=['train', 'dev', 'test'])
train_dataset = GlueQNLI.get_datasets('train')
train_dataset = GlueQNLI.get_datasets(['train'])
train_dataset = GlueQNLI.get_datasets(segment='train')
train_dataset = GlueQNLI.get_datasets(mode='train')
"""
if not args and not kwargs:
try:
args = cls.SEGMENTS.keys()
args = cls.SPLITS.keys()
except:
raise AttributeError(
'Dataset must have SEGMENTS attridute to use get_dataset if configs is None.'
'Dataset must have SPLITS attridute to use get_dataset if configs is None.'
)
datasets = tuple(MapDatasetWrapper(cls(arg)) for arg in args)
......
......@@ -38,23 +38,23 @@ class DuReaderExample(object):
class DuReader(SQuAD):
SEGMENT_INFO = collections.namedtuple('SEGMENT_INFO', ('file', 'md5'))
META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
DATA_URL = 'https://dataset-bj.cdn.bcebos.com/dureader/dureader_preprocessed.zip'
SEGMENTS = {
'train': SEGMENT_INFO(
SPLITS = {
'train': META_INFO(
os.path.join('preprocessed', 'trainset', 'zhidao.train.json'),
None),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('preprocessed', 'devset', 'zhidao.dev.json'), None),
'test': SEGMENT_INFO(
'test': META_INFO(
os.path.join('preprocessed', 'testset', 'zhidao.test.json'), None)
}
def __init__(self,
tokenizer,
segment='train',
mode='train',
root=None,
doc_stride=128,
max_query_length=64,
......@@ -63,17 +63,17 @@ class DuReader(SQuAD):
super(DuReader, self).__init__(
tokenizer=tokenizer,
segment=segment,
mode=mode,
root=root,
doc_stride=doc_stride,
max_query_length=max_query_length,
max_seq_length=max_seq_length,
**kwargs)
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = os.path.join(DATA_HOME, 'DuReader')
filename, data_hash = self.SEGMENTS[segment]
filename, data_hash = self.SPLITS[mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
......@@ -169,25 +169,25 @@ class DuReader(SQuAD):
class DuReaderRobust(SQuAD):
SEGMENT_INFO = collections.namedtuple('SEGMENT_INFO', ('file', 'md5'))
META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
DATA_URL = 'https://dataset-bj.cdn.bcebos.com/qianyan/dureader_robust-data.tar.gz'
SEGMENTS = {
'train': SEGMENT_INFO(
SPLITS = {
'train': META_INFO(
os.path.join('dureader_robust-data', 'train.json'),
'800a3dcb742f9fdf9b11e0a83433d4be'),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('dureader_robust-data', 'dev.json'),
'ae73cec081eaa28a735204c4898a2222'),
'test': SEGMENT_INFO(
'test': META_INFO(
os.path.join('dureader_robust-data', 'test.json'),
'e0e8aa5c7b6d11b6fc3935e29fc7746f')
}
def __init__(self,
tokenizer,
segment='train',
mode='train',
version_2_with_negative=True,
root=None,
doc_stride=128,
......@@ -197,7 +197,7 @@ class DuReaderRobust(SQuAD):
super(DuReaderRobust, self).__init__(
tokenizer=tokenizer,
segment=segment,
mode=mode,
version_2_with_negative=False,
root=root,
doc_stride=doc_stride,
......@@ -205,10 +205,10 @@ class DuReaderRobust(SQuAD):
max_seq_length=max_seq_length,
**kwargs)
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = os.path.join(DATA_HOME, 'DuReader')
filename, data_hash = self.SEGMENTS[segment]
filename, data_hash = self.SPLITS[mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
......@@ -226,38 +226,38 @@ class DuReaderRobust(SQuAD):
class DuReaderYesNo(Dataset):
SEGMENT_INFO = collections.namedtuple('SEGMENT_INFO', ('file', 'md5'))
META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
DATA_URL = 'https://dataset-bj.cdn.bcebos.com/qianyan/dureader_yesno-data.tar.gz'
SEGMENTS = {
'train': SEGMENT_INFO(
SPLITS = {
'train': META_INFO(
os.path.join('dureader_yesno-data', 'train.json'),
'c469a0ef3f975cfd705e3553ddb27cc1'),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('dureader_yesno-data', 'dev.json'),
'c38544f8b5a7b567492314e3232057b5'),
'test': SEGMENT_INFO(
'test': META_INFO(
os.path.join('dureader_yesno-data', 'test.json'),
'1c7a1a3ea5b8992eeaeea017fdc2d55f')
}
def __init__(self, segment='train', root=None, **kwargs):
def __init__(self, mode='train', root=None, **kwargs):
self._get_data(root, segment, **kwargs)
self._get_data(root, mode, **kwargs)
self._transform_func = None
if segment == 'train':
if mode == 'train':
self.is_training = True
else:
self.is_training = False
self._read()
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = os.path.join(DATA_HOME, 'DuReader')
filename, data_hash = self.SEGMENTS[segment]
filename, data_hash = self.SPLITS[mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
......
......@@ -41,30 +41,30 @@ __all__ = [
class _GlueDataset(TSVDataset):
URL = None
MD5 = None
SEGMENT_INFO = collections.namedtuple(
'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SEGMENTS = {} # mode: file, md5, field_indices, num_discard_samples
META_INFO = collections.namedtuple(
'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SPLITS = {} # mode: file, md5, field_indices, num_discard_samples
def __init__(self,
segment='train',
mode='train',
root=None,
return_all_fields=False,
**kwargs):
if return_all_fields:
# self.SEGMENTS = copy.deepcopy(self.__class__.SEGMENTS)
# self.SEGMENTS[segment].field_indices = segments
segments = copy.deepcopy(self.__class__.SEGMENTS)
segment_info = list(segments[segment])
segment_info[2] = None
segments[segment] = self.SEGMENT_INFO(*segment_info)
self.SEGMENTS = segments
# self.SPLITS = copy.deepcopy(self.__class__.SPLITS)
# self.SPLITS[mode].field_indices = splits
splits = copy.deepcopy(self.__class__.SPLITS)
mode_info = list(splits[mode])
mode_info[2] = None
splits[mode] = self.META_INFO(*mode_info)
self.SPLITS = splits
self._get_data(root, segment, **kwargs)
self._get_data(root, mode, **kwargs)
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = os.path.join(DATA_HOME, 'glue')
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
filename, data_hash, field_indices, num_discard_samples = self.SPLITS[
mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
......@@ -110,14 +110,14 @@ class GlueCoLA(_GlueDataset):
"""
URL = "https://dataset.bj.bcebos.com/glue/CoLA.zip"
MD5 = 'b178a7c2f397b0433c39c7caf50a3543'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('CoLA', 'train.tsv'),
'c79d4693b8681800338aa044bf9e797b', (3, 1), 0),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('CoLA', 'dev.tsv'), 'c5475ccefc9e7ca0917294b8bbda783c',
(3, 1), 0),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('CoLA', 'test.tsv'),
'd8721b7dedda0dcca73cebb2a9f4259f', (1, ), 1)
}
......@@ -156,14 +156,14 @@ class GlueSST2(_GlueDataset):
URL = 'https://dataset.bj.bcebos.com/glue/SST.zip'
MD5 = '9f81648d4199384278b86e315dac217c'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('SST-2', 'train.tsv'),
'da409a0a939379ed32a470bc0f7fe99a', (0, 1), 1),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('SST-2', 'dev.tsv'),
'268856b487b2a31a28c0a93daaff7288', (0, 1), 1),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('SST-2', 'test.tsv'),
'3230e4efec76488b87877a56ae49675a', (1, ), 1)
}
......@@ -209,22 +209,22 @@ class GlueMRPC(_GlueDataset):
TEST_DATA_URL = 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_test.txt'
TEST_DATA_MD5 = 'e437fdddb92535b820fe8852e2df8a49'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('MRPC', 'train.tsv'),
'dc2dac669a113866a6480a0b10cd50bf', (3, 4, 0), 1),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('MRPC', 'dev.tsv'), '185958e46ba556b38c6a7cc63f3a2135',
(3, 4, 0), 1),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('MRPC', 'test.tsv'),
'4825dab4b4832f81455719660b608de5', (3, 4), 1)
}
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = os.path.join(DATA_HOME, 'glue')
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
filename, data_hash, field_indices, num_discard_samples = self.SPLITS[
mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
......@@ -280,7 +280,7 @@ class GlueMRPC(_GlueDataset):
test_fh.write('%d\t%s\t%s\t%s\t%s\n' %
(idx, id1, id2, s1, s2))
root = default_root
super(GlueMRPC, self)._get_data(root, segment, **kwargs)
super(GlueMRPC, self)._get_data(root, mode, **kwargs)
def get_labels(self):
"""
......@@ -315,14 +315,14 @@ class GlueSTSB(_GlueDataset):
URL = 'https://dataset.bj.bcebos.com/glue/STS.zip'
MD5 = 'd573676be38f1a075a5702b90ceab3de'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('STS-B', 'train.tsv'),
'4f7a86dde15fe4832c18e5b970998672', (7, 8, 9), 1),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('STS-B', 'dev.tsv'),
'5f4d6b0d2a5f268b1b56db773ab2f1fe', (7, 8, 9), 1),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('STS-B', 'test.tsv'),
'339b5817e414d19d9bb5f593dd94249c', (7, 8), 1)
}
......@@ -365,19 +365,19 @@ class GlueQQP(_GlueDataset):
URL = 'https://dataset.bj.bcebos.com/glue/QQP.zip'
MD5 = '884bf26e39c783d757acc510a2a516ef'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('QQP', 'train.tsv'),
'e003db73d277d38bbd83a2ef15beb442', (3, 4, 5), 1),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('QQP', 'dev.tsv'), 'cff6a448d1580132367c22fc449ec214',
(3, 4, 5), 1),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('QQP', 'test.tsv'), '73de726db186b1b08f071364b2bb96d0',
(1, 2), 1)
}
def __init__(self, segment='train', root=None, return_all_fields=False):
def __init__(self, mode='train', root=None, return_all_fields=False):
# QQP may include broken samples
super(GlueQQP, self).__init__(
segment, root, return_all_fields, allow_missing=True)
......@@ -421,20 +421,20 @@ class GlueMNLI(_GlueDataset):
URL = 'https://dataset.bj.bcebos.com/glue/MNLI.zip'
MD5 = 'e343b4bdf53f927436d0792203b9b9ff'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('MNLI', 'train.tsv'),
'220192295e23b6705f3545168272c740', (8, 9, 11), 1),
'dev_matched': _GlueDataset.SEGMENT_INFO(
'dev_matched': _GlueDataset.META_INFO(
os.path.join('MNLI', 'dev_matched.tsv'),
'c3fa2817007f4cdf1a03663611a8ad23', (8, 9, 15), 1),
'dev_mismatched': _GlueDataset.SEGMENT_INFO(
'dev_mismatched': _GlueDataset.META_INFO(
os.path.join('MNLI', 'dev_mismatched.tsv'),
'b219e6fe74e4aa779e2f417ffe713053', (8, 9, 15), 1),
'test_matched': _GlueDataset.SEGMENT_INFO(
'test_matched': _GlueDataset.META_INFO(
os.path.join('MNLI', 'test_matched.tsv'),
'33ea0389aedda8a43dabc9b3579684d9', (8, 9), 1),
'test_mismatched': _GlueDataset.SEGMENT_INFO(
'test_mismatched': _GlueDataset.META_INFO(
os.path.join('MNLI', 'test_mismatched.tsv'),
'7d2f60a73d54f30d8a65e474b615aeb6', (8, 9), 1),
}
......@@ -481,14 +481,14 @@ class GlueQNLI(_GlueDataset):
"""
URL = 'https://dataset.bj.bcebos.com/glue/QNLI.zip'
MD5 = 'b4efd6554440de1712e9b54e14760e82'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('QNLI', 'train.tsv'),
'5e6063f407b08d1f7c7074d049ace94a', (1, 2, 3), 1),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('QNLI', 'dev.tsv'), '1e81e211959605f144ba6c0ad7dc948b',
(1, 2, 3), 1),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('QNLI', 'test.tsv'),
'f2a29f83f3fe1a9c049777822b7fa8b0', (1, 2), 1)
}
......@@ -531,14 +531,14 @@ class GlueRTE(_GlueDataset):
URL = 'https://dataset.bj.bcebos.com/glue/RTE.zip'
MD5 = 'bef554d0cafd4ab6743488101c638539'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('RTE', 'train.tsv'),
'd2844f558d111a16503144bb37a8165f', (1, 2, 3), 1),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('RTE', 'dev.tsv'), '973cb4178d4534cf745a01c309d4a66c',
(1, 2, 3), 1),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('RTE', 'test.tsv'), '6041008f3f3e48704f57ce1b88ad2e74',
(1, 2), 1)
}
......@@ -582,14 +582,14 @@ class GlueWNLI(_GlueDataset):
URL = 'https://dataset.bj.bcebos.com/glue/WNLI.zip'
MD5 = 'a1b4bd2861017d302d29e42139657a42'
SEGMENTS = {
'train': _GlueDataset.SEGMENT_INFO(
SPLITS = {
'train': _GlueDataset.META_INFO(
os.path.join('WNLI', 'train.tsv'),
'5cdc5a87b7be0c87a6363fa6a5481fc1', (1, 2, 3), 1),
'dev': _GlueDataset.SEGMENT_INFO(
'dev': _GlueDataset.META_INFO(
os.path.join('WNLI', 'dev.tsv'), 'a79a6dd5d71287bcad6824c892e517ee',
(1, 2, 3), 1),
'test': _GlueDataset.SEGMENT_INFO(
'test': _GlueDataset.META_INFO(
os.path.join('WNLI', 'test.tsv'),
'a18789ba4f60f6fdc8cb4237e4ba24b5', (1, 2), 1)
}
......
......@@ -38,38 +38,38 @@ class LCQMC(TSVDataset):
URL = "https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz"
MD5 = "62a7ba36f786a82ae59bbde0b0a9af0c"
SEGMENT_INFO = collections.namedtuple(
'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SEGMENTS = {
'train': SEGMENT_INFO(
META_INFO = collections.namedtuple(
'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SPLITS = {
'train': META_INFO(
os.path.join('lcqmc', 'train.tsv'),
'2193c022439b038ac12c0ae918b211a1', (0, 1, 2), 1),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('lcqmc', 'dev.tsv'),
'c5dcba253cb4105d914964fd8b3c0e94', (0, 1, 2), 1),
'test': SEGMENT_INFO(
'test': META_INFO(
os.path.join('lcqmc', 'test.tsv'),
'8f4b71e15e67696cc9e112a459ec42bd', (0, 1, 2), 1)
}
def __init__(self,
segment='train',
mode='train',
root=None,
return_all_fields=False,
**kwargs):
if return_all_fields:
segments = copy.deepcopy(self.__class__.SEGMENTS)
segment_info = list(segments[segment])
segment_info[2] = None
segments[segment] = self.SEGMENT_INFO(*segment_info)
self.SEGMENTS = segments
splits = copy.deepcopy(self.__class__.SPLITS)
mode_info = list(splits[mode])
mode_info[2] = None
splits[mode] = self.META_INFO(*mode_info)
self.SPLITS = splits
self._get_data(root, segment, **kwargs)
self._get_data(root, mode, **kwargs)
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = DATA_HOME
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
filename, data_hash, field_indices, num_discard_samples = self.SPLITS[
mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
......
......@@ -29,30 +29,30 @@ __all__ = ['MSRA_NER']
class MSRA_NER(TSVDataset):
URL = "https://bj.bcebos.com/paddlehub-dataset/msra_ner.tar.gz"
MD5 = None
SEGMENT_INFO = collections.namedtuple(
'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SEGMENTS = {
'train': SEGMENT_INFO(
META_INFO = collections.namedtuple(
'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SPLITS = {
'train': META_INFO(
os.path.join('msra_ner', 'train.tsv'),
'67d3c93a37daba60ef43c03271f119d7',
(0, 1),
1, ),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('msra_ner', 'dev.tsv'),
'ec772f3ba914bca5269f6e785bb3375d',
(0, 1),
1, ),
'test': SEGMENT_INFO(
'test': META_INFO(
os.path.join('msra_ner', 'test.tsv'),
'2f27ae68b5f61d6553ffa28bb577c8a7',
(0, 1),
1, ),
}
def __init__(self, segment='train', root=None, **kwargs):
def __init__(self, mode='train', root=None, **kwargs):
default_root = os.path.join(DATA_HOME, 'msra')
filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
segment]
filename, data_hash, field_indices, num_discard_samples = self.SPLITS[
mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
......
......@@ -66,7 +66,7 @@ class InputFeatures(object):
class SQuAD(Dataset):
SEGMENT_INFO = collections.namedtuple('SEGMENT_INFO', ('file', 'md5'))
META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
DEV_DATA_URL_V2 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v2.0.json'
TRAIN_DATA_URL_V2 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v2.0.json'
......@@ -74,20 +74,20 @@ class SQuAD(Dataset):
DEV_DATA_URL_V1 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v1.1.json'
TRAIN_DATA_URL_V1 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v1.1.json'
SEGMENTS = {
SPLITS = {
'1.1': {
'train': SEGMENT_INFO(
'train': META_INFO(
os.path.join('v1', 'train-v1.1.json'),
'981b29407e0affa3b1b156f72073b945'),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('v1', 'dev-v1.1.json'),
'3e85deb501d4e538b6bc56f786231552')
},
'2.0': {
'train': SEGMENT_INFO(
'train': META_INFO(
os.path.join('v2', 'train-v2.0.json'),
'62108c273c268d70893182d5cf8df740'),
'dev': SEGMENT_INFO(
'dev': META_INFO(
os.path.join('v2', 'dev-v2.0.json'),
'246adae8b7002f8679c027697b0b7cf8')
}
......@@ -95,7 +95,7 @@ class SQuAD(Dataset):
def __init__(self,
tokenizer,
segment='train',
mode='train',
version_2_with_negative=True,
root=None,
doc_stride=128,
......@@ -104,7 +104,7 @@ class SQuAD(Dataset):
**kwargs):
self.version_2_with_negative = version_2_with_negative
self._get_data(root, segment, **kwargs)
self._get_data(root, mode, **kwargs)
self.tokenizer = tokenizer
self.doc_stride = doc_stride
self.max_query_length = max_query_length
......@@ -112,7 +112,7 @@ class SQuAD(Dataset):
self._transform_func = None
if segment == 'train':
if mode == 'train':
self.is_training = True
else:
self.is_training = False
......@@ -126,12 +126,12 @@ class SQuAD(Dataset):
max_query_length=self.max_query_length,
max_seq_length=self.max_seq_length)
def _get_data(self, root, segment, **kwargs):
def _get_data(self, root, mode, **kwargs):
default_root = os.path.join(DATA_HOME, 'SQuAD')
if self.version_2_with_negative:
filename, data_hash = self.SEGMENTS['2.0'][segment]
filename, data_hash = self.SPLITS['2.0'][mode]
else:
filename, data_hash = self.SEGMENTS['1.1'][segment]
filename, data_hash = self.SPLITS['1.1'][mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
......@@ -141,7 +141,7 @@ class SQuAD(Dataset):
warnings.warn(
'md5 check failed for {}, download {} data to {}'.format(
filename, self.__class__.__name__, default_root))
if segment == 'train':
if mode == 'train':
if self.version_2_with_negative:
fullname = get_path_from_url(
self.TRAIN_DATA_URL_V2,
......@@ -150,7 +150,7 @@ class SQuAD(Dataset):
fullname = get_path_from_url(
self.TRAIN_DATA_URL_V1,
os.path.join(default_root, 'v1'))
elif segment == 'dev':
elif mode == 'dev':
if self.version_2_with_negative:
fullname = get_path_from_url(
self.DEV_DATA_URL_V2, os.path.join(default_root, 'v2'))
......
......@@ -142,7 +142,7 @@ class TranslationDataset(paddle.io.Dataset):
if not os.path.exists(root):
os.makedirs(root)
print("IWSLT will be downloaded at ", root)
get_path_from_url(self.URL, root)
get_path_from_url(cls.URL, root)
print("Downloaded success......")
else:
filename_list = [
......@@ -155,7 +155,7 @@ class TranslationDataset(paddle.io.Dataset):
if not os.path.exists(file_path):
print(
"The dataset is incomplete and will be re-downloaded.")
get_path_from_url(self.URL, root)
get_path_from_url(cls.URL, root)
print("Downloaded success......")
break
return data_dir
......@@ -245,13 +245,13 @@ class IWSLT15(TranslationDataset):
eos_token = '</s>'
dataset_dirname = "iwslt15.en-vi"
def __init__(self, segment='train', root=None, transform_func=None):
def __init__(self, mode='train', root=None, transform_func=None):
# Input check
segment_select = ('train', 'dev', 'test')
if segment not in segment_select:
if mode not in segment_select:
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(segment))
format(mode))
if transform_func is not None:
if len(transform_func) != 2:
raise ValueError("`transform_func` must have length of two for"
......@@ -259,7 +259,7 @@ class IWSLT15(TranslationDataset):
# Download data
data_path = IWSLT15.get_data(root)
dataset = setup_datasets(self.train_filenames, self.valid_filenames,
self.test_filenames, [segment], data_path)[0]
self.test_filenames, [mode], data_path)[0]
self.data = dataset.data
if transform_func is not None:
self.data = [(transform_func[0](data[0]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册