Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1a72a903
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a72a903
编写于
8月 21, 2020
作者:
K
Kaipeng Deng
提交者:
GitHub
8月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add map style dataset (#26004)
* add map_style dataset. test=develop
上级
644dfd7d
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
2564 addition
and
18 deletion
+2564
-18
python/paddle/incubate/hapi/datasets/__init__.py
python/paddle/incubate/hapi/datasets/__init__.py
+32
-2
python/paddle/incubate/hapi/datasets/cifar.py
python/paddle/incubate/hapi/datasets/cifar.py
+207
-0
python/paddle/incubate/hapi/datasets/conll05.py
python/paddle/incubate/hapi/datasets/conll05.py
+297
-0
python/paddle/incubate/hapi/datasets/flowers.py
python/paddle/incubate/hapi/datasets/flowers.py
+9
-8
python/paddle/incubate/hapi/datasets/imdb.py
python/paddle/incubate/hapi/datasets/imdb.py
+144
-0
python/paddle/incubate/hapi/datasets/imikolov.py
python/paddle/incubate/hapi/datasets/imikolov.py
+171
-0
python/paddle/incubate/hapi/datasets/mnist.py
python/paddle/incubate/hapi/datasets/mnist.py
+7
-8
python/paddle/incubate/hapi/datasets/movie_reviews.py
python/paddle/incubate/hapi/datasets/movie_reviews.py
+173
-0
python/paddle/incubate/hapi/datasets/movielens.py
python/paddle/incubate/hapi/datasets/movielens.py
+219
-0
python/paddle/incubate/hapi/datasets/uci_housing.py
python/paddle/incubate/hapi/datasets/uci_housing.py
+110
-0
python/paddle/incubate/hapi/datasets/voc2012.py
python/paddle/incubate/hapi/datasets/voc2012.py
+133
-0
python/paddle/incubate/hapi/datasets/wmt14.py
python/paddle/incubate/hapi/datasets/wmt14.py
+179
-0
python/paddle/incubate/hapi/datasets/wmt16.py
python/paddle/incubate/hapi/datasets/wmt16.py
+244
-0
python/paddle/incubate/hapi/tests/test_dataset_cifar.py
python/paddle/incubate/hapi/tests/test_dataset_cifar.py
+83
-0
python/paddle/incubate/hapi/tests/test_dataset_conll05.py
python/paddle/incubate/hapi/tests/test_dataset_conll05.py
+41
-0
python/paddle/incubate/hapi/tests/test_dataset_imdb.py
python/paddle/incubate/hapi/tests/test_dataset_imdb.py
+55
-0
python/paddle/incubate/hapi/tests/test_dataset_imikolov.py
python/paddle/incubate/hapi/tests/test_dataset_imikolov.py
+51
-0
python/paddle/incubate/hapi/tests/test_dataset_movie_reviews.py
.../paddle/incubate/hapi/tests/test_dataset_movie_reviews.py
+55
-0
python/paddle/incubate/hapi/tests/test_dataset_movielens.py
python/paddle/incubate/hapi/tests/test_dataset_movielens.py
+61
-0
python/paddle/incubate/hapi/tests/test_dataset_uci_housing.py
...on/paddle/incubate/hapi/tests/test_dataset_uci_housing.py
+104
-0
python/paddle/incubate/hapi/tests/test_dataset_voc.py
python/paddle/incubate/hapi/tests/test_dataset_voc.py
+70
-0
python/paddle/incubate/hapi/tests/test_dataset_wmt.py
python/paddle/incubate/hapi/tests/test_dataset_wmt.py
+119
-0
未找到文件。
python/paddle/incubate/hapi/datasets/__init__.py
浏览文件 @
1a72a903
...
...
@@ -15,11 +15,41 @@
from
.
import
folder
from
.
import
mnist
from
.
import
flowers
from
.
import
cifar
from
.
import
voc2012
from
.
import
conll05
from
.
import
imdb
from
.
import
imikolov
from
.
import
movielens
from
.
import
movie_reviews
from
.
import
uci_housing
from
.
import
wmt14
from
.
import
wmt16
from
.folder
import
*
from
.mnist
import
*
from
.flowers
import
*
from
.cifar
import
*
from
.voc2012
import
*
from
.conll05
import
*
from
.imdb
import
*
from
.imikolov
import
*
from
.movielens
import
*
from
.movie_reviews
import
*
from
.uci_housing
import
*
from
.wmt14
import
*
from
.wmt16
import
*
__all__
=
folder
.
__all__
\
+
mnist
.
__all__
\
+
flowers
.
__all__
+
mnist
.
__all__
\
+
flowers
.
__all__
\
+
cifar
.
__all__
\
+
voc2012
.
__all__
\
+
conll05
.
__all__
\
+
imdb
.
__all__
\
+
imikolov
.
__all__
\
+
movielens
.
__all__
\
+
movie_reviews
.
__all__
\
+
uci_housing
.
__all__
\
+
wmt14
.
__all__
\
+
wmt16
.
__all__
python/paddle/incubate/hapi/datasets/cifar.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
tarfile
import
numpy
as
np
import
six
from
six.moves
import
cPickle
as
pickle
from
paddle.io
import
Dataset
from
.utils
import
_check_exists_and_download
__all__
=
[
'Cifar10'
,
'Cifar100'
]
URL_PREFIX
=
'https://dataset.bj.bcebos.com/cifar/'
CIFAR10_URL
=
URL_PREFIX
+
'cifar-10-python.tar.gz'
CIFAR10_MD5
=
'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL
=
URL_PREFIX
+
'cifar-100-python.tar.gz'
CIFAR100_MD5
=
'eb9058c3a382ffc7106e4002c42a8d85'
MODE_FLAG_MAP
=
{
'train10'
:
'data_batch'
,
'test10'
:
'test_batch'
,
'train100'
:
'train'
,
'test100'
:
'test'
}
class
Cifar10
(
Dataset
):
"""
Implementation of `Cifar-10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
dataset, which has 10 categories.
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' mode. Default 'train'.
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of cifar-10 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Cifar10
from paddle.incubate.hapi.vision.transforms import Normalize
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = paddle.nn.Linear(3072, 10, act='softmax')
def forward(self, image, label):
image = paddle.reshape(image, (3, -1))
return self.fc(image), label
paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
cifar10 = Cifar10(mode='train', transform=normalize)
for i in range(10):
image, label = cifar10[i]
image = paddle.to_tensor(image)
label = paddle.to_tensor(label)
model = SimpleNet()
image, label = model(image, label)
print(image.numpy().shape, label.numpy().shape)
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
transform
=
None
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
,
'train'
,
'test'
],
\
"mode should be 'train10', 'test10', 'train100' or 'test100', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
_init_url_md5_flag
()
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
self
.
data_url
,
self
.
data_md5
,
'cifar'
,
download
)
self
.
transform
=
transform
# read dataset into memory
self
.
_load_data
()
def
_init_url_md5_flag
(
self
):
self
.
data_url
=
CIFAR10_URL
self
.
data_md5
=
CIFAR10_MD5
self
.
flag
=
MODE_FLAG_MAP
[
self
.
mode
+
'10'
]
def
_load_data
(
self
):
self
.
data
=
[]
with
tarfile
.
open
(
self
.
data_file
,
mode
=
'r'
)
as
f
:
names
=
(
each_item
.
name
for
each_item
in
f
if
self
.
flag
in
each_item
.
name
)
for
name
in
names
:
if
six
.
PY2
:
batch
=
pickle
.
load
(
f
.
extractfile
(
name
))
else
:
batch
=
pickle
.
load
(
f
.
extractfile
(
name
),
encoding
=
'bytes'
)
data
=
batch
[
six
.
b
(
'data'
)]
labels
=
batch
.
get
(
six
.
b
(
'labels'
),
batch
.
get
(
six
.
b
(
'fine_labels'
),
None
))
assert
labels
is
not
None
for
sample
,
label
in
six
.
moves
.
zip
(
data
,
labels
):
self
.
data
.
append
((
sample
,
label
))
def
__getitem__
(
self
,
idx
):
image
,
label
=
self
.
data
[
idx
]
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
return
image
,
label
def
__len__
(
self
):
return
len
(
self
.
data
)
class
Cifar100
(
Cifar10
):
"""
Implementation of `Cifar-100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
dataset, which has 100 categories.
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' mode. Default 'train'.
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of cifar-100 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Cifar100
from paddle.incubate.hapi.vision.transforms import Normalize
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = paddle.nn.Linear(3072, 100, act='softmax')
def forward(self, image, label):
image = paddle.reshape(image, (3, -1))
return self.fc(image), label
paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
cifar100 = Cifar100(mode='train', transform=normalize)
for i in range(10):
image, label = cifar100[i]
image = paddle.to_tensor(image)
label = paddle.to_tensor(label)
model = SimpleNet()
image, label = model(image, label)
print(image.numpy().shape, label.numpy().shape)
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
transform
=
None
,
download
=
True
):
super
(
Cifar100
,
self
).
__init__
(
data_file
,
mode
,
transform
,
download
)
def
_init_url_md5_flag
(
self
):
self
.
data_url
=
CIFAR100_URL
self
.
data_md5
=
CIFAR100_MD5
self
.
flag
=
MODE_FLAG_MAP
[
self
.
mode
+
'100'
]
python/paddle/incubate/hapi/datasets/conll05.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
gzip
import
tarfile
import
numpy
as
np
import
six
from
six.moves
import
cPickle
as
pickle
from
paddle.io
import
Dataset
import
paddle.compat
as
cpt
from
.utils
import
_check_exists_and_download
__all__
=
[
'Conll05st'
]
DATA_URL
=
'http://paddlemodels.bj.bcebos.com/conll05st/conll05st-tests.tar.gz'
DATA_MD5
=
'387719152ae52d60422c016e92a742fc'
WORDDICT_URL
=
'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt'
WORDDICT_MD5
=
'ea7fb7d4c75cc6254716f0177a506baa'
VERBDICT_URL
=
'http://paddlemodels.bj.bcebos.com/conll05st%2FverbDict.txt'
VERBDICT_MD5
=
'0d2977293bbb6cbefab5b0f97db1e77c'
TRGDICT_URL
=
'http://paddlemodels.bj.bcebos.com/conll05st%2FtargetDict.txt'
TRGDICT_MD5
=
'd8c7f03ceb5fc2e5a0fa7503a4353751'
EMB_URL
=
'http://paddlemodels.bj.bcebos.com/conll05st%2Femb'
EMB_MD5
=
'bf436eb0faa1f6f9103017f8be57cdb7'
UNK_IDX
=
0
class
Conll05st
(
Dataset
):
"""
Implementation of `Conll05st <https://www.cs.upc.edu/~srlconll/soft.html>`_
test dataset.
Note: only support download test dataset automatically for that
only test dataset of Conll05st is public.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
word_dict_file(str): path to word dictionary file, can be set None if
:attr:`download` is True. Default None
verb_dict_file(str): path to verb dictionary file, can be set None if
:attr:`download` is True. Default None
target_dict_file(str): path to target dictionary file, can be set None if
:attr:`download` is True. Default None
emb_file(str): path to embedding dictionary file, only used for
:code:`get_embedding` can be set None if :attr:`download` is
True. Default None
download(bool): whether to download dataset automatically if
:attr:`data_file` :attr:`word_dict_file` :attr:`verb_dict_file`
:attr:`target_dict_file` is not set. Default True
Returns:
Dataset: instance of conll05st dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Conll05st
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, pred_idx, mark, label):
return paddle.sum(pred_idx), paddle.sum(mark), paddle.sum(label)
paddle.disable_static()
conll05st = Conll05st()
for i in range(10):
pred_idx, mark, label= conll05st[i][-3:]
pred_idx = paddle.to_tensor(pred_idx)
mark = paddle.to_tensor(mark)
label = paddle.to_tensor(label)
model = SimpleNet()
pred_idx, mark, label= model(pred_idx, mark, label)
print(pred_idx.numpy(), mark.numpy(), label.numpy())
"""
def
__init__
(
self
,
data_file
=
None
,
word_dict_file
=
None
,
verb_dict_file
=
None
,
target_dict_file
=
None
,
emb_file
=
None
,
download
=
True
):
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
DATA_URL
,
DATA_MD5
,
'conll05st'
,
download
)
self
.
word_dict_file
=
word_dict_file
if
self
.
word_dict_file
is
None
:
assert
download
,
"word_dict_file is not set and downloading automatically is disabled"
self
.
word_dict_file
=
_check_exists_and_download
(
word_dict_file
,
WORDDICT_URL
,
WORDDICT_MD5
,
'conll05st'
,
download
)
self
.
verb_dict_file
=
verb_dict_file
if
self
.
verb_dict_file
is
None
:
assert
download
,
"verb_dict_file is not set and downloading automatically is disabled"
self
.
verb_dict_file
=
_check_exists_and_download
(
verb_dict_file
,
VERBDICT_URL
,
VERBDICT_MD5
,
'conll05st'
,
download
)
self
.
target_dict_file
=
target_dict_file
if
self
.
target_dict_file
is
None
:
assert
download
,
"target_dict_file is not set and downloading automatically is disabled"
self
.
target_dict_file
=
_check_exists_and_download
(
target_dict_file
,
TRGDICT_URL
,
TRGDICT_MD5
,
'conll05st'
,
download
)
self
.
word_dict
=
self
.
_load_dict
(
self
.
word_dict_file
)
self
.
predicate_dict
=
self
.
_load_dict
(
self
.
verb_dict_file
)
self
.
label_dict
=
self
.
_load_label_dict
(
self
.
target_dict_file
)
# read dataset into memory
self
.
_load_anno
()
def
_load_label_dict
(
self
,
filename
):
d
=
dict
()
tag_dict
=
set
()
with
open
(
filename
,
'r'
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
if
line
.
startswith
(
"B-"
):
tag_dict
.
add
(
line
[
2
:])
elif
line
.
startswith
(
"I-"
):
tag_dict
.
add
(
line
[
2
:])
index
=
0
for
tag
in
tag_dict
:
d
[
"B-"
+
tag
]
=
index
index
+=
1
d
[
"I-"
+
tag
]
=
index
index
+=
1
d
[
"O"
]
=
index
return
d
def
_load_dict
(
self
,
filename
):
d
=
dict
()
with
open
(
filename
,
'r'
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
d
[
line
.
strip
()]
=
i
return
d
def
_load_anno
(
self
):
tf
=
tarfile
.
open
(
self
.
data_file
)
wf
=
tf
.
extractfile
(
"conll05st-release/test.wsj/words/test.wsj.words.gz"
)
pf
=
tf
.
extractfile
(
"conll05st-release/test.wsj/props/test.wsj.props.gz"
)
self
.
sentences
=
[]
self
.
predicates
=
[]
self
.
labels
=
[]
with
gzip
.
GzipFile
(
fileobj
=
wf
)
as
words_file
,
gzip
.
GzipFile
(
fileobj
=
pf
)
as
props_file
:
sentences
=
[]
labels
=
[]
one_seg
=
[]
for
word
,
label
in
zip
(
words_file
,
props_file
):
word
=
cpt
.
to_text
(
word
.
strip
())
label
=
cpt
.
to_text
(
label
.
strip
().
split
())
if
len
(
label
)
==
0
:
# end of sentence
for
i
in
range
(
len
(
one_seg
[
0
])):
a_kind_lable
=
[
x
[
i
]
for
x
in
one_seg
]
labels
.
append
(
a_kind_lable
)
if
len
(
labels
)
>=
1
:
verb_list
=
[]
for
x
in
labels
[
0
]:
if
x
!=
'-'
:
verb_list
.
append
(
x
)
for
i
,
lbl
in
enumerate
(
labels
[
1
:]):
cur_tag
=
'O'
is_in_bracket
=
False
lbl_seq
=
[]
verb_word
=
''
for
l
in
lbl
:
if
l
==
'*'
and
is_in_bracket
==
False
:
lbl_seq
.
append
(
'O'
)
elif
l
==
'*'
and
is_in_bracket
==
True
:
lbl_seq
.
append
(
'I-'
+
cur_tag
)
elif
l
==
'*)'
:
lbl_seq
.
append
(
'I-'
+
cur_tag
)
is_in_bracket
=
False
elif
l
.
find
(
'('
)
!=
-
1
and
l
.
find
(
')'
)
!=
-
1
:
cur_tag
=
l
[
1
:
l
.
find
(
'*'
)]
lbl_seq
.
append
(
'B-'
+
cur_tag
)
is_in_bracket
=
False
elif
l
.
find
(
'('
)
!=
-
1
and
l
.
find
(
')'
)
==
-
1
:
cur_tag
=
l
[
1
:
l
.
find
(
'*'
)]
lbl_seq
.
append
(
'B-'
+
cur_tag
)
is_in_bracket
=
True
else
:
raise
RuntimeError
(
'Unexpected label: %s'
%
l
)
self
.
sentences
.
append
(
sentences
)
self
.
predicates
.
append
(
verb_list
[
i
])
self
.
labels
.
append
(
lbl_seq
)
sentences
=
[]
labels
=
[]
one_seg
=
[]
else
:
sentences
.
append
(
word
)
one_seg
.
append
(
label
)
pf
.
close
()
wf
.
close
()
tf
.
close
()
def
__getitem__
(
self
,
idx
):
sentence
=
self
.
sentences
[
idx
]
predicate
=
self
.
predicates
[
idx
]
labels
=
self
.
labels
[
idx
]
sen_len
=
len
(
sentence
)
verb_index
=
labels
.
index
(
'B-V'
)
mark
=
[
0
]
*
len
(
labels
)
if
verb_index
>
0
:
mark
[
verb_index
-
1
]
=
1
ctx_n1
=
sentence
[
verb_index
-
1
]
else
:
ctx_n1
=
'bos'
if
verb_index
>
1
:
mark
[
verb_index
-
2
]
=
1
ctx_n2
=
sentence
[
verb_index
-
2
]
else
:
ctx_n2
=
'bos'
mark
[
verb_index
]
=
1
ctx_0
=
sentence
[
verb_index
]
if
verb_index
<
len
(
labels
)
-
1
:
mark
[
verb_index
+
1
]
=
1
ctx_p1
=
sentence
[
verb_index
+
1
]
else
:
ctx_p1
=
'eos'
if
verb_index
<
len
(
labels
)
-
2
:
mark
[
verb_index
+
2
]
=
1
ctx_p2
=
sentence
[
verb_index
+
2
]
else
:
ctx_p2
=
'eos'
word_idx
=
[
self
.
word_dict
.
get
(
w
,
UNK_IDX
)
for
w
in
sentence
]
ctx_n2_idx
=
[
self
.
word_dict
.
get
(
ctx_n2
,
UNK_IDX
)]
*
sen_len
ctx_n1_idx
=
[
self
.
word_dict
.
get
(
ctx_n1
,
UNK_IDX
)]
*
sen_len
ctx_0_idx
=
[
self
.
word_dict
.
get
(
ctx_0
,
UNK_IDX
)]
*
sen_len
ctx_p1_idx
=
[
self
.
word_dict
.
get
(
ctx_p1
,
UNK_IDX
)]
*
sen_len
ctx_p2_idx
=
[
self
.
word_dict
.
get
(
ctx_p2
,
UNK_IDX
)]
*
sen_len
pred_idx
=
[
self
.
predicate_dict
.
get
(
predicate
)]
*
sen_len
label_idx
=
[
self
.
label_dict
.
get
(
w
)
for
w
in
labels
]
return
(
np
.
array
(
word_idx
),
np
.
array
(
ctx_n2_idx
),
np
.
array
(
ctx_n1_idx
),
np
.
array
(
ctx_0_idx
),
np
.
array
(
ctx_p1_idx
),
np
.
array
(
ctx_p2_idx
),
np
.
array
(
pred_idx
),
np
.
array
(
mark
),
np
.
array
(
label_idx
))
def
__len__
(
self
):
return
len
(
self
.
sentences
)
def
get_dict
(
self
):
"""
Get the word, verb and label dictionary of Wikipedia corpus.
"""
return
self
.
word_dict
,
self
.
predicate_dict
,
self
.
label_dict
def
get_embedding
(
self
):
return
self
.
emb_file
python/paddle/incubate/hapi/datasets/flowers.py
浏览文件 @
1a72a903
...
...
@@ -36,12 +36,13 @@ SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
MODE_FLAG_MAP
=
{
'train'
:
'tstid'
,
'test'
:
'trnid'
,
'valid'
:
"valid"
}
MODE_FLAG_MAP
=
{
'train'
:
'tstid'
,
'test'
:
'trnid'
,
'valid'
:
'valid'
}
class
Flowers
(
Dataset
):
"""
Implement of flowers dataset
Implementation of `Flowers <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
dataset
Args:
data_file(str): path to data file, can be set None if
...
...
@@ -51,9 +52,9 @@ class Flowers(Dataset):
setid_file(str): path to subset index file, can be set
None if :attr:`download` is True. Default None
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
True
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default
True
Examples:
...
...
@@ -82,19 +83,19 @@ class Flowers(Dataset):
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file
not set and auto download
disabled"
assert
download
,
"data_file
is not set and downloading automatically is
disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
DATA_URL
,
DATA_MD5
,
'flowers'
,
download
)
self
.
label_file
=
label_file
if
self
.
label_file
is
None
:
assert
download
,
"label_file
not set and auto download
disabled"
assert
download
,
"label_file
is not set and downloading automatically is
disabled"
self
.
label_file
=
_check_exists_and_download
(
label_file
,
LABEL_URL
,
LABEL_MD5
,
'flowers'
,
download
)
self
.
setid_file
=
setid_file
if
self
.
setid_file
is
None
:
assert
download
,
"setid_file
not set and auto download
disabled"
assert
download
,
"setid_file
is not set and downloading automatically is
disabled"
self
.
setid_file
=
_check_exists_and_download
(
setid_file
,
SETID_URL
,
SETID_MD5
,
'flowers'
,
download
)
...
...
python/paddle/incubate/hapi/datasets/imdb.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
re
import
six
import
string
import
tarfile
import
numpy
as
np
import
collections
from
paddle.io
import
Dataset
from
.utils
import
_check_exists_and_download
__all__
=
[
'Imdb'
]
URL
=
'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
MD5
=
'7c2ac02c03563afcf9b574c7e56c153a'
class
Imdb
(
Dataset
):
"""
Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' 'test' mode. Default 'train'.
cutoff(int): cutoff number for building word dictionary. Default 150.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of IMDB dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Imdb
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, doc, label):
return paddle.sum(doc), label
paddle.disable_static()
imdb = Imdb(mode='train')
for i in range(10):
doc, label = imdb[i]
doc = paddle.to_tensor(doc)
label = paddle.to_tensor(label)
model = SimpleNet()
image, label = model(doc, label)
print(doc.numpy().shape, label.numpy().shape)
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
cutoff
=
150
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train', 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
URL
,
MD5
,
'imdb'
,
download
)
# Build a word dictionary from the corpus
self
.
word_idx
=
self
.
_build_work_dict
(
cutoff
)
# read dataset into memory
self
.
_load_anno
()
def
_build_work_dict
(
self
,
cutoff
):
word_freq
=
collections
.
defaultdict
(
int
)
pattern
=
re
.
compile
(
"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"
)
for
doc
in
self
.
_tokenize
(
pattern
):
for
word
in
doc
:
word_freq
[
word
]
+=
1
# Not sure if we should prune less-frequent words here.
word_freq
=
[
x
for
x
in
six
.
iteritems
(
word_freq
)
if
x
[
1
]
>
cutoff
]
dictionary
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
dictionary
))
word_idx
=
dict
(
list
(
zip
(
words
,
six
.
moves
.
range
(
len
(
words
)))))
word_idx
[
'<unk>'
]
=
len
(
words
)
return
word_idx
def
_tokenize
(
self
,
pattern
):
data
=
[]
with
tarfile
.
open
(
self
.
data_file
)
as
tarf
:
tf
=
tarf
.
next
()
while
tf
!=
None
:
if
bool
(
pattern
.
match
(
tf
.
name
)):
# newline and punctuations removal and ad-hoc tokenization.
data
.
append
(
tarf
.
extractfile
(
tf
).
read
().
rstrip
(
six
.
b
(
"
\n\r
"
))
.
translate
(
None
,
six
.
b
(
string
.
punctuation
)).
lower
(
).
split
())
tf
=
tarf
.
next
()
return
data
def
_load_anno
(
self
):
pos_pattern
=
re
.
compile
(
"aclImdb/{}/pos/.*\.txt$"
.
format
(
self
.
mode
))
neg_pattern
=
re
.
compile
(
"aclImdb/{}/neg/.*\.txt$"
.
format
(
self
.
mode
))
UNK
=
self
.
word_idx
[
'<unk>'
]
self
.
docs
=
[]
self
.
labels
=
[]
for
doc
in
self
.
_tokenize
(
pos_pattern
):
self
.
docs
.
append
([
self
.
word_idx
.
get
(
w
,
UNK
)
for
w
in
doc
])
self
.
labels
.
append
(
0
)
for
doc
in
self
.
_tokenize
(
neg_pattern
):
self
.
docs
.
append
([
self
.
word_idx
.
get
(
w
,
UNK
)
for
w
in
doc
])
self
.
labels
.
append
(
1
)
def
__getitem__
(
self
,
idx
):
return
(
np
.
array
(
self
.
docs
[
idx
]),
np
.
array
([
self
.
labels
[
idx
]]))
def
__len__
(
self
):
return
len
(
self
.
docs
)
python/paddle/incubate/hapi/datasets/imikolov.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
six
import
tarfile
import
numpy
as
np
import
collections
from
paddle.io
import
Dataset
from
.utils
import
_check_exists_and_download
__all__
=
[
'Imikolov'
]
URL
=
'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz'
MD5
=
'30177ea32e27c525793142b6bf2c8e2d'
class
Imikolov
(
Dataset
):
"""
Implementation of imikolov dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
data_type(str): 'NGRAM' or 'SEQ'. Default 'NGRAM'.
window_size(int): sliding window size for 'NGRAM' data. Default -1.
mode(str): 'train' 'test' mode. Default 'train'.
min_word_freq(int): minimal word frequence for building word dictionary. Default 50.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of imikolov dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Imikolov
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, src, trg):
return paddle.sum(src), paddle.sum(trg)
paddle.disable_static()
imikolov = Imikolov(mode='train', data_type='SEQ', window_size=2)
for i in range(10):
src, trg = imikolov[i]
src = paddle.to_tensor(src)
trg = paddle.to_tensor(trg)
model = SimpleNet()
src, trg = model(src, trg)
print(src.numpy().shape, trg.numpy().shape)
"""
def
__init__
(
self
,
data_file
=
None
,
data_type
=
'NGRAM'
,
window_size
=-
1
,
mode
=
'train'
,
min_word_freq
=
50
,
download
=
True
):
assert
data_type
.
upper
()
in
[
'NGRAM'
,
'SEQ'
],
\
"data type should be 'NGRAM', 'SEQ', but got {}"
.
format
(
data_type
)
self
.
data_type
=
data_type
.
upper
()
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train', 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
window_size
=
window_size
self
.
min_word_freq
=
min_word_freq
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
URL
,
MD5
,
'imikolov'
,
download
)
# Build a word dictionary from the corpus
self
.
word_idx
=
self
.
_build_work_dict
(
min_word_freq
)
# read dataset into memory
self
.
_load_anno
()
def
word_count
(
self
,
f
,
word_freq
=
None
):
if
word_freq
is
None
:
word_freq
=
collections
.
defaultdict
(
int
)
for
l
in
f
:
for
w
in
l
.
strip
().
split
():
word_freq
[
w
]
+=
1
word_freq
[
'<s>'
]
+=
1
word_freq
[
'<e>'
]
+=
1
return
word_freq
def
_build_work_dict
(
self
,
cutoff
):
train_filename
=
'./simple-examples/data/ptb.train.txt'
test_filename
=
'./simple-examples/data/ptb.valid.txt'
with
tarfile
.
open
(
self
.
data_file
)
as
tf
:
trainf
=
tf
.
extractfile
(
train_filename
)
testf
=
tf
.
extractfile
(
test_filename
)
word_freq
=
self
.
word_count
(
testf
,
self
.
word_count
(
trainf
))
if
'<unk>'
in
word_freq
:
# remove <unk> for now, since we will set it as last index
del
word_freq
[
'<unk>'
]
word_freq
=
[
x
for
x
in
six
.
iteritems
(
word_freq
)
if
x
[
1
]
>
self
.
min_word_freq
]
word_freq_sorted
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
word_freq_sorted
))
word_idx
=
dict
(
list
(
zip
(
words
,
six
.
moves
.
range
(
len
(
words
)))))
word_idx
[
'<unk>'
]
=
len
(
words
)
return
word_idx
def
_load_anno
(
self
):
self
.
data
=
[]
with
tarfile
.
open
(
self
.
data_file
)
as
tf
:
filename
=
'./simple-examples/data/ptb.{}.txt'
.
format
(
self
.
mode
)
f
=
tf
.
extractfile
(
filename
)
UNK
=
self
.
word_idx
[
'<unk>'
]
for
l
in
f
:
if
self
.
data_type
==
'NGRAM'
:
assert
self
.
window_size
>
-
1
,
'Invalid gram length'
l
=
[
'<s>'
]
+
l
.
strip
().
split
()
+
[
'<e>'
]
if
len
(
l
)
>=
self
.
window_size
:
l
=
[
self
.
word_idx
.
get
(
w
,
UNK
)
for
w
in
l
]
for
i
in
six
.
moves
.
range
(
self
.
window_size
,
len
(
l
)
+
1
):
self
.
data
.
append
(
tuple
(
l
[
i
-
self
.
window_size
:
i
]))
elif
self
.
data_type
==
'SEQ'
:
l
=
l
.
strip
().
split
()
l
=
[
self
.
word_idx
.
get
(
w
,
UNK
)
for
w
in
l
]
src_seq
=
[
self
.
word_idx
[
'<s>'
]]
+
l
trg_seq
=
l
+
[
self
.
word_idx
[
'<e>'
]]
if
self
.
window_size
>
0
and
len
(
src_seq
)
>
self
.
window_size
:
continue
self
.
data
.
append
((
src_seq
,
trg_seq
))
else
:
assert
False
,
'Unknow data type'
def
__getitem__
(
self
,
idx
):
return
tuple
([
np
.
array
(
d
)
for
d
in
self
.
data
[
idx
]])
def
__len__
(
self
):
return
len
(
self
.
data
)
python/paddle/incubate/hapi/datasets/mnist.py
浏览文件 @
1a72a903
...
...
@@ -38,7 +38,7 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
class
MNIST
(
Dataset
):
"""
Implement
of MNIST
dataset
Implement
ation of `MNIST <http://yann.lecun.com/exdb/mnist/>`_
dataset
Args:
image_path(str): path to image file, can be set None if
...
...
@@ -48,9 +48,8 @@ class MNIST(Dataset):
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether auto download mnist dataset if
:attr:`image_path`/:attr:`label_path` unset. Default
True
download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True
Returns:
Dataset: MNIST Dataset.
...
...
@@ -82,7 +81,7 @@ class MNIST(Dataset):
self
.
chw_format
=
chw_format
self
.
image_path
=
image_path
if
self
.
image_path
is
None
:
assert
download
,
"image_path
not set and auto download
disabled"
assert
download
,
"image_path
is not set and downloading automatically is
disabled"
image_url
=
TRAIN_IMAGE_URL
if
mode
==
'train'
else
TEST_IMAGE_URL
image_md5
=
TRAIN_IMAGE_MD5
if
mode
==
'train'
else
TEST_IMAGE_MD5
self
.
image_path
=
_check_exists_and_download
(
...
...
@@ -90,9 +89,9 @@ class MNIST(Dataset):
self
.
label_path
=
label_path
if
self
.
label_path
is
None
:
assert
download
,
"label_path
not set and auto download
disabled"
label_url
=
TRAIN_LABEL_URL
if
mode
==
'train'
else
TEST_LABEL_URL
label_md5
=
TRAIN_LABEL_MD5
if
mode
==
'train'
else
TEST_LABEL_MD5
assert
download
,
"label_path
is not set and downloading automatically is
disabled"
label_url
=
TRAIN_LABEL_URL
if
self
.
mode
==
'train'
else
TEST_LABEL_URL
label_md5
=
TRAIN_LABEL_MD5
if
self
.
mode
==
'train'
else
TEST_LABEL_MD5
self
.
label_path
=
_check_exists_and_download
(
label_path
,
label_url
,
label_md5
,
'mnist'
,
download
)
...
...
python/paddle/incubate/hapi/datasets/movie_reviews.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
six
import
numpy
as
np
import
collections
import
nltk
from
nltk.corpus
import
movie_reviews
import
zipfile
from
functools
import
cmp_to_key
from
itertools
import
chain
import
paddle
from
paddle.io
import
Dataset
__all__
=
[
'MovieReviews'
]
URL
=
"https://corpora.bj.bcebos.com/movie_reviews%2Fmovie_reviews.zip"
MD5
=
'155de2b77c6834dd8eea7cbe88e93acb'
NUM_TRAINING_INSTANCES
=
1600
NUM_TOTAL_INSTANCES
=
2000
class
MovieReviews
(
Dataset
):
"""
Implementation of `NLTK movie reviews <http://www.nltk.org/nltk_data/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' 'test' mode. Default 'train'.
download(bool): whether auto download cifar dataset if
:attr:`data_file` unset. Default True.
Returns:
Dataset: instance of movie reviews dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import MovieReviews
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, word, category):
return paddle.sum(word), category
paddle.disable_static()
movie_reviews = MovieReviews(mode='train')
for i in range(10):
word_list, category = movie_reviews[i]
word_list = paddle.to_tensor(word_list)
category = paddle.to_tensor(category)
model = SimpleNet()
word_list, category = model(word_list, category)
print(word_list.numpy().shape, category.numpy())
"""
def
__init__
(
self
,
mode
=
'train'
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train', 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
_download_data_if_not_yet
()
# read dataset into memory
self
.
_load_sentiment_data
()
def
_get_word_dict
(
self
):
"""
Sorted the words by the frequency of words which occur in sample
:return:
words_freq_sorted
"""
words_freq_sorted
=
list
()
word_freq_dict
=
collections
.
defaultdict
(
int
)
for
category
in
movie_reviews
.
categories
():
for
field
in
movie_reviews
.
fileids
(
category
):
for
words
in
movie_reviews
.
words
(
field
):
word_freq_dict
[
words
]
+=
1
words_sort_list
=
list
(
six
.
iteritems
(
word_freq_dict
))
words_sort_list
.
sort
(
key
=
cmp_to_key
(
lambda
a
,
b
:
b
[
1
]
-
a
[
1
]))
for
index
,
word
in
enumerate
(
words_sort_list
):
words_freq_sorted
.
append
((
word
[
0
],
index
))
return
words_freq_sorted
def
_sort_files
(
self
):
"""
Sorted the sample for cross reading the sample
:return:
files_list
"""
files_list
=
list
()
neg_file_list
=
movie_reviews
.
fileids
(
'neg'
)
pos_file_list
=
movie_reviews
.
fileids
(
'pos'
)
files_list
=
list
(
chain
.
from_iterable
(
list
(
zip
(
neg_file_list
,
pos_file_list
))))
return
files_list
def
_load_sentiment_data
(
self
):
"""
Load the data set
:return:
data_set
"""
self
.
data
=
[]
words_ids
=
dict
(
self
.
_get_word_dict
())
for
sample_file
in
self
.
_sort_files
():
words_list
=
list
()
category
=
0
if
'neg'
in
sample_file
else
1
for
word
in
movie_reviews
.
words
(
sample_file
):
words_list
.
append
(
words_ids
[
word
.
lower
()])
self
.
data
.
append
((
words_list
,
category
))
def
_download_data_if_not_yet
(
self
):
"""
Download the data set, if the data set is not download.
"""
try
:
# download and extract movie_reviews.zip
paddle
.
dataset
.
common
.
download
(
URL
,
'corpora'
,
md5sum
=
MD5
,
save_name
=
'movie_reviews.zip'
)
path
=
os
.
path
.
join
(
paddle
.
dataset
.
common
.
DATA_HOME
,
'corpora'
)
filename
=
os
.
path
.
join
(
path
,
'movie_reviews.zip'
)
zip_file
=
zipfile
.
ZipFile
(
filename
)
zip_file
.
extractall
(
path
)
zip_file
.
close
()
# make sure that nltk can find the data
if
paddle
.
dataset
.
common
.
DATA_HOME
not
in
nltk
.
data
.
path
:
nltk
.
data
.
path
.
append
(
paddle
.
dataset
.
common
.
DATA_HOME
)
movie_reviews
.
categories
()
except
LookupError
:
print
(
"Downloading movie_reviews data set, please wait....."
)
nltk
.
download
(
'movie_reviews'
,
download_dir
=
paddle
.
dataset
.
common
.
DATA_HOME
)
print
(
"Download data set success....."
)
print
(
"Path is "
+
nltk
.
data
.
find
(
'corpora/movie_reviews'
).
path
)
def
__getitem__
(
self
,
idx
):
if
self
.
mode
==
'test'
:
idx
+=
NUM_TRAINING_INSTANCES
data
=
self
.
data
[
idx
]
return
np
.
array
(
data
[
0
]),
np
.
array
(
data
[
1
])
def
__len__
(
self
):
if
self
.
mode
==
'train'
:
return
NUM_TRAINING_INSTANCES
else
:
return
NUM_TOTAL_INSTANCES
-
NUM_TRAINING_INSTANCES
python/paddle/incubate/hapi/datasets/movielens.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
zipfile
import
re
import
random
import
functools
import
six
import
paddle
from
paddle.io
import
Dataset
import
paddle.compat
as
cpt
from
.utils
import
_check_exists_and_download
__all__
=
[
'Movielens'
]
age_table
=
[
1
,
18
,
25
,
35
,
45
,
50
,
56
]
URL
=
'https://dataset.bj.bcebos.com/movielens%2Fml-1m.zip'
MD5
=
'c4d9eecfca2ab87c1945afe126590906'
class
MovieInfo
(
object
):
"""
Movie id, title and categories information are stored in MovieInfo.
"""
def
__init__
(
self
,
index
,
categories
,
title
):
self
.
index
=
int
(
index
)
self
.
categories
=
categories
self
.
title
=
title
def
value
(
self
,
categories_dict
,
movie_title_dict
):
"""
Get information from a movie.
"""
return
[[
self
.
index
],
[
categories_dict
[
c
]
for
c
in
self
.
categories
],
[
movie_title_dict
[
w
.
lower
()]
for
w
in
self
.
title
.
split
()]]
def
__str__
(
self
):
return
"<MovieInfo id(%d), title(%s), categories(%s)>"
%
(
self
.
index
,
self
.
title
,
self
.
categories
)
def
__repr__
(
self
):
return
self
.
__str__
()
class
UserInfo
(
object
):
"""
User id, gender, age, and job information are stored in UserInfo.
"""
def
__init__
(
self
,
index
,
gender
,
age
,
job_id
):
self
.
index
=
int
(
index
)
self
.
is_male
=
gender
==
'M'
self
.
age
=
age_table
.
index
(
int
(
age
))
self
.
job_id
=
int
(
job_id
)
def
value
(
self
):
"""
Get information from a user.
"""
return
[[
self
.
index
],
[
0
if
self
.
is_male
else
1
],
[
self
.
age
],
[
self
.
job_id
]]
def
__str__
(
self
):
return
"<UserInfo id(%d), gender(%s), age(%d), job(%d)>"
%
(
self
.
index
,
"M"
if
self
.
is_male
else
"F"
,
age_table
[
self
.
age
],
self
.
job_id
)
def
__repr__
(
self
):
return
str
(
self
)
class
Movielens
(
Dataset
):
"""
Implementation of `Movielens 1-M <https://grouplens.org/datasets/movielens/1m/>`_ dataset.
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' or 'test' mode. Default 'train'.
test_ratio(float): split ratio for test sample. Default 0.1.
rand_seed(int): random seed. Default 0.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of Movielens 1-M dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import Movielens
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, category, title, rating):
return paddle.sum(category), paddle.sum(title), paddle.sum(rating)
paddle.disable_static()
movielens = Movielens(mode='train')
for i in range(10):
category, title, rating = movielens[i][-3:]
category = paddle.to_tensor(category)
title = paddle.to_tensor(title)
rating = paddle.to_tensor(rating)
model = SimpleNet()
category, title, rating = model(category, title, rating)
print(category.numpy().shape, title.numpy().shape, rating.numpy().shape)
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
test_ratio
=
0.1
,
rand_seed
=
0
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train', 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
URL
,
MD5
,
'sentiment'
,
download
)
self
.
test_ratio
=
test_ratio
self
.
rand_seed
=
rand_seed
np
.
random
.
seed
(
rand_seed
)
self
.
_load_meta_info
()
self
.
_load_data
()
def
_load_meta_info
(
self
):
pattern
=
re
.
compile
(
r
'^(.*)\((\d+)\)$'
)
self
.
movie_info
=
dict
()
self
.
movie_title_dict
=
dict
()
self
.
categories_dict
=
dict
()
self
.
user_info
=
dict
()
with
zipfile
.
ZipFile
(
self
.
data_file
)
as
package
:
for
info
in
package
.
infolist
():
assert
isinstance
(
info
,
zipfile
.
ZipInfo
)
title_word_set
=
set
()
categories_set
=
set
()
with
package
.
open
(
'ml-1m/movies.dat'
)
as
movie_file
:
for
i
,
line
in
enumerate
(
movie_file
):
line
=
cpt
.
to_text
(
line
,
encoding
=
'latin'
)
movie_id
,
title
,
categories
=
line
.
strip
().
split
(
'::'
)
categories
=
categories
.
split
(
'|'
)
for
c
in
categories
:
categories_set
.
add
(
c
)
title
=
pattern
.
match
(
title
).
group
(
1
)
self
.
movie_info
[
int
(
movie_id
)]
=
MovieInfo
(
index
=
movie_id
,
categories
=
categories
,
title
=
title
)
for
w
in
title
.
split
():
title_word_set
.
add
(
w
.
lower
())
for
i
,
w
in
enumerate
(
title_word_set
):
self
.
movie_title_dict
[
w
]
=
i
for
i
,
c
in
enumerate
(
categories_set
):
self
.
categories_dict
[
c
]
=
i
with
package
.
open
(
'ml-1m/users.dat'
)
as
user_file
:
for
line
in
user_file
:
line
=
cpt
.
to_text
(
line
,
encoding
=
'latin'
)
uid
,
gender
,
age
,
job
,
_
=
line
.
strip
().
split
(
"::"
)
self
.
user_info
[
int
(
uid
)]
=
UserInfo
(
index
=
uid
,
gender
=
gender
,
age
=
age
,
job_id
=
job
)
def
_load_data
(
self
):
self
.
data
=
[]
is_test
=
self
.
mode
==
'test'
with
zipfile
.
ZipFile
(
self
.
data_file
)
as
package
:
with
package
.
open
(
'ml-1m/ratings.dat'
)
as
rating
:
for
line
in
rating
:
line
=
cpt
.
to_text
(
line
,
encoding
=
'latin'
)
if
(
np
.
random
.
random
()
<
self
.
test_ratio
)
==
is_test
:
uid
,
mov_id
,
rating
,
_
=
line
.
strip
().
split
(
"::"
)
uid
=
int
(
uid
)
mov_id
=
int
(
mov_id
)
rating
=
float
(
rating
)
*
2
-
5.0
mov
=
self
.
movie_info
[
mov_id
]
usr
=
self
.
user_info
[
uid
]
self
.
data
.
append
(
usr
.
value
()
+
\
mov
.
value
(
self
.
categories_dict
,
self
.
movie_title_dict
)
+
\
[[
rating
]])
def
__getitem__
(
self
,
idx
):
data
=
self
.
data
[
idx
]
return
tuple
([
np
.
array
(
d
)
for
d
in
data
])
def
__len__
(
self
):
return
len
(
self
.
data
)
python/paddle/incubate/hapi/datasets/uci_housing.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
six
import
numpy
as
np
import
paddle.dataset.common
from
paddle.io
import
Dataset
from
.utils
import
_check_exists_and_download
__all__
=
[
"UCIHousing"
]
URL
=
'http://paddlemodels.bj.bcebos.com/uci_housing/housing.data'
MD5
=
'd4accdce7a25600298819f8e28e8d593'
feature_names
=
[
'CRIM'
,
'ZN'
,
'INDUS'
,
'CHAS'
,
'NOX'
,
'RM'
,
'AGE'
,
'DIS'
,
'RAD'
,
'TAX'
,
'PTRATIO'
,
'B'
,
'LSTAT'
]
class
UCIHousing
(
Dataset
):
"""
Implementation of `UCI housing <https://archive.ics.uci.edu/ml/datasets/Housing>`_
dataset
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of UCI housing dataset.
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import UCIHousing
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, feature, target):
return paddle.sum(feature), target
paddle.disable_static()
uci_housing = UCIHousing(mode='train')
for i in range(10):
feature, target = uci_housing[i]
feature = paddle.to_tensor(feature)
target = paddle.to_tensor(target)
model = SimpleNet()
feature, target = model(feature, target)
print(feature.numpy().shape, target.numpy())
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
URL
,
MD5
,
'uci_housing'
,
download
)
# read dataset into memory
self
.
_load_data
()
def
_load_data
(
self
,
feature_num
=
14
,
ratio
=
0.8
):
data
=
np
.
fromfile
(
self
.
data_file
,
sep
=
' '
)
data
=
data
.
reshape
(
data
.
shape
[
0
]
//
feature_num
,
feature_num
)
maximums
,
minimums
,
avgs
=
data
.
max
(
axis
=
0
),
data
.
min
(
axis
=
0
),
data
.
sum
(
axis
=
0
)
/
data
.
shape
[
0
]
for
i
in
six
.
moves
.
range
(
feature_num
-
1
):
data
[:,
i
]
=
(
data
[:,
i
]
-
avgs
[
i
])
/
(
maximums
[
i
]
-
minimums
[
i
])
offset
=
int
(
data
.
shape
[
0
]
*
ratio
)
if
self
.
mode
==
'train'
:
self
.
data
=
data
[:
offset
]
elif
self
.
mode
==
'test'
:
self
.
data
=
data
[
offset
:]
def
__getitem__
(
self
,
idx
):
data
=
self
.
data
[
idx
]
return
np
.
array
(
data
[:
-
1
]),
np
.
array
(
data
[
-
1
:])
def
__len__
(
self
):
return
len
(
self
.
data
)
python/paddle/incubate/hapi/datasets/voc2012.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
io
import
tarfile
import
numpy
as
np
from
PIL
import
Image
from
paddle.io
import
Dataset
from
.utils
import
_check_exists_and_download
__all__
=
[
"VOC2012"
]
VOC_URL
=
'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/
\
VOCtrainval_11-May-2012.tar'
VOC_MD5
=
'131da710f39b47a43fdfa256cbc11976'
SET_FILE
=
'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
DATA_FILE
=
'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
LABEL_FILE
=
'VOCdevkit/VOC2012/SegmentationClass/{}.png'
CACHE_DIR
=
'voc2012'
MODE_FLAG_MAP
=
{
'train'
:
'trainval'
,
'test'
:
'train'
,
'valid'
:
"val"
}
class
VOC2012
(
Dataset
):
"""
Implementation of `VOC2012 <http://host.robots.ox.ac.uk/pascal/VOC/voc2012/>`_ dataset
Args:
data_file(str): path to data file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import VOC2012
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, image, label):
return paddle.sum(image), label
paddle.disable_static()
voc2012 = VOC2012(mode='train')
for i in range(10):
image, label= voc2012[i]
image = paddle.cast(paddle.to_tensor(image), 'float32')
label = paddle.to_tensor(label)
model = SimpleNet()
image, label= model(image, label)
print(image.numpy().shape, label.numpy().shape)
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
transform
=
None
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'valid'
,
'test'
],
\
"mode should be 'train', 'valid' or 'test', but got {}"
.
format
(
mode
)
self
.
flag
=
MODE_FLAG_MAP
[
mode
.
lower
()]
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
VOC_URL
,
VOC_MD5
,
CACHE_DIR
,
download
)
self
.
transform
=
transform
# read dataset into memory
self
.
_load_anno
()
def
_load_anno
(
self
):
self
.
name2mem
=
{}
self
.
data_tar
=
tarfile
.
open
(
self
.
data_file
)
for
ele
in
self
.
data_tar
.
getmembers
():
self
.
name2mem
[
ele
.
name
]
=
ele
set_file
=
SET_FILE
.
format
(
self
.
flag
)
sets
=
self
.
data_tar
.
extractfile
(
self
.
name2mem
[
set_file
])
self
.
data
=
[]
self
.
labels
=
[]
for
line
in
sets
:
line
=
line
.
strip
()
data
=
DATA_FILE
.
format
(
line
.
decode
(
'utf-8'
))
label
=
LABEL_FILE
.
format
(
line
.
decode
(
'utf-8'
))
self
.
data
.
append
(
data
)
self
.
labels
.
append
(
label
)
def
__getitem__
(
self
,
idx
):
data_file
=
self
.
data
[
idx
]
label_file
=
self
.
labels
[
idx
]
data
=
self
.
data_tar
.
extractfile
(
self
.
name2mem
[
data_file
]).
read
()
label
=
self
.
data_tar
.
extractfile
(
self
.
name2mem
[
label_file
]).
read
()
data
=
Image
.
open
(
io
.
BytesIO
(
data
))
label
=
Image
.
open
(
io
.
BytesIO
(
label
))
data
=
np
.
array
(
data
)
label
=
np
.
array
(
label
)
if
self
.
transform
is
not
None
:
data
=
self
.
transform
(
data
)
return
data
,
label
def
__len__
(
self
):
return
len
(
self
.
data
)
python/paddle/incubate/hapi/datasets/wmt14.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
tarfile
import
numpy
as
np
import
gzip
from
paddle.io
import
Dataset
import
paddle.compat
as
cpt
from
.utils
import
_check_exists_and_download
__all__
=
[
'WMT14'
]
URL_DEV_TEST
=
(
'http://www-lium.univ-lemans.fr/~schwenk/'
'cslm_joint_paper/data/dev+test.tgz'
)
MD5_DEV_TEST
=
'7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and
# will be add later.
URL_TRAIN
=
(
'http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz'
)
MD5_TRAIN
=
'0791583d57d5beb693b9414c5b36798c'
START
=
"<s>"
END
=
"<e>"
UNK
=
"<unk>"
UNK_IDX
=
2
class
WMT14
(
Dataset
):
"""
Implementation of `WMT14 <http://www.statmt.org/wmt14/>`_ test dataset.
The original WMT14 dataset is too large and a small set of data for set is
provided. This module will download dataset from
http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' or 'gen'. Default 'train'
dict_size(int): word dictionary size. Default -1.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of WMT14 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import WMT14
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, src_ids, trg_ids, trg_ids_next):
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
paddle.disable_static()
wmt14 = WMT14(mode='train', dict_size=50)
for i in range(10):
src_ids, trg_ids, trg_ids_next = wmt14[i]
src_ids = paddle.to_tensor(src_ids)
trg_ids = paddle.to_tensor(trg_ids)
trg_ids_next = paddle.to_tensor(trg_ids_next)
model = SimpleNet()
src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
dict_size
=-
1
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
,
'gen'
],
\
"mode should be 'train', 'test' or 'gen', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
URL_TRAIN
,
MD5_TRAIN
,
'wmt14'
,
download
)
# read dataset into memory
assert
dict_size
>
0
,
"dict_size should be set as positive number"
self
.
dict_size
=
dict_size
self
.
_load_data
()
def
_load_data
(
self
):
def
__to_dict
(
fd
,
size
):
out_dict
=
dict
()
for
line_count
,
line
in
enumerate
(
fd
):
if
line_count
<
size
:
out_dict
[
cpt
.
to_text
(
line
.
strip
())]
=
line_count
else
:
break
return
out_dict
self
.
src_ids
=
[]
self
.
trg_ids
=
[]
self
.
trg_ids_next
=
[]
with
tarfile
.
open
(
self
.
data_file
,
mode
=
'r'
)
as
f
:
names
=
[
each_item
.
name
for
each_item
in
f
if
each_item
.
name
.
endswith
(
"src.dict"
)
]
assert
len
(
names
)
==
1
self
.
src_dict
=
__to_dict
(
f
.
extractfile
(
names
[
0
]),
self
.
dict_size
)
names
=
[
each_item
.
name
for
each_item
in
f
if
each_item
.
name
.
endswith
(
"trg.dict"
)
]
assert
len
(
names
)
==
1
self
.
trg_dict
=
__to_dict
(
f
.
extractfile
(
names
[
0
]),
self
.
dict_size
)
file_name
=
"{}/{}"
.
format
(
self
.
mode
,
self
.
mode
)
names
=
[
each_item
.
name
for
each_item
in
f
if
each_item
.
name
.
endswith
(
file_name
)
]
for
name
in
names
:
for
line
in
f
.
extractfile
(
name
):
line
=
cpt
.
to_text
(
line
)
line_split
=
line
.
strip
().
split
(
'
\t
'
)
if
len
(
line_split
)
!=
2
:
continue
src_seq
=
line_split
[
0
]
# one source sequence
src_words
=
src_seq
.
split
()
src_ids
=
[
self
.
src_dict
.
get
(
w
,
UNK_IDX
)
for
w
in
[
START
]
+
src_words
+
[
END
]
]
trg_seq
=
line_split
[
1
]
# one target sequence
trg_words
=
trg_seq
.
split
()
trg_ids
=
[
self
.
trg_dict
.
get
(
w
,
UNK_IDX
)
for
w
in
trg_words
]
# remove sequence whose length > 80 in training mode
if
len
(
src_ids
)
>
80
or
len
(
trg_ids
)
>
80
:
continue
trg_ids_next
=
trg_ids
+
[
self
.
trg_dict
[
END
]]
trg_ids
=
[
self
.
trg_dict
[
START
]]
+
trg_ids
self
.
src_ids
.
append
(
src_ids
)
self
.
trg_ids
.
append
(
trg_ids
)
self
.
trg_ids_next
.
append
(
trg_ids_next
)
def
__getitem__
(
self
,
idx
):
return
(
np
.
array
(
self
.
src_ids
[
idx
]),
np
.
array
(
self
.
trg_ids
[
idx
]),
np
.
array
(
self
.
trg_ids_next
[
idx
]))
def
__len__
(
self
):
return
len
(
self
.
src_ids
)
def
get_dict
(
self
,
reverse
=
False
):
if
reverse
:
src_dict
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
src_dict
)}
trg_dict
=
{
v
:
k
for
k
,
v
in
six
.
iteritems
(
trg_dict
)}
return
src_dict
,
trg_dict
python/paddle/incubate/hapi/datasets/wmt16.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
from
__future__
import
print_function
import
os
import
six
import
tarfile
import
numpy
as
np
from
collections
import
defaultdict
import
paddle
from
paddle.io
import
Dataset
import
paddle.compat
as
cpt
from
.utils
import
_check_exists_and_download
__all__
=
[
'WMT16'
]
DATA_URL
=
(
"http://paddlemodels.bj.bcebos.com/wmt/wmt16.tar.gz"
)
DATA_MD5
=
"0c38be43600334966403524a40dcd81e"
TOTAL_EN_WORDS
=
11250
TOTAL_DE_WORDS
=
19220
START_MARK
=
"<s>"
END_MARK
=
"<e>"
UNK_MARK
=
"<unk>"
class
WMT16
(
Dataset
):
"""
Implementation of `WMT16 <http://www.statmt.org/wmt16/>`_ test dataset.
ACL2016 Multimodal Machine Translation. Please see this website for more
details: http://www.statmt.org/wmt16/multimodal-task.html#task1
If you use the dataset created for your task, please cite the following paper:
Multi30K: Multilingual English-German Image Descriptions.
.. code-block:: text
@article{elliott-EtAl:2016:VL16,
author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.},
title = {Multi30K: Multilingual English-German Image Descriptions},
booktitle = {Proceedings of the 6th Workshop on Vision and Language},
year = {2016},
pages = {70--74},
year = 2016
}
Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train', 'test' or 'val'. Default 'train'
src_dict_size(int): word dictionary size for source language word. Default -1.
trg_dict_size(int): word dictionary size for target language word. Default -1.
lang(str): source language, 'en' or 'de'. Default 'en'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
Returns:
Dataset: instance of WMT16 dataset
Examples:
.. code-block:: python
import paddle
from paddle.incubate.hapi.datasets import WMT16
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
def forward(self, src_ids, trg_ids, trg_ids_next):
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
paddle.disable_static()
wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50)
for i in range(10):
src_ids, trg_ids, trg_ids_next = wmt16[i]
src_ids = paddle.to_tensor(src_ids)
trg_ids = paddle.to_tensor(trg_ids)
trg_ids_next = paddle.to_tensor(trg_ids_next)
model = SimpleNet()
src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())
"""
def
__init__
(
self
,
data_file
=
None
,
mode
=
'train'
,
src_dict_size
=-
1
,
trg_dict_size
=-
1
,
lang
=
'en'
,
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
,
'val'
],
\
"mode should be 'train', 'test' or 'val', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
data_file
=
data_file
if
self
.
data_file
is
None
:
assert
download
,
"data_file is not set and downloading automatically is disabled"
self
.
data_file
=
_check_exists_and_download
(
data_file
,
DATA_URL
,
DATA_MD5
,
'wmt16'
,
download
)
self
.
lang
=
lang
assert
src_dict_size
>
0
,
"dict_size should be set as positive number"
assert
trg_dict_size
>
0
,
"dict_size should be set as positive number"
self
.
src_dict_size
=
min
(
src_dict_size
,
(
TOTAL_EN_WORDS
if
lang
==
"en"
else
TOTAL_DE_WORDS
))
self
.
trg_dict_size
=
min
(
trg_dict_size
,
(
TOTAL_DE_WORDS
if
lang
==
"en"
else
TOTAL_EN_WORDS
))
# load source and target word dict
self
.
src_dict
=
self
.
_load_dict
(
lang
,
src_dict_size
)
self
.
trg_dict
=
self
.
_load_dict
(
"de"
if
lang
==
"en"
else
"en"
,
trg_dict_size
)
# load data
self
.
data
=
self
.
_load_data
()
def
_load_dict
(
self
,
lang
,
dict_size
,
reverse
=
False
):
dict_path
=
os
.
path
.
join
(
paddle
.
dataset
.
common
.
DATA_HOME
,
"wmt16/%s_%d.dict"
%
(
lang
,
dict_size
))
if
not
os
.
path
.
exists
(
dict_path
)
or
(
len
(
open
(
dict_path
,
"rb"
).
readlines
())
!=
dict_size
):
self
.
_build_dict
(
dict_path
,
dict_size
,
lang
)
word_dict
=
{}
with
open
(
dict_path
,
"rb"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
cpt
.
to_text
(
line
.
strip
())
else
:
word_dict
[
cpt
.
to_text
(
line
.
strip
())]
=
idx
return
word_dict
def
_build_dict
(
self
,
dict_path
,
dict_size
,
lang
):
word_dict
=
defaultdict
(
int
)
with
tarfile
.
open
(
self
.
data_file
,
mode
=
"r"
)
as
f
:
for
line
in
f
.
extractfile
(
"wmt16/train"
):
line
=
cpt
.
to_text
(
line
)
line_split
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
line_split
)
!=
2
:
continue
sen
=
line_split
[
0
]
if
self
.
lang
==
"en"
else
line_split
[
1
]
for
w
in
sen
.
split
():
word_dict
[
w
]
+=
1
with
open
(
dict_path
,
"wb"
)
as
fout
:
fout
.
write
(
cpt
.
to_bytes
(
"%s
\n
%s
\n
%s
\n
"
%
(
START_MARK
,
END_MARK
,
UNK_MARK
)))
for
idx
,
word
in
enumerate
(
sorted
(
six
.
iteritems
(
word_dict
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)):
if
idx
+
3
==
dict_size
:
break
fout
.
write
(
cpt
.
to_bytes
(
word
[
0
]))
fout
.
write
(
cpt
.
to_bytes
(
'
\n
'
))
def
_load_data
(
self
):
# the index for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
# dictionary to determine their indices.
start_id
=
self
.
src_dict
[
START_MARK
]
end_id
=
self
.
src_dict
[
END_MARK
]
unk_id
=
self
.
src_dict
[
UNK_MARK
]
src_col
=
0
if
self
.
lang
==
"en"
else
1
trg_col
=
1
-
src_col
self
.
src_ids
=
[]
self
.
trg_ids
=
[]
self
.
trg_ids_next
=
[]
with
tarfile
.
open
(
self
.
data_file
,
mode
=
"r"
)
as
f
:
for
line
in
f
.
extractfile
(
"wmt16/{}"
.
format
(
self
.
mode
)):
line
=
cpt
.
to_text
(
line
)
line_split
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
line_split
)
!=
2
:
continue
src_words
=
line_split
[
src_col
].
split
()
src_ids
=
[
start_id
]
+
[
self
.
src_dict
.
get
(
w
,
unk_id
)
for
w
in
src_words
]
+
[
end_id
]
trg_words
=
line_split
[
trg_col
].
split
()
trg_ids
=
[
self
.
trg_dict
.
get
(
w
,
unk_id
)
for
w
in
trg_words
]
trg_ids_next
=
trg_ids
+
[
end_id
]
trg_ids
=
[
start_id
]
+
trg_ids
self
.
src_ids
.
append
(
src_ids
)
self
.
trg_ids
.
append
(
trg_ids
)
self
.
trg_ids_next
.
append
(
trg_ids_next
)
def
__getitem__
(
self
,
idx
):
return
(
np
.
array
(
self
.
src_ids
[
idx
]),
np
.
array
(
self
.
trg_ids
[
idx
]),
np
.
array
(
self
.
trg_ids_next
[
idx
]))
def
__len__
(
self
):
return
len
(
self
.
src_ids
)
def
get_dict
(
self
,
lang
,
reverse
=
False
):
"""
return the word dictionary for the specified language.
Args:
lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
reverse(bool): If reverse is set to False, the returned python
dictionary will use word as key and use index as value.
If reverse is set to True, the returned python
dictionary will use index as key and word as value.
Returns:
dict: The word dictionary for the specific language.
"""
dict_size
=
self
.
src_dict_size
if
lang
==
self
.
lang
else
self
.
trg_dict_size
dict_path
=
os
.
path
.
join
(
paddle
.
dataset
.
common
.
DATA_HOME
,
"wmt16/%s_%d.dict"
%
(
lang
,
dict_size
))
assert
os
.
path
.
exists
(
dict_path
),
"Word dictionary does not exist. "
"Please invoke paddle.dataset.wmt16.train/test/validation first "
"to build the dictionary."
return
_load_dict
(
lang
,
dict_size
)
python/paddle/incubate/hapi/tests/test_dataset_cifar.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestCifar10Train
(
unittest
.
TestCase
):
def
test_main
(
self
):
cifar
=
Cifar10
(
mode
=
'train'
)
self
.
assertTrue
(
len
(
cifar
)
==
50000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
50000
)
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
class
TestCifar10Test
(
unittest
.
TestCase
):
def
test_main
(
self
):
cifar
=
Cifar10
(
mode
=
'test'
)
self
.
assertTrue
(
len
(
cifar
)
==
10000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
10000
)
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
class
TestCifar100Train
(
unittest
.
TestCase
):
def
test_main
(
self
):
cifar
=
Cifar100
(
mode
=
'train'
)
self
.
assertTrue
(
len
(
cifar
)
==
50000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
50000
)
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
99
)
class
TestCifar100Test
(
unittest
.
TestCase
):
def
test_main
(
self
):
cifar
=
Cifar100
(
mode
=
'test'
)
self
.
assertTrue
(
len
(
cifar
)
==
10000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
10000
)
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
99
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_conll05.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestConll05st
(
unittest
.
TestCase
):
def
test_main
(
self
):
conll05st
=
Conll05st
()
self
.
assertTrue
(
len
(
conll05st
)
==
5267
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
5267
)
sample
=
conll05st
[
idx
]
self
.
assertTrue
(
len
(
sample
)
==
9
)
for
s
in
sample
:
self
.
assertTrue
(
len
(
s
.
shape
)
==
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_imdb.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestImdbTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
imdb
=
Imdb
(
mode
=
'train'
)
self
.
assertTrue
(
len
(
imdb
)
==
25000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
25000
)
data
,
label
=
imdb
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
label
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
int
(
label
)
in
[
0
,
1
])
class
TestImdbTest
(
unittest
.
TestCase
):
def
test_main
(
self
):
imdb
=
Imdb
(
mode
=
'test'
)
self
.
assertTrue
(
len
(
imdb
)
==
25000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
25000
)
data
,
label
=
imdb
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
label
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
int
(
label
)
in
[
0
,
1
])
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_imikolov.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestImikolovTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
imikolov
=
Imikolov
(
mode
=
'train'
,
data_type
=
'NGRAM'
,
window_size
=
2
)
self
.
assertTrue
(
len
(
imikolov
)
==
929589
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
929589
)
data
=
imikolov
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
2
)
class
TestImikolovTest
(
unittest
.
TestCase
):
def
test_main
(
self
):
imikolov
=
Imikolov
(
mode
=
'test'
,
data_type
=
'NGRAM'
,
window_size
=
2
)
self
.
assertTrue
(
len
(
imikolov
)
==
82430
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
82430
)
data
=
imikolov
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
2
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_movie_reviews.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestMovieReviewsTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
movie_reviews
=
MovieReviews
(
mode
=
'train'
)
self
.
assertTrue
(
len
(
movie_reviews
)
==
1600
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
1600
)
data
=
movie_reviews
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
2
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
int
(
data
[
1
])
in
[
0
,
1
])
class
TestMovieReviewsTest
(
unittest
.
TestCase
):
def
test_main
(
self
):
movie_reviews
=
MovieReviews
(
mode
=
'test'
)
self
.
assertTrue
(
len
(
movie_reviews
)
==
400
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
400
)
data
=
movie_reviews
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
2
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
int
(
data
[
1
])
in
[
0
,
1
])
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_movielens.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestMovielensTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
movielens
=
Movielens
(
mode
=
'train'
)
# movielens dataset random split train/test
# not check dataset length here
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
900000
)
data
=
movielens
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
8
)
for
i
,
d
in
enumerate
(
data
):
self
.
assertTrue
(
len
(
d
.
shape
)
==
1
)
if
i
not
in
[
5
,
6
]:
self
.
assertTrue
(
d
.
shape
[
0
]
==
1
)
class
TestMovielensTest
(
unittest
.
TestCase
):
def
test_main
(
self
):
movielens
=
Movielens
(
mode
=
'test'
)
# movielens dataset random split train/test
# not check dataset length here
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
100000
)
data
=
movielens
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
8
)
for
i
,
d
in
enumerate
(
data
):
self
.
assertTrue
(
len
(
d
.
shape
)
==
1
)
if
i
not
in
[
5
,
6
]:
self
.
assertTrue
(
d
.
shape
[
0
]
==
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_uci_housing.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestUCIHousingTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
uci_housing
=
UCIHousing
(
mode
=
'train'
)
self
.
assertTrue
(
len
(
uci_housing
)
==
404
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
404
)
data
=
uci_housing
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
2
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
data
[
0
].
shape
[
0
]
==
13
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
data
[
1
].
shape
[
0
]
==
1
)
class
TestUCIHousingTest
(
unittest
.
TestCase
):
def
test_main
(
self
):
uci_housing
=
UCIHousing
(
mode
=
'test'
)
self
.
assertTrue
(
len
(
uci_housing
)
==
102
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
102
)
data
=
uci_housing
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
2
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
data
[
0
].
shape
[
0
]
==
13
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
data
[
1
].
shape
[
0
]
==
1
)
class
TestWMT14Train
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt14
=
WMT14
(
mode
=
'train'
,
dict_size
=
50
)
self
.
assertTrue
(
len
(
wmt14
)
==
191155
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
191155
)
data
=
wmt14
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
class
TestWMT14Test
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt14
=
WMT14
(
mode
=
'test'
,
dict_size
=
50
)
self
.
assertTrue
(
len
(
wmt14
)
==
5957
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
5957
)
data
=
wmt14
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
class
TestWMT14Gen
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt14
=
WMT14
(
mode
=
'gen'
,
dict_size
=
50
)
self
.
assertTrue
(
len
(
wmt14
)
==
3001
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
3001
)
data
=
wmt14
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_voc.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
voc2012
,
VOC2012
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
# VOC2012 is too large for unittest to download, stub a small dataset here
voc2012
.
VOC_URL
=
'https://paddlemodels.bj.bcebos.com/voc2012_stub/VOCtrainval_11-May-2012.tar'
voc2012
.
VOC_MD5
=
'34cb1fe5bdc139a5454b25b16118fff8'
class
TestVOC2012Train
(
unittest
.
TestCase
):
def
test_main
(
self
):
voc2012
=
VOC2012
(
mode
=
'train'
)
self
.
assertTrue
(
len
(
voc2012
)
==
3
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
3
)
image
,
label
=
voc2012
[
idx
]
self
.
assertTrue
(
len
(
image
.
shape
)
==
3
)
self
.
assertTrue
(
len
(
label
.
shape
)
==
2
)
class
TestVOC2012Valid
(
unittest
.
TestCase
):
def
test_main
(
self
):
voc2012
=
VOC2012
(
mode
=
'valid'
)
self
.
assertTrue
(
len
(
voc2012
)
==
1
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
1
)
image
,
label
=
voc2012
[
idx
]
self
.
assertTrue
(
len
(
image
.
shape
)
==
3
)
self
.
assertTrue
(
len
(
label
.
shape
)
==
2
)
class
TestVOC2012Test
(
unittest
.
TestCase
):
def
test_main
(
self
):
voc2012
=
VOC2012
(
mode
=
'test'
)
self
.
assertTrue
(
len
(
voc2012
)
==
2
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
1
)
image
,
label
=
voc2012
[
idx
]
self
.
assertTrue
(
len
(
image
.
shape
)
==
3
)
self
.
assertTrue
(
len
(
label
.
shape
)
==
2
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/hapi/tests/test_dataset_wmt.py
0 → 100644
浏览文件 @
1a72a903
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
tempfile
import
shutil
import
cv2
from
paddle.incubate.hapi.datasets
import
*
from
paddle.incubate.hapi.datasets.utils
import
_check_exists_and_download
class
TestWMT14Train
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt14
=
WMT14
(
mode
=
'train'
,
dict_size
=
50
)
self
.
assertTrue
(
len
(
wmt14
)
==
191155
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
191155
)
data
=
wmt14
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
class
TestWMT14Test
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt14
=
WMT14
(
mode
=
'test'
,
dict_size
=
50
)
self
.
assertTrue
(
len
(
wmt14
)
==
5957
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
5957
)
data
=
wmt14
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
class
TestWMT14Gen
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt14
=
WMT14
(
mode
=
'gen'
,
dict_size
=
50
)
self
.
assertTrue
(
len
(
wmt14
)
==
3001
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
3001
)
data
=
wmt14
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
class
TestWMT16Train
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt16
=
WMT16
(
mode
=
'train'
,
src_dict_size
=
50
,
trg_dict_size
=
50
,
lang
=
'en'
)
self
.
assertTrue
(
len
(
wmt16
)
==
29000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
29000
)
data
=
wmt16
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
class
TestWMT16Test
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt16
=
WMT16
(
mode
=
'test'
,
src_dict_size
=
50
,
trg_dict_size
=
50
,
lang
=
'en'
)
self
.
assertTrue
(
len
(
wmt16
)
==
1000
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
1000
)
data
=
wmt16
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
class
TestWMT16Val
(
unittest
.
TestCase
):
def
test_main
(
self
):
wmt16
=
WMT16
(
mode
=
'val'
,
src_dict_size
=
50
,
trg_dict_size
=
50
,
lang
=
'en'
)
self
.
assertTrue
(
len
(
wmt16
)
==
1014
)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
1014
)
data
=
wmt16
[
idx
]
self
.
assertTrue
(
len
(
data
)
==
3
)
self
.
assertTrue
(
len
(
data
[
0
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
1
].
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
[
2
].
shape
)
==
1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录