未验证 提交 8e7dff30 编写于 作者: A Aston Zhang 提交者: GitHub

Merge pull request #212 from astonzhang/nin

Add transform=None in utils.DataLoader
......@@ -16,7 +16,7 @@ class DataLoader(object):
time. But the limits are 1) all examples in dataset have the same shape, 2)
data transfomer needs to process multiple examples at each time
"""
def __init__(self, dataset, batch_size, shuffle, transform):
def __init__(self, dataset, batch_size, shuffle, transform=None):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
......@@ -47,7 +47,7 @@ class DataLoader(object):
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
"""download the fashion mnist dataest and then load into memory"""
def transform_mnist(data, label):
# transform a batch of examples
# Transform a batch of examples.
if resize:
n = data.shape[0]
new_data = nd.zeros((n, resize, resize, data.shape[3]))
......@@ -56,11 +56,12 @@ def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fas
data = new_data
# change data from batch x height x width x channel to batch x channel x height x width
return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')
mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform = transform_mnist)
test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform = transform_mnist)
# Transform later to avoid memory explosion.
train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
return (train_data, test_data)
def try_gpu():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册