Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b89b4e32
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b89b4e32
编写于
11月 04, 2020
作者:
L
LielinJiang
提交者:
GitHub
11月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fashion dataset (#28411)
上级
463075a8
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
109 addition
and
18 deletion
+109
-18
python/paddle/tests/test_datasets.py
python/paddle/tests/test_datasets.py
+46
-0
python/paddle/vision/datasets/cifar.py
python/paddle/vision/datasets/cifar.py
+1
-1
python/paddle/vision/datasets/mnist.py
python/paddle/vision/datasets/mnist.py
+62
-17
未找到文件。
python/paddle/tests/test_datasets.py
浏览文件 @
b89b4e32
...
...
@@ -134,6 +134,52 @@ class TestMNISTTrain(unittest.TestCase):
mnist
=
MNIST
(
mode
=
'train'
,
transform
=
transform
,
backend
=
1
)
class
TestFASHIONMNISTTest
(
unittest
.
TestCase
):
def
test_main
(
self
):
transform
=
T
.
Transpose
()
mnist
=
FashionMNIST
(
mode
=
'test'
,
transform
=
transform
)
self
.
assertTrue
(
len
(
mnist
)
==
10000
)
for
i
in
range
(
len
(
mnist
)):
image
,
label
=
mnist
[
i
]
self
.
assertTrue
(
image
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
image
.
shape
[
1
]
==
28
)
self
.
assertTrue
(
image
.
shape
[
2
]
==
28
)
self
.
assertTrue
(
label
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
class
TestFASHIONMNISTTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
transform
=
T
.
Transpose
()
mnist
=
FashionMNIST
(
mode
=
'train'
,
transform
=
transform
)
self
.
assertTrue
(
len
(
mnist
)
==
60000
)
for
i
in
range
(
len
(
mnist
)):
image
,
label
=
mnist
[
i
]
self
.
assertTrue
(
image
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
image
.
shape
[
1
]
==
28
)
self
.
assertTrue
(
image
.
shape
[
2
]
==
28
)
self
.
assertTrue
(
label
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
# test cv2 backend
mnist
=
FashionMNIST
(
mode
=
'train'
,
transform
=
transform
,
backend
=
'cv2'
)
self
.
assertTrue
(
len
(
mnist
)
==
60000
)
for
i
in
range
(
len
(
mnist
)):
image
,
label
=
mnist
[
i
]
self
.
assertTrue
(
image
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
image
.
shape
[
1
]
==
28
)
self
.
assertTrue
(
image
.
shape
[
2
]
==
28
)
self
.
assertTrue
(
label
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
break
with
self
.
assertRaises
(
ValueError
):
mnist
=
FashionMNIST
(
mode
=
'train'
,
transform
=
transform
,
backend
=
1
)
class
TestFlowersTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
flowers
=
Flowers
(
mode
=
'train'
)
...
...
python/paddle/vision/datasets/cifar.py
浏览文件 @
b89b4e32
...
...
@@ -161,7 +161,7 @@ class Cifar10(Dataset):
image
=
image
.
transpose
([
1
,
2
,
0
])
if
self
.
backend
==
'pil'
:
image
=
Image
.
fromarray
(
image
)
image
=
Image
.
fromarray
(
image
.
astype
(
'uint8'
)
)
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
...
...
python/paddle/vision/datasets/mnist.py
浏览文件 @
b89b4e32
...
...
@@ -24,17 +24,7 @@ import paddle
from
paddle.io
import
Dataset
from
paddle.dataset.common
import
_check_exists_and_download
__all__
=
[
"MNIST"
]
URL_PREFIX
=
'https://dataset.bj.bcebos.com/mnist/'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5
=
'9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL
=
URL_PREFIX
+
't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5
=
'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL
=
URL_PREFIX
+
'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5
=
'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL
=
URL_PREFIX
+
'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_MD5
=
'd53e105ee54ea40749a09fcbcd1e9432'
__all__
=
[
"MNIST"
,
"FashionMNIST"
]
class
MNIST
(
Dataset
):
...
...
@@ -70,6 +60,16 @@ class MNIST(Dataset):
print(sample[0].size, sample[1])
"""
NAME
=
'mnist'
URL_PREFIX
=
'https://dataset.bj.bcebos.com/mnist/'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5
=
'9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL
=
URL_PREFIX
+
't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5
=
'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL
=
URL_PREFIX
+
'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5
=
'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL
=
URL_PREFIX
+
'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_MD5
=
'd53e105ee54ea40749a09fcbcd1e9432'
def
__init__
(
self
,
image_path
=
None
,
...
...
@@ -93,18 +93,18 @@ class MNIST(Dataset):
self
.
image_path
=
image_path
if
self
.
image_path
is
None
:
assert
download
,
"image_path is not set and downloading automatically is disabled"
image_url
=
TRAIN_IMAGE_URL
if
mode
==
'train'
else
TEST_IMAGE_URL
image_md5
=
TRAIN_IMAGE_MD5
if
mode
==
'train'
else
TEST_IMAGE_MD5
image_url
=
self
.
TRAIN_IMAGE_URL
if
mode
==
'train'
else
self
.
TEST_IMAGE_URL
image_md5
=
self
.
TRAIN_IMAGE_MD5
if
mode
==
'train'
else
self
.
TEST_IMAGE_MD5
self
.
image_path
=
_check_exists_and_download
(
image_path
,
image_url
,
image_md5
,
'mnist'
,
download
)
image_path
,
image_url
,
image_md5
,
self
.
NAME
,
download
)
self
.
label_path
=
label_path
if
self
.
label_path
is
None
:
assert
download
,
"label_path is not set and downloading automatically is disabled"
label_url
=
TRAIN_LABEL_URL
if
self
.
mode
==
'train'
else
TEST_LABEL_URL
label_md5
=
TRAIN_LABEL_MD5
if
self
.
mode
==
'train'
else
TEST_LABEL_MD5
label_url
=
self
.
TRAIN_LABEL_URL
if
self
.
mode
==
'train'
else
self
.
TEST_LABEL_URL
label_md5
=
self
.
TRAIN_LABEL_MD5
if
self
.
mode
==
'train'
else
self
.
TEST_LABEL_MD5
self
.
label_path
=
_check_exists_and_download
(
label_path
,
label_url
,
label_md5
,
'mnist'
,
download
)
label_path
,
label_url
,
label_md5
,
self
.
NAME
,
download
)
self
.
transform
=
transform
...
...
@@ -175,3 +175,48 @@ class MNIST(Dataset):
def
__len__
(
self
):
return
len
(
self
.
labels
)
class
FashionMNIST
(
MNIST
):
"""
Implementation `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ dataset.
Args:
image_path(str): path to image file, can be set None if
:attr:`download` is True. Default None
label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Returns:
Dataset: Fashion-MNIST Dataset.
Examples:
.. code-block:: python
from paddle.vision.datasets import FashionMNIST
mnist = FashionMNIST(mode='test')
for i in range(len(mnist)):
sample = mnist[i]
print(sample[0].size, sample[1])
"""
NAME
=
'fashion-mnist'
URL_PREFIX
=
'https://dataset.bj.bcebos.com/fashion_mnist/'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5
=
'bef4ecab320f06d8554ea6380940ec79'
TEST_LABEL_URL
=
URL_PREFIX
+
't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5
=
'bb300cfdad3c16e7a12a480ee83cd310'
TRAIN_IMAGE_URL
=
URL_PREFIX
+
'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5
=
'8d4fb7e6c68d591d4c3dfef9ec88bf0d'
TRAIN_LABEL_URL
=
URL_PREFIX
+
'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_MD5
=
'25c81989df183df01b3e8a0aad5dffbe'
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录