未验证 提交 8e7dff30 编写于 作者: A Aston Zhang 提交者: GitHub

Merge pull request #212 from astonzhang/nin

Add transform=None in utils.DataLoader
...@@ -16,7 +16,7 @@ class DataLoader(object): ...@@ -16,7 +16,7 @@ class DataLoader(object):
time. But the limits are 1) all examples in dataset have the same shape, 2) time. But the limits are 1) all examples in dataset have the same shape, 2)
data transfomer needs to process multiple examples at each time data transfomer needs to process multiple examples at each time
""" """
def __init__(self, dataset, batch_size, shuffle, transform): def __init__(self, dataset, batch_size, shuffle, transform=None):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
...@@ -47,7 +47,7 @@ class DataLoader(object): ...@@ -47,7 +47,7 @@ class DataLoader(object):
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"): def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
"""download the fashion mnist dataest and then load into memory""" """download the fashion mnist dataest and then load into memory"""
def transform_mnist(data, label): def transform_mnist(data, label):
# transform a batch of examples # Transform a batch of examples.
if resize: if resize:
n = data.shape[0] n = data.shape[0]
new_data = nd.zeros((n, resize, resize, data.shape[3])) new_data = nd.zeros((n, resize, resize, data.shape[3]))
...@@ -56,11 +56,12 @@ def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fas ...@@ -56,11 +56,12 @@ def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fas
data = new_data data = new_data
# change data from batch x height x width x channel to batch x channel x height x width # change data from batch x height x width x channel to batch x channel x height x width
return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32') return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')
mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None) mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None) mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform = transform_mnist) # Transform later to avoid memory explosion.
test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform = transform_mnist) train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
return (train_data, test_data) return (train_data, test_data)
def try_gpu(): def try_gpu():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册