未验证 提交 f3b4a64a 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix CIFAR MNIST UCIHousing dataset. test=develop (#27368)

* fix CIFAR & MNIST dataset. test=develop
上级 f936adbd
...@@ -27,8 +27,10 @@ class TestCifar10Train(unittest.TestCase): ...@@ -27,8 +27,10 @@ class TestCifar10Train(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 50000) idx = np.random.randint(0, 50000)
data, label = cifar[idx] data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3072) self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
...@@ -41,8 +43,10 @@ class TestCifar10Test(unittest.TestCase): ...@@ -41,8 +43,10 @@ class TestCifar10Test(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 10000) idx = np.random.randint(0, 10000)
data, label = cifar[idx] data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3072) self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
...@@ -55,8 +59,10 @@ class TestCifar100Train(unittest.TestCase): ...@@ -55,8 +59,10 @@ class TestCifar100Train(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 50000) idx = np.random.randint(0, 50000)
data, label = cifar[idx] data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3072) self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
...@@ -69,8 +75,10 @@ class TestCifar100Test(unittest.TestCase): ...@@ -69,8 +75,10 @@ class TestCifar100Test(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 10000) idx = np.random.randint(0, 10000)
data, label = cifar[idx] data, label = cifar[idx]
self.assertTrue(len(data.shape) == 1) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3072) self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
......
...@@ -103,12 +103,14 @@ class TestMNISTTest(unittest.TestCase): ...@@ -103,12 +103,14 @@ class TestMNISTTest(unittest.TestCase):
class TestMNISTTrain(unittest.TestCase): class TestMNISTTrain(unittest.TestCase):
def test_main(self): def test_main(self):
mnist = MNIST(mode='train', chw_format=False) mnist = MNIST(mode='train')
self.assertTrue(len(mnist) == 60000) self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)): for i in range(len(mnist)):
image, label = mnist[i] image, label = mnist[i]
self.assertTrue(image.shape[0] == 784) self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import six import six
import numpy as np import numpy as np
import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
...@@ -88,6 +89,8 @@ class UCIHousing(Dataset): ...@@ -88,6 +89,8 @@ class UCIHousing(Dataset):
# read dataset into memory # read dataset into memory
self._load_data() self._load_data()
self.dtype = paddle.get_default_dtype()
def _load_data(self, feature_num=14, ratio=0.8): def _load_data(self, feature_num=14, ratio=0.8):
data = np.fromfile(self.data_file, sep=' ') data = np.fromfile(self.data_file, sep=' ')
data = data.reshape(data.shape[0] // feature_num, feature_num) data = data.reshape(data.shape[0] // feature_num, feature_num)
...@@ -103,7 +106,8 @@ class UCIHousing(Dataset): ...@@ -103,7 +106,8 @@ class UCIHousing(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
data = self.data[idx] data = self.data[idx]
return np.array(data[:-1]), np.array(data[-1:]) return np.array(data[:-1]).astype(self.dtype), \
np.array(data[-1:]).astype(self.dtype)
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
...@@ -139,6 +139,7 @@ class Cifar10(Dataset): ...@@ -139,6 +139,7 @@ class Cifar10(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
image, label = self.data[idx] image, label = self.data[idx]
image = np.reshape(image, [3, 32, 32])
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
return image, label return image, label
......
...@@ -44,8 +44,6 @@ class MNIST(Dataset): ...@@ -44,8 +44,6 @@ class MNIST(Dataset):
:attr:`download` is True. Default None :attr:`download` is True. Default None
label_path(str): path to label file, can be set None if label_path(str): path to label file, can be set None if
:attr:`download` is True. Default None :attr:`download` is True. Default None
chw_format(bool): If set True, the output shape is [1, 28, 28],
otherwise, output shape is [1, 784]. Default True.
mode(str): 'train' or 'test' mode. Default 'train'. mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True :attr:`image_path` :attr:`label_path` is not set. Default True
...@@ -70,14 +68,12 @@ class MNIST(Dataset): ...@@ -70,14 +68,12 @@ class MNIST(Dataset):
def __init__(self, def __init__(self,
image_path=None, image_path=None,
label_path=None, label_path=None,
chw_format=True,
mode='train', mode='train',
transform=None, transform=None,
download=True): download=True):
assert mode.lower() in ['train', 'test'], \ assert mode.lower() in ['train', 'test'], \
"mode should be 'train' or 'test', but got {}".format(mode) "mode should be 'train' or 'test', but got {}".format(mode)
self.mode = mode.lower() self.mode = mode.lower()
self.chw_format = chw_format
self.image_path = image_path self.image_path = image_path
if self.image_path is None: if self.image_path is None:
assert download, "image_path is not set and downloading automatically is disabled" assert download, "image_path is not set and downloading automatically is disabled"
...@@ -139,10 +135,6 @@ class MNIST(Dataset): ...@@ -139,10 +135,6 @@ class MNIST(Dataset):
cols)).astype('float32') cols)).astype('float32')
offset_img += struct.calcsize(fmt_images) offset_img += struct.calcsize(fmt_images)
images = images / 255.0
images = images * 2.0
images = images - 1.0
for i in range(buffer_size): for i in range(buffer_size):
self.images.append(images[i, :]) self.images.append(images[i, :])
self.labels.append( self.labels.append(
...@@ -150,7 +142,6 @@ class MNIST(Dataset): ...@@ -150,7 +142,6 @@ class MNIST(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx] image, label = self.images[idx], self.labels[idx]
if self.chw_format:
image = np.reshape(image, [1, 28, 28]) image = np.reshape(image, [1, 28, 28])
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册