Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CoolBran
d2l-zh
提交
8e7dff30
D
d2l-zh
项目概览
CoolBran
/
d2l-zh
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
d2l-zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
8e7dff30
编写于
2月 27, 2018
作者:
A
Aston Zhang
提交者:
GitHub
2月 27, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #212 from astonzhang/nin
Add transform=None in utils.DataLoader
上级
f111a6c2
f1203616
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
6 addition
and
5 deletion
+6
-5
utils.py
utils.py
+6
-5
未找到文件。
utils.py
浏览文件 @
8e7dff30
...
@@ -16,7 +16,7 @@ class DataLoader(object):
...
@@ -16,7 +16,7 @@ class DataLoader(object):
time. But the limits are 1) all examples in dataset have the same shape, 2)
time. But the limits are 1) all examples in dataset have the same shape, 2)
data transfomer needs to process multiple examples at each time
data transfomer needs to process multiple examples at each time
"""
"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
,
transform
):
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
,
transform
=
None
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
...
@@ -47,7 +47,7 @@ class DataLoader(object):
...
@@ -47,7 +47,7 @@ class DataLoader(object):
def
load_data_fashion_mnist
(
batch_size
,
resize
=
None
,
root
=
"~/.mxnet/datasets/fashion-mnist"
):
def
load_data_fashion_mnist
(
batch_size
,
resize
=
None
,
root
=
"~/.mxnet/datasets/fashion-mnist"
):
"""download the fashion mnist dataest and then load into memory"""
"""download the fashion mnist dataest and then load into memory"""
def
transform_mnist
(
data
,
label
):
def
transform_mnist
(
data
,
label
):
#
transform a batch of examples
#
Transform a batch of examples.
if
resize
:
if
resize
:
n
=
data
.
shape
[
0
]
n
=
data
.
shape
[
0
]
new_data
=
nd
.
zeros
((
n
,
resize
,
resize
,
data
.
shape
[
3
]))
new_data
=
nd
.
zeros
((
n
,
resize
,
resize
,
data
.
shape
[
3
]))
...
@@ -56,11 +56,12 @@ def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fas
...
@@ -56,11 +56,12 @@ def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fas
data
=
new_data
data
=
new_data
# change data from batch x height x width x channel to batch x channel x height x width
# change data from batch x height x width x channel to batch x channel x height x width
return
nd
.
transpose
(
data
.
astype
(
'float32'
),
(
0
,
3
,
1
,
2
))
/
255
,
label
.
astype
(
'float32'
)
return
nd
.
transpose
(
data
.
astype
(
'float32'
),
(
0
,
3
,
1
,
2
))
/
255
,
label
.
astype
(
'float32'
)
mnist_train
=
gluon
.
data
.
vision
.
FashionMNIST
(
root
=
root
,
train
=
True
,
transform
=
None
)
mnist_train
=
gluon
.
data
.
vision
.
FashionMNIST
(
root
=
root
,
train
=
True
,
transform
=
None
)
mnist_test
=
gluon
.
data
.
vision
.
FashionMNIST
(
root
=
root
,
train
=
False
,
transform
=
None
)
mnist_test
=
gluon
.
data
.
vision
.
FashionMNIST
(
root
=
root
,
train
=
False
,
transform
=
None
)
train_data
=
DataLoader
(
mnist_train
,
batch_size
,
shuffle
=
True
,
transform
=
transform_mnist
)
# Transform later to avoid memory explosion.
test_data
=
DataLoader
(
mnist_test
,
batch_size
,
shuffle
=
False
,
transform
=
transform_mnist
)
train_data
=
DataLoader
(
mnist_train
,
batch_size
,
shuffle
=
True
,
transform
=
transform_mnist
)
test_data
=
DataLoader
(
mnist_test
,
batch_size
,
shuffle
=
False
,
transform
=
transform_mnist
)
return
(
train_data
,
test_data
)
return
(
train_data
,
test_data
)
def
try_gpu
():
def
try_gpu
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录