Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f3b4a64a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3b4a64a
编写于
9月 21, 2020
作者:
K
Kaipeng Deng
提交者:
GitHub
9月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix CIFAR MNIST UCIHousing dataset. test=develop (#27368)
* fix CIFAR & MNIST dataset. test=develop
上级
f936adbd
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
27 addition
and
21 deletion
+27
-21
python/paddle/tests/test_dataset_cifar.py
python/paddle/tests/test_dataset_cifar.py
+16
-8
python/paddle/tests/test_datasets.py
python/paddle/tests/test_datasets.py
+4
-2
python/paddle/text/datasets/uci_housing.py
python/paddle/text/datasets/uci_housing.py
+5
-1
python/paddle/vision/datasets/cifar.py
python/paddle/vision/datasets/cifar.py
+1
-0
python/paddle/vision/datasets/mnist.py
python/paddle/vision/datasets/mnist.py
+1
-10
未找到文件。
python/paddle/tests/test_dataset_cifar.py
浏览文件 @
f3b4a64a
...
@@ -27,8 +27,10 @@ class TestCifar10Train(unittest.TestCase):
...
@@ -27,8 +27,10 @@ class TestCifar10Train(unittest.TestCase):
# long time, randomly check 1 sample
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
50000
)
idx
=
np
.
random
.
randint
(
0
,
50000
)
data
,
label
=
cifar
[
idx
]
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
.
shape
)
==
3
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3
)
self
.
assertTrue
(
data
.
shape
[
1
]
==
32
)
self
.
assertTrue
(
data
.
shape
[
2
]
==
32
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
...
@@ -41,8 +43,10 @@ class TestCifar10Test(unittest.TestCase):
...
@@ -41,8 +43,10 @@ class TestCifar10Test(unittest.TestCase):
# long time, randomly check 1 sample
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
10000
)
idx
=
np
.
random
.
randint
(
0
,
10000
)
data
,
label
=
cifar
[
idx
]
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
.
shape
)
==
3
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3
)
self
.
assertTrue
(
data
.
shape
[
1
]
==
32
)
self
.
assertTrue
(
data
.
shape
[
2
]
==
32
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
...
@@ -55,8 +59,10 @@ class TestCifar100Train(unittest.TestCase):
...
@@ -55,8 +59,10 @@ class TestCifar100Train(unittest.TestCase):
# long time, randomly check 1 sample
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
50000
)
idx
=
np
.
random
.
randint
(
0
,
50000
)
data
,
label
=
cifar
[
idx
]
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
.
shape
)
==
3
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3
)
self
.
assertTrue
(
data
.
shape
[
1
]
==
32
)
self
.
assertTrue
(
data
.
shape
[
2
]
==
32
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
99
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
99
)
...
@@ -69,8 +75,10 @@ class TestCifar100Test(unittest.TestCase):
...
@@ -69,8 +75,10 @@ class TestCifar100Test(unittest.TestCase):
# long time, randomly check 1 sample
# long time, randomly check 1 sample
idx
=
np
.
random
.
randint
(
0
,
10000
)
idx
=
np
.
random
.
randint
(
0
,
10000
)
data
,
label
=
cifar
[
idx
]
data
,
label
=
cifar
[
idx
]
self
.
assertTrue
(
len
(
data
.
shape
)
==
1
)
self
.
assertTrue
(
len
(
data
.
shape
)
==
3
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3072
)
self
.
assertTrue
(
data
.
shape
[
0
]
==
3
)
self
.
assertTrue
(
data
.
shape
[
1
]
==
32
)
self
.
assertTrue
(
data
.
shape
[
2
]
==
32
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
99
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
99
)
...
...
python/paddle/tests/test_datasets.py
浏览文件 @
f3b4a64a
...
@@ -103,12 +103,14 @@ class TestMNISTTest(unittest.TestCase):
...
@@ -103,12 +103,14 @@ class TestMNISTTest(unittest.TestCase):
class
TestMNISTTrain
(
unittest
.
TestCase
):
class
TestMNISTTrain
(
unittest
.
TestCase
):
def
test_main
(
self
):
def
test_main
(
self
):
mnist
=
MNIST
(
mode
=
'train'
,
chw_format
=
False
)
mnist
=
MNIST
(
mode
=
'train'
)
self
.
assertTrue
(
len
(
mnist
)
==
60000
)
self
.
assertTrue
(
len
(
mnist
)
==
60000
)
for
i
in
range
(
len
(
mnist
)):
for
i
in
range
(
len
(
mnist
)):
image
,
label
=
mnist
[
i
]
image
,
label
=
mnist
[
i
]
self
.
assertTrue
(
image
.
shape
[
0
]
==
784
)
self
.
assertTrue
(
image
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
image
.
shape
[
1
]
==
28
)
self
.
assertTrue
(
image
.
shape
[
2
]
==
28
)
self
.
assertTrue
(
label
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
label
.
shape
[
0
]
==
1
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
self
.
assertTrue
(
0
<=
int
(
label
)
<=
9
)
...
...
python/paddle/text/datasets/uci_housing.py
浏览文件 @
f3b4a64a
...
@@ -17,6 +17,7 @@ from __future__ import print_function
...
@@ -17,6 +17,7 @@ from __future__ import print_function
import
six
import
six
import
numpy
as
np
import
numpy
as
np
import
paddle
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
from
paddle.dataset.common
import
_check_exists_and_download
from
paddle.dataset.common
import
_check_exists_and_download
...
@@ -88,6 +89,8 @@ class UCIHousing(Dataset):
...
@@ -88,6 +89,8 @@ class UCIHousing(Dataset):
# read dataset into memory
# read dataset into memory
self
.
_load_data
()
self
.
_load_data
()
self
.
dtype
=
paddle
.
get_default_dtype
()
def
_load_data
(
self
,
feature_num
=
14
,
ratio
=
0.8
):
def
_load_data
(
self
,
feature_num
=
14
,
ratio
=
0.8
):
data
=
np
.
fromfile
(
self
.
data_file
,
sep
=
' '
)
data
=
np
.
fromfile
(
self
.
data_file
,
sep
=
' '
)
data
=
data
.
reshape
(
data
.
shape
[
0
]
//
feature_num
,
feature_num
)
data
=
data
.
reshape
(
data
.
shape
[
0
]
//
feature_num
,
feature_num
)
...
@@ -103,7 +106,8 @@ class UCIHousing(Dataset):
...
@@ -103,7 +106,8 @@ class UCIHousing(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
data
=
self
.
data
[
idx
]
data
=
self
.
data
[
idx
]
return
np
.
array
(
data
[:
-
1
]),
np
.
array
(
data
[
-
1
:])
return
np
.
array
(
data
[:
-
1
]).
astype
(
self
.
dtype
),
\
np
.
array
(
data
[
-
1
:]).
astype
(
self
.
dtype
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
data
)
return
len
(
self
.
data
)
python/paddle/vision/datasets/cifar.py
浏览文件 @
f3b4a64a
...
@@ -139,6 +139,7 @@ class Cifar10(Dataset):
...
@@ -139,6 +139,7 @@ class Cifar10(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
image
,
label
=
self
.
data
[
idx
]
image
,
label
=
self
.
data
[
idx
]
image
=
np
.
reshape
(
image
,
[
3
,
32
,
32
])
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
image
=
self
.
transform
(
image
)
return
image
,
label
return
image
,
label
...
...
python/paddle/vision/datasets/mnist.py
浏览文件 @
f3b4a64a
...
@@ -44,8 +44,6 @@ class MNIST(Dataset):
...
@@ -44,8 +44,6 @@ class MNIST(Dataset):
:attr:`download` is True. Default None
:attr:`download` is True. Default None
label_path(str): path to label file, can be set None if
label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None
:attr:`download` is True. Default None
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True
:attr:`image_path` :attr:`label_path` is not set. Default True
...
@@ -70,14 +68,12 @@ class MNIST(Dataset):
...
@@ -70,14 +68,12 @@ class MNIST(Dataset):
def
__init__
(
self
,
def
__init__
(
self
,
image_path
=
None
,
image_path
=
None
,
label_path
=
None
,
label_path
=
None
,
chw_format
=
True
,
mode
=
'train'
,
mode
=
'train'
,
transform
=
None
,
transform
=
None
,
download
=
True
):
download
=
True
):
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
assert
mode
.
lower
()
in
[
'train'
,
'test'
],
\
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
"mode should be 'train' or 'test', but got {}"
.
format
(
mode
)
self
.
mode
=
mode
.
lower
()
self
.
mode
=
mode
.
lower
()
self
.
chw_format
=
chw_format
self
.
image_path
=
image_path
self
.
image_path
=
image_path
if
self
.
image_path
is
None
:
if
self
.
image_path
is
None
:
assert
download
,
"image_path is not set and downloading automatically is disabled"
assert
download
,
"image_path is not set and downloading automatically is disabled"
...
@@ -139,10 +135,6 @@ class MNIST(Dataset):
...
@@ -139,10 +135,6 @@ class MNIST(Dataset):
cols
)).
astype
(
'float32'
)
cols
)).
astype
(
'float32'
)
offset_img
+=
struct
.
calcsize
(
fmt_images
)
offset_img
+=
struct
.
calcsize
(
fmt_images
)
images
=
images
/
255.0
images
=
images
*
2.0
images
=
images
-
1.0
for
i
in
range
(
buffer_size
):
for
i
in
range
(
buffer_size
):
self
.
images
.
append
(
images
[
i
,
:])
self
.
images
.
append
(
images
[
i
,
:])
self
.
labels
.
append
(
self
.
labels
.
append
(
...
@@ -150,8 +142,7 @@ class MNIST(Dataset):
...
@@ -150,8 +142,7 @@ class MNIST(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
image
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
image
,
label
=
self
.
images
[
idx
],
self
.
labels
[
idx
]
if
self
.
chw_format
:
image
=
np
.
reshape
(
image
,
[
1
,
28
,
28
])
image
=
np
.
reshape
(
image
,
[
1
,
28
,
28
])
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
image
)
image
=
self
.
transform
(
image
)
return
image
,
label
return
image
,
label
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录