提交 7bf6e231 编写于 作者: 哆啦壹萌's avatar 哆啦壹萌 提交者: Mu Li

make user can load existing dataset from any folder (#92)

* make user can load existing dataset from any folder

* fix a typo
上级 a64c6e73
......@@ -58,7 +58,7 @@ out = nd.Convolution(data, w, b, kernel=w.shape[2:], num_filter=w.shape[0])
print('input:', data, '\n\nweight:', w, '\n\nbias:', b, '\n\noutput:', out)
```
当输需要多通道时,每个输出通道有对应权重,然后每个通道上做卷积。
当输需要多通道时,每个输出通道有对应权重,然后每个通道上做卷积。
$$conv(data, w, b)[:,i,:,:] = conv(data, w[i,:,:,:], b[i])$$
......
......@@ -38,7 +38,7 @@ class DataLoader(object):
def __len__(self):
return len(self.dataset)//self.batch_size
def load_data_fashion_mnist(batch_size, resize=None):
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
"""download the fashion mnist dataest and then load into memory"""
def transform_mnist(data, label):
# transform a batch of examples
......@@ -50,8 +50,8 @@ def load_data_fashion_mnist(batch_size, resize=None):
data = new_data
# change data from batch x height x weight x channel to batch x channel x height x weight
return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')
mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform_mnist)
mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform_mnist)
mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=transform_mnist)
mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=transform_mnist)
train_data = DataLoader(mnist_train, batch_size, shuffle=True)
test_data = DataLoader(mnist_test, batch_size, shuffle=False)
return (train_data, test_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册