From 7bf6e2314b2937f4f7a2857ccdefb78856a5ffe5 Mon Sep 17 00:00:00 2001 From: LinkHS <381082014@qq.com> Date: Sat, 4 Nov 2017 17:39:37 -0500 Subject: [PATCH] make user can load existing dataset from any folder (#92) * make user can load existing dataset from any folder * fix a typo --- chapter_convolutional-neural-networks/cnn-scratch.md | 2 +- utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/chapter_convolutional-neural-networks/cnn-scratch.md b/chapter_convolutional-neural-networks/cnn-scratch.md index 9fe551d..013db39 100644 --- a/chapter_convolutional-neural-networks/cnn-scratch.md +++ b/chapter_convolutional-neural-networks/cnn-scratch.md @@ -58,7 +58,7 @@ out = nd.Convolution(data, w, b, kernel=w.shape[2:], num_filter=w.shape[0]) print('input:', data, '\n\nweight:', w, '\n\nbias:', b, '\n\noutput:', out) ``` -当输入需要多通道时,每个输出通道有对应权重,然后每个通道上做卷积。 +当输出需要多通道时,每个输出通道有对应权重,然后每个通道上做卷积。 $$conv(data, w, b)[:,i,:,:] = conv(data, w[i,:,:,:], b[i])$$ diff --git a/utils.py b/utils.py index db2566a..ee277d8 100644 --- a/utils.py +++ b/utils.py @@ -38,7 +38,7 @@ class DataLoader(object): def __len__(self): return len(self.dataset)//self.batch_size -def load_data_fashion_mnist(batch_size, resize=None): +def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"): """download the fashion mnist dataest and then load into memory""" def transform_mnist(data, label): # transform a batch of examples @@ -50,8 +50,8 @@ def load_data_fashion_mnist(batch_size, resize=None): data = new_data # change data from batch x height x weight x channel to batch x channel x height x weight return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32') - mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform_mnist) - mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform_mnist) + mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=transform_mnist) + mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=transform_mnist) train_data = DataLoader(mnist_train, batch_size, shuffle=True) test_data = DataLoader(mnist_test, batch_size, shuffle=False) return (train_data, test_data) -- GitLab