未验证 提交 f3b4a64a 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix CIFAR MNIST UCIHousing dataset. test=develop (#27368)

* fix CIFAR & MNIST dataset. test=develop
上级 f936adbd
......@@ -27,8 +27,10 @@ class TestCifar10Train(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 50000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 9)
......@@ -41,8 +43,10 @@ class TestCifar10Test(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 9)
......@@ -55,8 +59,10 @@ class TestCifar100Train(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 50000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 99)
......@@ -69,8 +75,10 @@ class TestCifar100Test(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1)
self.assertTrue(data.shape[0] == 3072)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 99)
......
......@@ -103,12 +103,14 @@ class TestMNISTTest(unittest.TestCase):
class TestMNISTTrain(unittest.TestCase):
def test_main(self):
mnist = MNIST(mode='train', chw_format=False)
mnist = MNIST(mode='train')
self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)):
image, label = mnist[i]
self.assertTrue(image.shape[0] == 784)
self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import six
import numpy as np
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
......@@ -88,6 +89,8 @@ class UCIHousing(Dataset):
# read dataset into memory
self._load_data()
self.dtype = paddle.get_default_dtype()
def _load_data(self, feature_num=14, ratio=0.8):
data = np.fromfile(self.data_file, sep=' ')
data = data.reshape(data.shape[0] // feature_num, feature_num)
......@@ -103,7 +106,8 @@ class UCIHousing(Dataset):
def __getitem__(self, idx):
data = self.data[idx]
return np.array(data[:-1]), np.array(data[-1:])
return np.array(data[:-1]).astype(self.dtype), \
np.array(data[-1:]).astype(self.dtype)
def __len__(self):
return len(self.data)
......@@ -139,6 +139,7 @@ class Cifar10(Dataset):
def __getitem__(self, idx):
image, label = self.data[idx]
image = np.reshape(image, [3, 32, 32])
if self.transform is not None:
image = self.transform(image)
return image, label
......
......@@ -44,8 +44,6 @@ class MNIST(Dataset):
:attr:`download` is True. Default None
label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True
......@@ -70,14 +68,12 @@ class MNIST(Dataset):
def __init__(self,
image_path=None,
label_path=None,
chw_format=True,
mode='train',
transform=None,
download=True):
assert mode.lower() in ['train', 'test'], \
"mode should be 'train' or 'test', but got {}".format(mode)
self.mode = mode.lower()
self.chw_format = chw_format
self.image_path = image_path
if self.image_path is None:
assert download, "image_path is not set and downloading automatically is disabled"
......@@ -139,10 +135,6 @@ class MNIST(Dataset):
cols)).astype('float32')
offset_img += struct.calcsize(fmt_images)
images = images / 255.0
images = images * 2.0
images = images - 1.0
for i in range(buffer_size):
self.images.append(images[i, :])
self.labels.append(
......@@ -150,8 +142,7 @@ class MNIST(Dataset):
def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx]
if self.chw_format:
image = np.reshape(image, [1, 28, 28])
image = np.reshape(image, [1, 28, 28])
if self.transform is not None:
image = self.transform(image)
return image, label
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册