未验证 提交 f111a6c2 编写于 作者: A Aston Zhang 提交者: GitHub

Merge pull request #140 from starsdeep/fix_for_resize_to_224_when_training_AlexNet

fix for resize to 224 when training AlextNet in alexnet-gluon.md
......@@ -16,10 +16,11 @@ class DataLoader(object):
time. But the limits are 1) all examples in dataset have the same shape, 2)
data transfomer needs to process multiple examples at each time
"""
def __init__(self, dataset, batch_size, shuffle):
def __init__(self, dataset, batch_size, shuffle, transform):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.transform = transform
def __iter__(self):
data = self.dataset[:]
......@@ -33,8 +34,12 @@ class DataLoader(object):
y = nd.array(y.asnumpy()[idx])
for i in range(n//self.batch_size):
yield (X[i*self.batch_size:(i+1)*self.batch_size],
y[i*self.batch_size:(i+1)*self.batch_size])
if self.transform is not None:
yield self.transform(X[i*self.batch_size:(i+1)*self.batch_size],
y[i*self.batch_size:(i+1)*self.batch_size])
else:
yield (X[i*self.batch_size:(i+1)*self.batch_size],
y[i*self.batch_size:(i+1)*self.batch_size])
def __len__(self):
return len(self.dataset)//self.batch_size
......@@ -51,10 +56,11 @@ def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fas
data = new_data
# change data from batch x height x width x channel to batch x channel x height x width
return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')
mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=transform_mnist)
mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=transform_mnist)
train_data = DataLoader(mnist_train, batch_size, shuffle=True)
test_data = DataLoader(mnist_test, batch_size, shuffle=False)
mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform = transform_mnist)
test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform = transform_mnist)
return (train_data, test_data)
def try_gpu():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册