未验证 提交 1a72a903 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add map style dataset (#26004)

* add map_style dataset. test=develop
上级 644dfd7d
......@@ -15,11 +15,41 @@
from . import folder
from . import mnist
from . import flowers
from . import cifar
from . import voc2012
from . import conll05
from . import imdb
from . import imikolov
from . import movielens
from . import movie_reviews
from . import uci_housing
from . import wmt14
from . import wmt16
from .folder import *
from .mnist import *
from .flowers import *
from .cifar import *
from .voc2012 import *
from .conll05 import *
from .imdb import *
from .imikolov import *
from .movielens import *
from .movie_reviews import *
from .uci_housing import *
from .wmt14 import *
from .wmt16 import *
__all__ = folder.__all__ \
+ mnist.__all__ \
+ flowers.__all__
+ mnist.__all__ \
+ flowers.__all__ \
+ cifar.__all__ \
+ voc2012.__all__ \
+ conll05.__all__ \
+ imdb.__all__ \
+ imikolov.__all__ \
+ movielens.__all__ \
+ movie_reviews.__all__ \
+ uci_housing.__all__ \
+ wmt14.__all__ \
+ wmt16.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import tarfile
import numpy as np
import six
from six.moves import cPickle as pickle
from paddle.io import Dataset
from .utils import _check_exists_and_download
__all__ = ['Cifar10', 'Cifar100']
URL_PREFIX = 'https://dataset.bj.bcebos.com/cifar/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
MODE_FLAG_MAP = {
'train10': 'data_batch',
'test10': 'test_batch',
'train100': 'train',
'test100': 'test'
}
class Cifar10(Dataset):
"""
Implementation of `Cifar-10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
dataset, which has 10 categories.
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' mode. Default 'train'.
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of cifar-10 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Cifar10
from paddle.incubate.hapi.vision.transforms import Normalize
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = paddle.nn.Linear(3072, 10, act='softmax')
def forward(self, image, label):
image = paddle.reshape(image, (3, -1))
return self.fc(image), label
paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
cifar10 = Cifar10(mode='train', transform=normalize)
for i in range(10):
image, label = cifar10[i]
image = paddle.to_tensor(image)
label = paddle.to_tensor(label)
model = SimpleNet()
image, label = model(image, label)
print(image.numpy().shape, label.numpy().shape)
"""
def __init__(self,
data_file=None,
mode='train',
transform=None,
download=True):
assert mode.lower() in ['train', 'test', 'train', 'test'], \
"mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode)
self.mode = mode.lower()
self._init_url_md5_flag()
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file, self.data_url, self.data_md5, 'cifar', download)
self.transform = transform
# read dataset into memory
self._load_data()
def _init_url_md5_flag(self):
self.data_url = CIFAR10_URL
self.data_md5 = CIFAR10_MD5
self.flag = MODE_FLAG_MAP[self.mode + '10']
def _load_data(self):
self.data = []
with tarfile.open(self.data_file, mode='r') as f:
names = (each_item.name for each_item in f
if self.flag in each_item.name)
for name in names:
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(f.extractfile(name), encoding='bytes')
data = batch[six.b('data')]
labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in six.moves.zip(data, labels):
self.data.append((sample, label))
def __getitem__(self, idx):
image, label = self.data[idx]
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.data)
class Cifar100(Cifar10):
"""
Implementation of `Cifar-100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
dataset, which has 100 categories.
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' mode. Default 'train'.
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of cifar-100 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Cifar100
from paddle.incubate.hapi.vision.transforms import Normalize
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = paddle.nn.Linear(3072, 100, act='softmax')
def forward(self, image, label):
image = paddle.reshape(image, (3, -1))
return self.fc(image), label
paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
cifar100 = Cifar100(mode='train', transform=normalize)
for i in range(10):
image, label = cifar100[i]
image = paddle.to_tensor(image)
label = paddle.to_tensor(label)
model = SimpleNet()
image, label = model(image, label)
print(image.numpy().shape, label.numpy().shape)
"""
def __init__(self,
data_file=None,
mode='train',
transform=None,
download=True):
super(Cifar100, self).__init__(data_file, mode, transform, download)
def _init_url_md5_flag(self):
self.data_url = CIFAR100_URL
self.data_md5 = CIFAR100_MD5
self.flag = MODE_FLAG_MAP[self.mode + '100']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gzip
import tarfile
import numpy as np
import six
from six.moves import cPickle as pickle
from paddle.io import Dataset
import paddle.compat as cpt
from .utils import _check_exists_and_download
__all__ = ['Conll05st']
DATA_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/conll05st-tests.tar.gz'
DATA_MD5 = '387719152ae52d60422c016e92a742fc'
WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt'
WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa'
VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FverbDict.txt'
VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c'
TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FtargetDict.txt'
TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751'
EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2Femb'
EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7'
UNK_IDX = 0
class Conll05st(Dataset):
"""
Implementation of `Conll05st <https://www.cs.upc.edu/~srlconll/soft.html>`_
test dataset.
Note: only support download test dataset automatically for that
only test dataset of Conll05st is public.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
word_dict_file(str): path to word dictionary file, can be set None if
:attr:`download` is True. Default None
verb_dict_file(str): path to verb dictionary file, can be set None if
:attr:`download` is True. Default None
target_dict_file(str): path to target dictionary file, can be set None if
:attr:`download` is True. Default None
emb_file(str): path to embedding dictionary file, only used for
:code:`get_embedding` can be set None if :attr:`download` is
True. Default None
download(bool): whether to download dataset automatically if
:attr:`data_file` :attr:`word_dict_file` :attr:`verb_dict_file`
:attr:`target_dict_file` is not set. Default True
Returns:
Dataset: instance of conll05st dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Conll05st
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, pred_idx, mark, label):
return paddle.sum(pred_idx), paddle.sum(mark), paddle.sum(label)
paddle.disable_static()
conll05st = Conll05st()
for i in range(10):
pred_idx, mark, label= conll05st[i][-3:]
pred_idx = paddle.to_tensor(pred_idx)
mark = paddle.to_tensor(mark)
label = paddle.to_tensor(label)
model = SimpleNet()
pred_idx, mark, label= model(pred_idx, mark, label)
print(pred_idx.numpy(), mark.numpy(), label.numpy())
"""
def __init__(self,
data_file=None,
word_dict_file=None,
verb_dict_file=None,
target_dict_file=None,
emb_file=None,
download=True):
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file, DATA_URL, DATA_MD5, 'conll05st', download)
self.word_dict_file = word_dict_file
if self.word_dict_file is None:
assert download, "word_dict_file is not set and downloading automatically is disabled"
self.word_dict_file = _check_exists_and_download(
word_dict_file, WORDDICT_URL, WORDDICT_MD5, 'conll05st',
download)
self.verb_dict_file = verb_dict_file
if self.verb_dict_file is None:
assert download, "verb_dict_file is not set and downloading automatically is disabled"
self.verb_dict_file = _check_exists_and_download(
verb_dict_file, VERBDICT_URL, VERBDICT_MD5, 'conll05st',
download)
self.target_dict_file = target_dict_file
if self.target_dict_file is None:
assert download, "target_dict_file is not set and downloading automatically is disabled"
self.target_dict_file = _check_exists_and_download(
target_dict_file, TRGDICT_URL, TRGDICT_MD5, 'conll05st',
download)
self.word_dict = self._load_dict(self.word_dict_file)
self.predicate_dict = self._load_dict(self.verb_dict_file)
self.label_dict = self._load_label_dict(self.target_dict_file)
# read dataset into memory
self._load_anno()
def _load_label_dict(self, filename):
d = dict()
tag_dict = set()
with open(filename, 'r') as f:
for i, line in enumerate(f):
line = line.strip()
if line.startswith("B-"):
tag_dict.add(line[2:])
elif line.startswith("I-"):
tag_dict.add(line[2:])
index = 0
for tag in tag_dict:
d["B-" + tag] = index
index += 1
d["I-" + tag] = index
index += 1
d["O"] = index
return d
def _load_dict(self, filename):
d = dict()
with open(filename, 'r') as f:
for i, line in enumerate(f):
d[line.strip()] = i
return d
def _load_anno(self):
tf = tarfile.open(self.data_file)
wf = tf.extractfile(
"conll05st-release/test.wsj/words/test.wsj.words.gz")
pf = tf.extractfile(
"conll05st-release/test.wsj/props/test.wsj.props.gz")
self.sentences = []
self.predicates = []
self.labels = []
with gzip.GzipFile(fileobj=wf) as words_file, gzip.GzipFile(
fileobj=pf) as props_file:
sentences = []
labels = []
one_seg = []
for word, label in zip(words_file, props_file):
word = cpt.to_text(word.strip())
label = cpt.to_text(label.strip().split())
if len(label) == 0: # end of sentence
for i in range(len(one_seg[0])):
a_kind_lable = [x[i] for x in one_seg]
labels.append(a_kind_lable)
if len(labels) >= 1:
verb_list = []
for x in labels[0]:
if x != '-':
verb_list.append(x)
for i, lbl in enumerate(labels[1:]):
cur_tag = 'O'
is_in_bracket = False
lbl_seq = []
verb_word = ''
for l in lbl:
if l == '*' and is_in_bracket == False:
lbl_seq.append('O')
elif l == '*' and is_in_bracket == True:
lbl_seq.append('I-' + cur_tag)
elif l == '*)':
lbl_seq.append('I-' + cur_tag)
is_in_bracket = False
elif l.find('(') != -1 and l.find(')') != -1:
cur_tag = l[1:l.find('*')]
lbl_seq.append('B-' + cur_tag)
is_in_bracket = False
elif l.find('(') != -1 and l.find(')') == -1:
cur_tag = l[1:l.find('*')]
lbl_seq.append('B-' + cur_tag)
is_in_bracket = True
else:
raise RuntimeError('Unexpected label: %s' %
l)
self.sentences.append(sentences)
self.predicates.append(verb_list[i])
self.labels.append(lbl_seq)
sentences = []
labels = []
one_seg = []
else:
sentences.append(word)
one_seg.append(label)
pf.close()
wf.close()
tf.close()
def __getitem__(self, idx):
sentence = self.sentences[idx]
predicate = self.predicates[idx]
labels = self.labels[idx]
sen_len = len(sentence)
verb_index = labels.index('B-V')
mark = [0] * len(labels)
if verb_index > 0:
mark[verb_index - 1] = 1
ctx_n1 = sentence[verb_index - 1]
else:
ctx_n1 = 'bos'
if verb_index > 1:
mark[verb_index - 2] = 1
ctx_n2 = sentence[verb_index - 2]
else:
ctx_n2 = 'bos'
mark[verb_index] = 1
ctx_0 = sentence[verb_index]
if verb_index < len(labels) - 1:
mark[verb_index + 1] = 1
ctx_p1 = sentence[verb_index + 1]
else:
ctx_p1 = 'eos'
if verb_index < len(labels) - 2:
mark[verb_index + 2] = 1
ctx_p2 = sentence[verb_index + 2]
else:
ctx_p2 = 'eos'
word_idx = [self.word_dict.get(w, UNK_IDX) for w in sentence]
ctx_n2_idx = [self.word_dict.get(ctx_n2, UNK_IDX)] * sen_len
ctx_n1_idx = [self.word_dict.get(ctx_n1, UNK_IDX)] * sen_len
ctx_0_idx = [self.word_dict.get(ctx_0, UNK_IDX)] * sen_len
ctx_p1_idx = [self.word_dict.get(ctx_p1, UNK_IDX)] * sen_len
ctx_p2_idx = [self.word_dict.get(ctx_p2, UNK_IDX)] * sen_len
pred_idx = [self.predicate_dict.get(predicate)] * sen_len
label_idx = [self.label_dict.get(w) for w in labels]
return (np.array(word_idx), np.array(ctx_n2_idx), np.array(ctx_n1_idx),
np.array(ctx_0_idx), np.array(ctx_p1_idx), np.array(ctx_p2_idx),
np.array(pred_idx), np.array(mark), np.array(label_idx))
def __len__(self):
return len(self.sentences)
def get_dict(self):
"""
Get the word, verb and label dictionary of Wikipedia corpus.
"""
return self.word_dict, self.predicate_dict, self.label_dict
def get_embedding(self):
return self.emb_file
......@@ -36,12 +36,13 @@ SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': "valid"}
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
class Flowers(Dataset):
"""
Implement of flowers dataset
Implementation of `Flowers <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
dataset
Args:
data_file(str): path to data file, can be set None if
......@@ -51,9 +52,9 @@ class Flowers(Dataset):
setid_file(str): path to subset index file, can be set
None if :attr:`download` is True. Default None
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
True
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Examples:
......@@ -82,19 +83,19 @@ class Flowers(Dataset):
self.data_file = data_file
if self.data_file is None:
assert download, "data_file not set and auto download disabled"
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file, DATA_URL, DATA_MD5, 'flowers', download)
self.label_file = label_file
if self.label_file is None:
assert download, "label_file not set and auto download disabled"
assert download, "label_file is not set and downloading automatically is disabled"
self.label_file = _check_exists_and_download(
label_file, LABEL_URL, LABEL_MD5, 'flowers', download)
self.setid_file = setid_file
if self.setid_file is None:
assert download, "setid_file not set and auto download disabled"
assert download, "setid_file is not set and downloading automatically is disabled"
self.setid_file = _check_exists_and_download(
setid_file, SETID_URL, SETID_MD5, 'flowers', download)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import re
import six
import string
import tarfile
import numpy as np
import collections
from paddle.io import Dataset
from .utils import _check_exists_and_download
__all__ = ['Imdb']
URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
class Imdb(Dataset):
"""
Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' 'test' mode. Default 'train'.
cutoff(int): cutoff number for building word dictionary. Default 150.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of IMDB dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Imdb
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, doc, label):
return paddle.sum(doc), label
paddle.disable_static()
imdb = Imdb(mode='train')
for i in range(10):
doc, label = imdb[i]
doc = paddle.to_tensor(doc)
label = paddle.to_tensor(label)
model = SimpleNet()
image, label = model(doc, label)
print(doc.numpy().shape, label.numpy().shape)
"""
def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
assert mode.lower() in ['train', 'test'], \
"mode should be 'train', 'test', but got {}".format(mode)
self.mode = mode.lower()
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(data_file, URL, MD5,
'imdb', download)
# Build a word dictionary from the corpus
self.word_idx = self._build_work_dict(cutoff)
# read dataset into memory
self._load_anno()
def _build_work_dict(self, cutoff):
word_freq = collections.defaultdict(int)
pattern = re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$")
for doc in self._tokenize(pattern):
for word in doc:
word_freq[word] += 1
# Not sure if we should prune less-frequent words here.
word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words)
return word_idx
def _tokenize(self, pattern):
data = []
with tarfile.open(self.data_file) as tarf:
tf = tarf.next()
while tf != None:
if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization.
data.append(
tarf.extractfile(tf).read().rstrip(six.b("\n\r"))
.translate(None, six.b(string.punctuation)).lower(
).split())
tf = tarf.next()
return data
def _load_anno(self):
pos_pattern = re.compile("aclImdb/{}/pos/.*\.txt$".format(self.mode))
neg_pattern = re.compile("aclImdb/{}/neg/.*\.txt$".format(self.mode))
UNK = self.word_idx['<unk>']
self.docs = []
self.labels = []
for doc in self._tokenize(pos_pattern):
self.docs.append([self.word_idx.get(w, UNK) for w in doc])
self.labels.append(0)
for doc in self._tokenize(neg_pattern):
self.docs.append([self.word_idx.get(w, UNK) for w in doc])
self.labels.append(1)
def __getitem__(self, idx):
return (np.array(self.docs[idx]), np.array([self.labels[idx]]))
def __len__(self):
return len(self.docs)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import six
import tarfile
import numpy as np
import collections
from paddle.io import Dataset
from .utils import _check_exists_and_download
__all__ = ['Imikolov']
URL = 'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
class Imikolov(Dataset):
"""
Implementation of imikolov dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
data_type(str): 'NGRAM' or 'SEQ'. Default 'NGRAM'.
window_size(int): sliding window size for 'NGRAM' data. Default -1.
mode(str): 'train' 'test' mode. Default 'train'.
min_word_freq(int): minimal word frequence for building word dictionary. Default 50.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of imikolov dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Imikolov
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, src, trg):
return paddle.sum(src), paddle.sum(trg)
paddle.disable_static()
imikolov = Imikolov(mode='train', data_type='SEQ', window_size=2)
for i in range(10):
src, trg = imikolov[i]
src = paddle.to_tensor(src)
trg = paddle.to_tensor(trg)
model = SimpleNet()
src, trg = model(src, trg)
print(src.numpy().shape, trg.numpy().shape)
"""
def __init__(self,
data_file=None,
data_type='NGRAM',
window_size=-1,
mode='train',
min_word_freq=50,
download=True):
assert data_type.upper() in ['NGRAM', 'SEQ'], \
"data type should be 'NGRAM', 'SEQ', but got {}".format(data_type)
self.data_type = data_type.upper()
assert mode.lower() in ['train', 'test'], \
"mode should be 'train', 'test', but got {}".format(mode)
self.mode = mode.lower()
self.window_size = window_size
self.min_word_freq = min_word_freq
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically disabled"
self.data_file = _check_exists_and_download(data_file, URL, MD5,
'imikolov', download)
# Build a word dictionary from the corpus
self.word_idx = self._build_work_dict(min_word_freq)
# read dataset into memory
self._load_anno()
def word_count(self, f, word_freq=None):
if word_freq is None:
word_freq = collections.defaultdict(int)
for l in f:
for w in l.strip().split():
word_freq[w] += 1
word_freq['<s>'] += 1
word_freq['<e>'] += 1
return word_freq
def _build_work_dict(self, cutoff):
train_filename = './simple-examples/data/ptb.train.txt'
test_filename = './simple-examples/data/ptb.valid.txt'
with tarfile.open(self.data_file) as tf:
trainf = tf.extractfile(train_filename)
testf = tf.extractfile(test_filename)
word_freq = self.word_count(testf, self.word_count(trainf))
if '<unk>' in word_freq:
# remove <unk> for now, since we will set it as last index
del word_freq['<unk>']
word_freq = [
x for x in six.iteritems(word_freq) if x[1] > self.min_word_freq
]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words)
return word_idx
def _load_anno(self):
self.data = []
with tarfile.open(self.data_file) as tf:
filename = './simple-examples/data/ptb.{}.txt'.format(self.mode)
f = tf.extractfile(filename)
UNK = self.word_idx['<unk>']
for l in f:
if self.data_type == 'NGRAM':
assert self.window_size > -1, 'Invalid gram length'
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= self.window_size:
l = [self.word_idx.get(w, UNK) for w in l]
for i in six.moves.range(self.window_size, len(l) + 1):
self.data.append(tuple(l[i - self.window_size:i]))
elif self.data_type == 'SEQ':
l = l.strip().split()
l = [self.word_idx.get(w, UNK) for w in l]
src_seq = [self.word_idx['<s>']] + l
trg_seq = l + [self.word_idx['<e>']]
if self.window_size > 0 and len(src_seq) > self.window_size:
continue
self.data.append((src_seq, trg_seq))
else:
assert False, 'Unknow data type'
def __getitem__(self, idx):
return tuple([np.array(d) for d in self.data[idx]])
def __len__(self):
return len(self.data)
......@@ -38,7 +38,7 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
class MNIST(Dataset):
"""
Implement of MNIST dataset
Implementation of `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset
Args:
image_path(str): path to image file, can be set None if
......@@ -48,9 +48,8 @@ class MNIST(Dataset):
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
True
download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True
Returns:
Dataset: MNIST Dataset.
......@@ -82,7 +81,7 @@ class MNIST(Dataset):
self.chw_format = chw_format
self.image_path = image_path
if self.image_path is None:
assert download, "image_path not set and auto download disabled"
assert download, "image_path is not set and downloading automatically is disabled"
image_url = TRAIN_IMAGE_URL if mode == 'train' else TEST_IMAGE_URL
image_md5 = TRAIN_IMAGE_MD5 if mode == 'train' else TEST_IMAGE_MD5
self.image_path = _check_exists_and_download(
......@@ -90,9 +89,9 @@ class MNIST(Dataset):
self.label_path = label_path
if self.label_path is None:
assert download, "label_path not set and auto download disabled"
label_url = TRAIN_LABEL_URL if mode == 'train' else TEST_LABEL_URL
label_md5 = TRAIN_LABEL_MD5 if mode == 'train' else TEST_LABEL_MD5
assert download, "label_path is not set and downloading automatically is disabled"
label_url = TRAIN_LABEL_URL if self.mode == 'train' else TEST_LABEL_URL
label_md5 = TRAIN_LABEL_MD5 if self.mode == 'train' else TEST_LABEL_MD5
self.label_path = _check_exists_and_download(
label_path, label_url, label_md5, 'mnist', download)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import six
import numpy as np
import collections
import nltk
from nltk.corpus import movie_reviews
import zipfile
from functools import cmp_to_key
from itertools import chain
import paddle
from paddle.io import Dataset
__all__ = ['MovieReviews']
URL = "https://corpora.bj.bcebos.com/movie_reviews%2Fmovie_reviews.zip"
MD5 = '155de2b77c6834dd8eea7cbe88e93acb'
NUM_TRAINING_INSTANCES = 1600
NUM_TOTAL_INSTANCES = 2000
class MovieReviews(Dataset):
"""
Implementation of `NLTK movie reviews <http://www.nltk.org/nltk_data/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' 'test' mode. Default 'train'.
download(bool): whether auto download cifar dataset if
:attr:`data_file` unset. Default True.
Returns:
Dataset: instance of movie reviews dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import MovieReviews
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, word, category):
return paddle.sum(word), category
paddle.disable_static()
movie_reviews = MovieReviews(mode='train')
for i in range(10):
word_list, category = movie_reviews[i]
word_list = paddle.to_tensor(word_list)
category = paddle.to_tensor(category)
model = SimpleNet()
word_list, category = model(word_list, category)
print(word_list.numpy().shape, category.numpy())
"""
def __init__(self, mode='train'):
assert mode.lower() in ['train', 'test'], \
"mode should be 'train', 'test', but got {}".format(mode)
self.mode = mode.lower()
self._download_data_if_not_yet()
# read dataset into memory
self._load_sentiment_data()
def _get_word_dict(self):
"""
Sorted the words by the frequency of words which occur in sample
:return:
words_freq_sorted
"""
words_freq_sorted = list()
word_freq_dict = collections.defaultdict(int)
for category in movie_reviews.categories():
for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field):
word_freq_dict[words] += 1
words_sort_list = list(six.iteritems(word_freq_dict))
words_sort_list.sort(key=cmp_to_key(lambda a, b: b[1] - a[1]))
for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index))
return words_freq_sorted
def _sort_files(self):
"""
Sorted the sample for cross reading the sample
:return:
files_list
"""
files_list = list()
neg_file_list = movie_reviews.fileids('neg')
pos_file_list = movie_reviews.fileids('pos')
files_list = list(
chain.from_iterable(list(zip(neg_file_list, pos_file_list))))
return files_list
def _load_sentiment_data(self):
"""
Load the data set
:return:
data_set
"""
self.data = []
words_ids = dict(self._get_word_dict())
for sample_file in self._sort_files():
words_list = list()
category = 0 if 'neg' in sample_file else 1
for word in movie_reviews.words(sample_file):
words_list.append(words_ids[word.lower()])
self.data.append((words_list, category))
def _download_data_if_not_yet(self):
"""
Download the data set, if the data set is not download.
"""
try:
# download and extract movie_reviews.zip
paddle.dataset.common.download(
URL, 'corpora', md5sum=MD5, save_name='movie_reviews.zip')
path = os.path.join(paddle.dataset.common.DATA_HOME, 'corpora')
filename = os.path.join(path, 'movie_reviews.zip')
zip_file = zipfile.ZipFile(filename)
zip_file.extractall(path)
zip_file.close()
# make sure that nltk can find the data
if paddle.dataset.common.DATA_HOME not in nltk.data.path:
nltk.data.path.append(paddle.dataset.common.DATA_HOME)
movie_reviews.categories()
except LookupError:
print("Downloading movie_reviews data set, please wait.....")
nltk.download(
'movie_reviews', download_dir=paddle.dataset.common.DATA_HOME)
print("Download data set success.....")
print("Path is " + nltk.data.find('corpora/movie_reviews').path)
def __getitem__(self, idx):
if self.mode == 'test':
idx += NUM_TRAINING_INSTANCES
data = self.data[idx]
return np.array(data[0]), np.array(data[1])
def __len__(self):
if self.mode == 'train':
return NUM_TRAINING_INSTANCES
else:
return NUM_TOTAL_INSTANCES - NUM_TRAINING_INSTANCES
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import zipfile
import re
import random
import functools
import six
import paddle
from paddle.io import Dataset
import paddle.compat as cpt
from .utils import _check_exists_and_download
__all__ = ['Movielens']
age_table = [1, 18, 25, 35, 45, 50, 56]
URL = 'https://dataset.bj.bcebos.com/movielens%2Fml-1m.zip'
MD5 = 'c4d9eecfca2ab87c1945afe126590906'
class MovieInfo(object):
"""
Movie id, title and categories information are stored in MovieInfo.
"""
def __init__(self, index, categories, title):
self.index = int(index)
self.categories = categories
self.title = title
def value(self, categories_dict, movie_title_dict):
"""
Get information from a movie.
"""
return [[self.index], [categories_dict[c] for c in self.categories],
[movie_title_dict[w.lower()] for w in self.title.split()]]
def __str__(self):
return "<MovieInfo id(%d), title(%s), categories(%s)>" % (
self.index, self.title, self.categories)
def __repr__(self):
return self.__str__()
class UserInfo(object):
"""
User id, gender, age, and job information are stored in UserInfo.
"""
def __init__(self, index, gender, age, job_id):
self.index = int(index)
self.is_male = gender == 'M'
self.age = age_table.index(int(age))
self.job_id = int(job_id)
def value(self):
"""
Get information from a user.
"""
return [[self.index], [0 if self.is_male else 1], [self.age],
[self.job_id]]
def __str__(self):
return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % (
self.index, "M"
if self.is_male else "F", age_table[self.age], self.job_id)
def __repr__(self):
return str(self)
class Movielens(Dataset):
"""
Implementation of `Movielens 1-M <https://grouplens.org/datasets/movielens/1m/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' or 'test' mode. Default 'train'.
test_ratio(float): split ratio for test sample. Default 0.1.
rand_seed(int): random seed. Default 0.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of Movielens 1-M dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Movielens
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, category, title, rating):
return paddle.sum(category), paddle.sum(title), paddle.sum(rating)
paddle.disable_static()
movielens = Movielens(mode='train')
for i in range(10):
category, title, rating = movielens[i][-3:]
category = paddle.to_tensor(category)
title = paddle.to_tensor(title)
rating = paddle.to_tensor(rating)
model = SimpleNet()
category, title, rating = model(category, title, rating)
print(category.numpy().shape, title.numpy().shape, rating.numpy().shape)
"""
def __init__(self,
data_file=None,
mode='train',
test_ratio=0.1,
rand_seed=0,
download=True):
assert mode.lower() in ['train', 'test'], \
"mode should be 'train', 'test', but got {}".format(mode)
self.mode = mode.lower()
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(data_file, URL, MD5,
'sentiment', download)
self.test_ratio = test_ratio
self.rand_seed = rand_seed
np.random.seed(rand_seed)
self._load_meta_info()
self._load_data()
def _load_meta_info(self):
pattern = re.compile(r'^(.*)\((\d+)\)$')
self.movie_info = dict()
self.movie_title_dict = dict()
self.categories_dict = dict()
self.user_info = dict()
with zipfile.ZipFile(self.data_file) as package:
for info in package.infolist():
assert isinstance(info, zipfile.ZipInfo)
title_word_set = set()
categories_set = set()
with package.open('ml-1m/movies.dat') as movie_file:
for i, line in enumerate(movie_file):
line = cpt.to_text(line, encoding='latin')
movie_id, title, categories = line.strip().split('::')
categories = categories.split('|')
for c in categories:
categories_set.add(c)
title = pattern.match(title).group(1)
self.movie_info[int(movie_id)] = MovieInfo(
index=movie_id, categories=categories, title=title)
for w in title.split():
title_word_set.add(w.lower())
for i, w in enumerate(title_word_set):
self.movie_title_dict[w] = i
for i, c in enumerate(categories_set):
self.categories_dict[c] = i
with package.open('ml-1m/users.dat') as user_file:
for line in user_file:
line = cpt.to_text(line, encoding='latin')
uid, gender, age, job, _ = line.strip().split("::")
self.user_info[int(uid)] = UserInfo(
index=uid, gender=gender, age=age, job_id=job)
def _load_data(self):
self.data = []
is_test = self.mode == 'test'
with zipfile.ZipFile(self.data_file) as package:
with package.open('ml-1m/ratings.dat') as rating:
for line in rating:
line = cpt.to_text(line, encoding='latin')
if (np.random.random() < self.test_ratio) == is_test:
uid, mov_id, rating, _ = line.strip().split("::")
uid = int(uid)
mov_id = int(mov_id)
rating = float(rating) * 2 - 5.0
mov = self.movie_info[mov_id]
usr = self.user_info[uid]
self.data.append(usr.value() + \
mov.value(self.categories_dict, self.movie_title_dict) + \
[[rating]])
def __getitem__(self, idx):
data = self.data[idx]
return tuple([np.array(d) for d in data])
def __len__(self):
return len(self.data)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import six
import numpy as np
import paddle.dataset.common
from paddle.io import Dataset
from .utils import _check_exists_and_download
__all__ = ["UCIHousing"]
URL = 'http://paddlemodels.bj.bcebos.com/uci_housing/housing.data'
MD5 = 'd4accdce7a25600298819f8e28e8d593'
feature_names = [
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
'PTRATIO', 'B', 'LSTAT'
]
class UCIHousing(Dataset):
"""
Implementation of `UCI housing <https://archive.ics.uci.edu/ml/datasets/Housing>`_
dataset
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of UCI housing dataset.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import UCIHousing
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, feature, target):
return paddle.sum(feature), target
paddle.disable_static()
uci_housing = UCIHousing(mode='train')
for i in range(10):
feature, target = uci_housing[i]
feature = paddle.to_tensor(feature)
target = paddle.to_tensor(target)
model = SimpleNet()
feature, target = model(feature, target)
print(feature.numpy().shape, target.numpy())
"""
def __init__(self, data_file=None, mode='train', download=True):
assert mode.lower() in ['train', 'test'], \
"mode should be 'train' or 'test', but got {}".format(mode)
self.mode = mode.lower()
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(data_file, URL, MD5,
'uci_housing', download)
# read dataset into memory
self._load_data()
def _load_data(self, feature_num=14, ratio=0.8):
data = np.fromfile(self.data_file, sep=' ')
data = data.reshape(data.shape[0] // feature_num, feature_num)
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
axis=0) / data.shape[0]
for i in six.moves.range(feature_num - 1):
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
offset = int(data.shape[0] * ratio)
if self.mode == 'train':
self.data = data[:offset]
elif self.mode == 'test':
self.data = data[offset:]
def __getitem__(self, idx):
data = self.data[idx]
return np.array(data[:-1]), np.array(data[-1:])
def __len__(self):
return len(self.data)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import io
import tarfile
import numpy as np
from PIL import Image
from paddle.io import Dataset
from .utils import _check_exists_and_download
__all__ = ["VOC2012"]
VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\
VOCtrainval_11-May-2012.tar'
VOC_MD5 = '131da710f39b47a43fdfa256cbc11976'
SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png'
CACHE_DIR = 'voc2012'
MODE_FLAG_MAP = {'train': 'trainval', 'test': 'train', 'valid': "val"}
class VOC2012(Dataset):
"""
Implementation of `VOC2012 <http://host.robots.ox.ac.uk/pascal/VOC/voc2012/>`_ dataset
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import VOC2012
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, image, label):
return paddle.sum(image), label
paddle.disable_static()
voc2012 = VOC2012(mode='train')
for i in range(10):
image, label= voc2012[i]
image = paddle.cast(paddle.to_tensor(image), 'float32')
label = paddle.to_tensor(label)
model = SimpleNet()
image, label= model(image, label)
print(image.numpy().shape, label.numpy().shape)
"""
def __init__(self,
data_file=None,
mode='train',
transform=None,
download=True):
assert mode.lower() in ['train', 'valid', 'test'], \
"mode should be 'train', 'valid' or 'test', but got {}".format(mode)
self.flag = MODE_FLAG_MAP[mode.lower()]
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file, VOC_URL, VOC_MD5, CACHE_DIR, download)
self.transform = transform
# read dataset into memory
self._load_anno()
def _load_anno(self):
self.name2mem = {}
self.data_tar = tarfile.open(self.data_file)
for ele in self.data_tar.getmembers():
self.name2mem[ele.name] = ele
set_file = SET_FILE.format(self.flag)
sets = self.data_tar.extractfile(self.name2mem[set_file])
self.data = []
self.labels = []
for line in sets:
line = line.strip()
data = DATA_FILE.format(line.decode('utf-8'))
label = LABEL_FILE.format(line.decode('utf-8'))
self.data.append(data)
self.labels.append(label)
def __getitem__(self, idx):
data_file = self.data[idx]
label_file = self.labels[idx]
data = self.data_tar.extractfile(self.name2mem[data_file]).read()
label = self.data_tar.extractfile(self.name2mem[label_file]).read()
data = Image.open(io.BytesIO(data))
label = Image.open(io.BytesIO(label))
data = np.array(data)
label = np.array(label)
if self.transform is not None:
data = self.transform(data)
return data, label
def __len__(self):
return len(self.data)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import tarfile
import numpy as np
import gzip
from paddle.io import Dataset
import paddle.compat as cpt
from .utils import _check_exists_and_download
__all__ = ['WMT14']
URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/'
'cslm_joint_paper/data/dev+test.tgz')
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and
# will be add later.
URL_TRAIN = ('http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz')
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
START = "<s>"
END = "<e>"
UNK = "<unk>"
UNK_IDX = 2
class WMT14(Dataset):
"""
Implementation of `WMT14 <http://www.statmt.org/wmt14/>`_ test dataset.
The original WMT14 dataset is too large and a small set of data for set is
provided. This module will download dataset from
http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' or 'gen'. Default 'train'
dict_size(int): word dictionary size. Default -1.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of WMT14 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import WMT14
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, src_ids, trg_ids, trg_ids_next):
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
paddle.disable_static()
wmt14 = WMT14(mode='train', dict_size=50)
for i in range(10):
src_ids, trg_ids, trg_ids_next = wmt14[i]
src_ids = paddle.to_tensor(src_ids)
trg_ids = paddle.to_tensor(trg_ids)
trg_ids_next = paddle.to_tensor(trg_ids_next)
model = SimpleNet()
src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())
"""
def __init__(self,
data_file=None,
mode='train',
dict_size=-1,
download=True):
assert mode.lower() in ['train', 'test', 'gen'], \
"mode should be 'train', 'test' or 'gen', but got {}".format(mode)
self.mode = mode.lower()
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file, URL_TRAIN, MD5_TRAIN, 'wmt14', download)
# read dataset into memory
assert dict_size > 0, "dict_size should be set as positive number"
self.dict_size = dict_size
self._load_data()
def _load_data(self):
def __to_dict(fd, size):
out_dict = dict()
for line_count, line in enumerate(fd):
if line_count < size:
out_dict[cpt.to_text(line.strip())] = line_count
else:
break
return out_dict
self.src_ids = []
self.trg_ids = []
self.trg_ids_next = []
with tarfile.open(self.data_file, mode='r') as f:
names = [
each_item.name for each_item in f
if each_item.name.endswith("src.dict")
]
assert len(names) == 1
self.src_dict = __to_dict(f.extractfile(names[0]), self.dict_size)
names = [
each_item.name for each_item in f
if each_item.name.endswith("trg.dict")
]
assert len(names) == 1
self.trg_dict = __to_dict(f.extractfile(names[0]), self.dict_size)
file_name = "{}/{}".format(self.mode, self.mode)
names = [
each_item.name for each_item in f
if each_item.name.endswith(file_name)
]
for name in names:
for line in f.extractfile(name):
line = cpt.to_text(line)
line_split = line.strip().split('\t')
if len(line_split) != 2:
continue
src_seq = line_split[0] # one source sequence
src_words = src_seq.split()
src_ids = [
self.src_dict.get(w, UNK_IDX)
for w in [START] + src_words + [END]
]
trg_seq = line_split[1] # one target sequence
trg_words = trg_seq.split()
trg_ids = [self.trg_dict.get(w, UNK_IDX) for w in trg_words]
# remove sequence whose length > 80 in training mode
if len(src_ids) > 80 or len(trg_ids) > 80:
continue
trg_ids_next = trg_ids + [self.trg_dict[END]]
trg_ids = [self.trg_dict[START]] + trg_ids
self.src_ids.append(src_ids)
self.trg_ids.append(trg_ids)
self.trg_ids_next.append(trg_ids_next)
def __getitem__(self, idx):
return (np.array(self.src_ids[idx]), np.array(self.trg_ids[idx]),
np.array(self.trg_ids_next[idx]))
def __len__(self):
return len(self.src_ids)
def get_dict(self, reverse=False):
if reverse:
src_dict = {v: k for k, v in six.iteritems(src_dict)}
trg_dict = {v: k for k, v in six.iteritems(trg_dict)}
return src_dict, trg_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
from __future__ import print_function
import os
import six
import tarfile
import numpy as np
from collections import defaultdict
import paddle
from paddle.io import Dataset
import paddle.compat as cpt
from .utils import _check_exists_and_download
__all__ = ['WMT16']
DATA_URL = ("http://paddlemodels.bj.bcebos.com/wmt/wmt16.tar.gz")
DATA_MD5 = "0c38be43600334966403524a40dcd81e"
TOTAL_EN_WORDS = 11250
TOTAL_DE_WORDS = 19220
START_MARK = "<s>"
END_MARK = "<e>"
UNK_MARK = "<unk>"
class WMT16(Dataset):
"""
Implementation of `WMT16 <http://www.statmt.org/wmt16/>`_ test dataset.
ACL2016 Multimodal Machine Translation. Please see this website for more
details: http://www.statmt.org/wmt16/multimodal-task.html#task1
If you use the dataset created for your task, please cite the following paper:
Multi30K: Multilingual English-German Image Descriptions.
.. code-block:: text
@article{elliott-EtAl:2016:VL16,
author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.},
title = {Multi30K: Multilingual English-German Image Descriptions},
booktitle = {Proceedings of the 6th Workshop on Vision and Language},
year = {2016},
pages = {70--74},
year = 2016
}
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' or 'val'. Default 'train'
src_dict_size(int): word dictionary size for source language word. Default -1.
trg_dict_size(int): word dictionary size for target language word. Default -1.
lang(str): source language, 'en' or 'de'. Default 'en'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of WMT16 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import WMT16
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, src_ids, trg_ids, trg_ids_next):
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
paddle.disable_static()
wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50)
for i in range(10):
src_ids, trg_ids, trg_ids_next = wmt16[i]
src_ids = paddle.to_tensor(src_ids)
trg_ids = paddle.to_tensor(trg_ids)
trg_ids_next = paddle.to_tensor(trg_ids_next)
model = SimpleNet()
src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())
"""
def __init__(self,
data_file=None,
mode='train',
src_dict_size=-1,
trg_dict_size=-1,
lang='en',
download=True):
assert mode.lower() in ['train', 'test', 'val'], \
"mode should be 'train', 'test' or 'val', but got {}".format(mode)
self.mode = mode.lower()
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file, DATA_URL, DATA_MD5, 'wmt16', download)
self.lang = lang
assert src_dict_size > 0, "dict_size should be set as positive number"
assert trg_dict_size > 0, "dict_size should be set as positive number"
self.src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if lang == "en"
else TOTAL_DE_WORDS))
self.trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if lang == "en"
else TOTAL_EN_WORDS))
# load source and target word dict
self.src_dict = self._load_dict(lang, src_dict_size)
self.trg_dict = self._load_dict("de" if lang == "en" else "en",
trg_dict_size)
# load data
self.data = self._load_data()
def _load_dict(self, lang, dict_size, reverse=False):
dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or (
len(open(dict_path, "rb").readlines()) != dict_size):
self._build_dict(dict_path, dict_size, lang)
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = cpt.to_text(line.strip())
else:
word_dict[cpt.to_text(line.strip())] = idx
return word_dict
def _build_dict(self, dict_path, dict_size, lang):
word_dict = defaultdict(int)
with tarfile.open(self.data_file, mode="r") as f:
for line in f.extractfile("wmt16/train"):
line = cpt.to_text(line)
line_split = line.strip().split("\t")
if len(line_split) != 2: continue
sen = line_split[0] if self.lang == "en" else line_split[1]
for w in sen.split():
word_dict[w] += 1
with open(dict_path, "wb") as fout:
fout.write(
cpt.to_bytes("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)))
for idx, word in enumerate(
sorted(
six.iteritems(word_dict),
key=lambda x: x[1],
reverse=True)):
if idx + 3 == dict_size: break
fout.write(cpt.to_bytes(word[0]))
fout.write(cpt.to_bytes('\n'))
def _load_data(self):
# the index for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
# dictionary to determine their indices.
start_id = self.src_dict[START_MARK]
end_id = self.src_dict[END_MARK]
unk_id = self.src_dict[UNK_MARK]
src_col = 0 if self.lang == "en" else 1
trg_col = 1 - src_col
self.src_ids = []
self.trg_ids = []
self.trg_ids_next = []
with tarfile.open(self.data_file, mode="r") as f:
for line in f.extractfile("wmt16/{}".format(self.mode)):
line = cpt.to_text(line)
line_split = line.strip().split("\t")
if len(line_split) != 2:
continue
src_words = line_split[src_col].split()
src_ids = [start_id] + [
self.src_dict.get(w, unk_id) for w in src_words
] + [end_id]
trg_words = line_split[trg_col].split()
trg_ids = [self.trg_dict.get(w, unk_id) for w in trg_words]
trg_ids_next = trg_ids + [end_id]
trg_ids = [start_id] + trg_ids
self.src_ids.append(src_ids)
self.trg_ids.append(trg_ids)
self.trg_ids_next.append(trg_ids_next)
def __getitem__(self, idx):
return (np.array(self.src_ids[idx]), np.array(self.trg_ids[idx]),
np.array(self.trg_ids_next[idx]))
def __len__(self):
return len(self.src_ids)
def get_dict(self, lang, reverse=False):
"""
return the word dictionary for the specified language.
Args:
lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
reverse(bool): If reverse is set to False, the returned python
dictionary will use word as key and use index as value.
If reverse is set to True, the returned python
dictionary will use index as key and word as value.
Returns:
dict: The word dictionary for the specific language.
"""
dict_size = self.src_dict_size if lang == self.lang else self.trg_dict_size
dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
assert os.path.exists(dict_path), "Word dictionary does not exist. "
"Please invoke paddle.dataset.wmt16.train/test/validation first "
"to build the dictionary."
return _load_dict(lang, dict_size)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestCifar10Train(unittest.TestCase):
def test_main(self):
cifar = Cifar10(mode='train')
self.assertTrue(len(cifar) == 50000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 50000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(0 <= int(label) <= 9)
class TestCifar10Test(unittest.TestCase):
def test_main(self):
cifar = Cifar10(mode='test')
self.assertTrue(len(cifar) == 10000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(0 <= int(label) <= 9)
class TestCifar100Train(unittest.TestCase):
def test_main(self):
cifar = Cifar100(mode='train')
self.assertTrue(len(cifar) == 50000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 50000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(0 <= int(label) <= 99)
class TestCifar100Test(unittest.TestCase):
def test_main(self):
cifar = Cifar100(mode='test')
self.assertTrue(len(cifar) == 10000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(0 <= int(label) <= 99)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestConll05st(unittest.TestCase):
def test_main(self):
conll05st = Conll05st()
self.assertTrue(len(conll05st) == 5267)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 5267)
sample = conll05st[idx]
self.assertTrue(len(sample) == 9)
for s in sample:
self.assertTrue(len(s.shape) == 1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestImdbTrain(unittest.TestCase):
def test_main(self):
imdb = Imdb(mode='train')
self.assertTrue(len(imdb) == 25000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 25000)
data, label = imdb[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(int(label) in [0, 1])
class TestImdbTest(unittest.TestCase):
def test_main(self):
imdb = Imdb(mode='test')
self.assertTrue(len(imdb) == 25000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 25000)
data, label = imdb[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(int(label) in [0, 1])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestImikolovTrain(unittest.TestCase):
def test_main(self):
imikolov = Imikolov(mode='train', data_type='NGRAM', window_size=2)
self.assertTrue(len(imikolov) == 929589)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 929589)
data = imikolov[idx]
self.assertTrue(len(data) == 2)
class TestImikolovTest(unittest.TestCase):
def test_main(self):
imikolov = Imikolov(mode='test', data_type='NGRAM', window_size=2)
self.assertTrue(len(imikolov) == 82430)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 82430)
data = imikolov[idx]
self.assertTrue(len(data) == 2)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestMovieReviewsTrain(unittest.TestCase):
def test_main(self):
movie_reviews = MovieReviews(mode='train')
self.assertTrue(len(movie_reviews) == 1600)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1600)
data = movie_reviews[idx]
self.assertTrue(len(data) == 2)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(int(data[1]) in [0, 1])
class TestMovieReviewsTest(unittest.TestCase):
def test_main(self):
movie_reviews = MovieReviews(mode='test')
self.assertTrue(len(movie_reviews) == 400)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 400)
data = movie_reviews[idx]
self.assertTrue(len(data) == 2)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(int(data[1]) in [0, 1])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestMovielensTrain(unittest.TestCase):
def test_main(self):
movielens = Movielens(mode='train')
# movielens dataset random split train/test
# not check dataset length here
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 900000)
data = movielens[idx]
self.assertTrue(len(data) == 8)
for i, d in enumerate(data):
self.assertTrue(len(d.shape) == 1)
if i not in [5, 6]:
self.assertTrue(d.shape[0] == 1)
class TestMovielensTest(unittest.TestCase):
def test_main(self):
movielens = Movielens(mode='test')
# movielens dataset random split train/test
# not check dataset length here
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 100000)
data = movielens[idx]
self.assertTrue(len(data) == 8)
for i, d in enumerate(data):
self.assertTrue(len(d.shape) == 1)
if i not in [5, 6]:
self.assertTrue(d.shape[0] == 1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestUCIHousingTrain(unittest.TestCase):
def test_main(self):
uci_housing = UCIHousing(mode='train')
self.assertTrue(len(uci_housing) == 404)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 404)
data = uci_housing[idx]
self.assertTrue(len(data) == 2)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(data[0].shape[0] == 13)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(data[1].shape[0] == 1)
class TestUCIHousingTest(unittest.TestCase):
def test_main(self):
uci_housing = UCIHousing(mode='test')
self.assertTrue(len(uci_housing) == 102)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 102)
data = uci_housing[idx]
self.assertTrue(len(data) == 2)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(data[0].shape[0] == 13)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(data[1].shape[0] == 1)
class TestWMT14Train(unittest.TestCase):
def test_main(self):
wmt14 = WMT14(mode='train', dict_size=50)
self.assertTrue(len(wmt14) == 191155)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 191155)
data = wmt14[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
class TestWMT14Test(unittest.TestCase):
def test_main(self):
wmt14 = WMT14(mode='test', dict_size=50)
self.assertTrue(len(wmt14) == 5957)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 5957)
data = wmt14[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
class TestWMT14Gen(unittest.TestCase):
def test_main(self):
wmt14 = WMT14(mode='gen', dict_size=50)
self.assertTrue(len(wmt14) == 3001)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 3001)
data = wmt14[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import voc2012, VOC2012
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
# VOC2012 is too large for unittest to download, stub a small dataset here
voc2012.VOC_URL = 'https://paddlemodels.bj.bcebos.com/voc2012_stub/VOCtrainval_11-May-2012.tar'
voc2012.VOC_MD5 = '34cb1fe5bdc139a5454b25b16118fff8'
class TestVOC2012Train(unittest.TestCase):
def test_main(self):
voc2012 = VOC2012(mode='train')
self.assertTrue(len(voc2012) == 3)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 3)
image, label = voc2012[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
class TestVOC2012Valid(unittest.TestCase):
def test_main(self):
voc2012 = VOC2012(mode='valid')
self.assertTrue(len(voc2012) == 1)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1)
image, label = voc2012[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
class TestVOC2012Test(unittest.TestCase):
def test_main(self):
voc2012 = VOC2012(mode='test')
self.assertTrue(len(voc2012) == 2)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1)
image, label = voc2012[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import tempfile
import shutil
import cv2
from paddle.incubate.hapi.datasets import *
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
class TestWMT14Train(unittest.TestCase):
def test_main(self):
wmt14 = WMT14(mode='train', dict_size=50)
self.assertTrue(len(wmt14) == 191155)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 191155)
data = wmt14[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
class TestWMT14Test(unittest.TestCase):
def test_main(self):
wmt14 = WMT14(mode='test', dict_size=50)
self.assertTrue(len(wmt14) == 5957)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 5957)
data = wmt14[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
class TestWMT14Gen(unittest.TestCase):
def test_main(self):
wmt14 = WMT14(mode='gen', dict_size=50)
self.assertTrue(len(wmt14) == 3001)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 3001)
data = wmt14[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
class TestWMT16Train(unittest.TestCase):
def test_main(self):
wmt16 = WMT16(
mode='train', src_dict_size=50, trg_dict_size=50, lang='en')
self.assertTrue(len(wmt16) == 29000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 29000)
data = wmt16[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
class TestWMT16Test(unittest.TestCase):
def test_main(self):
wmt16 = WMT16(
mode='test', src_dict_size=50, trg_dict_size=50, lang='en')
self.assertTrue(len(wmt16) == 1000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1000)
data = wmt16[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
class TestWMT16Val(unittest.TestCase):
def test_main(self):
wmt16 = WMT16(mode='val', src_dict_size=50, trg_dict_size=50, lang='en')
self.assertTrue(len(wmt16) == 1014)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1014)
data = wmt16[idx]
self.assertTrue(len(data) == 3)
self.assertTrue(len(data[0].shape) == 1)
self.assertTrue(len(data[1].shape) == 1)
self.assertTrue(len(data[2].shape) == 1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册