diff --git a/utils.py b/utils.py index 325bc3147a7338a02c4ead5fa644f757ea1c9a5c..2c90040a00f6b419bbafec19b23c145e2b178aaa 100644 --- a/utils.py +++ b/utils.py @@ -49,7 +49,7 @@ def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fas for i in range(n): new_data[i] = image.imresize(data[i], resize, resize) data = new_data - # change data from batch x height x weight x channel to batch x channel x height x weight + # change data from batch x height x width x channel to batch x channel x height x width return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32') mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=transform_mnist) mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=transform_mnist)