Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
course
提交
90d94759
C
course
项目概览
MindSpore
/
course
通知
4
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
course
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
90d94759
编写于
5月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3 简化experiment_1的训练代码,去掉部分高级功能,展示简介高效的效果
Merge pull request !3 from dyonghan/experiment_1
上级
3b5e0052
cf5770a6
变更
3
展开全部
显示空白变更内容
内联
并排
Showing
3 changed file
with
35 addition
and
192 deletion
+35
-192
.gitignore
.gitignore
+1
-0
experiment_1/1-LeNet5_MNIST.ipynb
experiment_1/1-LeNet5_MNIST.ipynb
+24
-161
experiment_1/main.py
experiment_1/main.py
+10
-31
未找到文件。
.gitignore
浏览文件 @
90d94759
...
@@ -134,3 +134,4 @@ dmypy.json
...
@@ -134,3 +134,4 @@ dmypy.json
# IDE
# IDE
.idea/
.idea/
.vscode/
experiment_1/1-LeNet5_MNIST.ipynb
浏览文件 @
90d94759
此差异已折叠。
点击以展开。
experiment_1/main.py
浏览文件 @
90d94759
...
@@ -3,19 +3,14 @@
...
@@ -3,19 +3,14 @@
import
os
import
os
# os.environ['DEVICE_ID'] = '0'
# os.environ['DEVICE_ID'] = '0'
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
mindspore
as
ms
import
mindspore
as
ms
import
mindspore.context
as
context
import
mindspore.context
as
context
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.dataset.transforms.vision.c_transforms
as
CV
import
mindspore.dataset.transforms.vision.c_transforms
as
CV
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore
import
nn
from
mindspore
import
nn
,
Tensor
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train.callback
import
LossMonitor
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'Ascend'
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'Ascend'
)
...
@@ -26,26 +21,16 @@ DATA_DIR_TEST = "MNIST/test" # 测试集信息
...
@@ -26,26 +21,16 @@ DATA_DIR_TEST = "MNIST/test" # 测试集信息
def
create_dataset
(
training
=
True
,
num_epoch
=
1
,
batch_size
=
32
,
resize
=
(
32
,
32
),
def
create_dataset
(
training
=
True
,
num_epoch
=
1
,
batch_size
=
32
,
resize
=
(
32
,
32
),
rescale
=
1
/
(
255
*
0.3081
),
shift
=-
0.1307
/
0.3081
,
buffer_size
=
64
):
rescale
=
1
/
(
255
*
0.3081
),
shift
=-
0.1307
/
0.3081
,
buffer_size
=
64
):
ds
=
ms
.
dataset
.
MnistDataset
(
DATA_DIR_TRAIN
if
training
else
DATA_DIR_TEST
)
ds
=
ms
.
dataset
.
MnistDataset
(
DATA_DIR_TRAIN
if
training
else
DATA_DIR_TEST
)
ds
=
ds
.
map
(
input_columns
=
"image"
,
operations
=
[
CV
.
Resize
(
resize
),
CV
.
Rescale
(
rescale
,
shift
),
CV
.
HWC2CHW
()])
# define map operations
resize_op
=
CV
.
Resize
(
resize
)
rescale_op
=
CV
.
Rescale
(
rescale
,
shift
)
hwc2chw_op
=
CV
.
HWC2CHW
()
# apply map operations on images
ds
=
ds
.
map
(
input_columns
=
"image"
,
operations
=
[
resize_op
,
rescale_op
,
hwc2chw_op
])
ds
=
ds
.
map
(
input_columns
=
"label"
,
operations
=
C
.
TypeCast
(
ms
.
int32
))
ds
=
ds
.
map
(
input_columns
=
"label"
,
operations
=
C
.
TypeCast
(
ms
.
int32
))
ds
=
ds
.
shuffle
(
buffer_size
=
buffer_size
).
batch
(
batch_size
,
drop_remainder
=
True
).
repeat
(
num_epoch
)
ds
=
ds
.
shuffle
(
buffer_size
=
buffer_size
)
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
repeat
(
num_epoch
)
return
ds
return
ds
class
LeNet
(
nn
.
Cell
):
class
LeNet
5
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
super
(
LeNet
5
,
self
).
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
5
,
stride
=
1
,
pad_mode
=
'valid'
)
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
5
,
stride
=
1
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
,
stride
=
1
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
,
stride
=
1
,
pad_mode
=
'valid'
)
...
@@ -70,26 +55,22 @@ class LeNet(nn.Cell):
...
@@ -70,26 +55,22 @@ class LeNet(nn.Cell):
return
output
return
output
LOOP_SINK
=
context
.
get_context
(
'enable_loop_sink'
)
def
test_train
(
lr
=
0.01
,
momentum
=
0.9
,
num_epoch
=
3
,
ckpt_name
=
"a_lenet"
):
def
test_train
(
lr
=
0.01
,
momentum
=
0.9
,
num_epoch
=
3
,
ckpt_name
=
"a_lenet"
):
ds_train
=
create_dataset
(
num_epoch
=
num_epoch
)
ds_train
=
create_dataset
(
num_epoch
=
num_epoch
)
ds_eval
=
create_dataset
(
training
=
False
)
ds_eval
=
create_dataset
(
training
=
False
)
steps_per_epoch
=
ds_train
.
get_dataset_size
()
net
=
LeNet
()
net
=
LeNet
5
()
loss
=
nn
.
loss
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
'mean'
)
loss
=
nn
.
loss
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
'mean'
)
opt
=
nn
.
Momentum
(
net
.
trainable_params
(),
lr
,
momentum
)
opt
=
nn
.
Momentum
(
net
.
trainable_params
(),
lr
,
momentum
)
ckpt_cfg
=
CheckpointConfig
(
save_checkpoint_steps
=
steps_per_epoch
,
keep_checkpoint_max
=
5
)
loss_cb
=
LossMonitor
(
per_print_times
=
1
)
ckpt_cb
=
ModelCheckpoint
(
prefix
=
ckpt_name
,
config
=
ckpt_cfg
)
loss_cb
=
LossMonitor
(
per_print_times
=
1
if
LOOP_SINK
else
steps_per_epoch
)
model
=
Model
(
net
,
loss
,
opt
,
metrics
=
{
'acc'
,
'loss'
})
model
=
Model
(
net
,
loss
,
opt
,
metrics
=
{
'acc'
,
'loss'
})
model
.
train
(
num_epoch
,
ds_train
,
callbacks
=
[
ckpt_cb
,
loss_cb
],
dataset_sink_mode
=
True
)
model
.
train
(
num_epoch
,
ds_train
,
callbacks
=
[
loss_cb
]
)
metrics
=
model
.
eval
(
ds_eval
)
metrics
=
model
.
eval
(
ds_eval
)
print
(
'Metrics:'
,
metrics
)
print
(
'Metrics:'
,
metrics
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -101,6 +82,4 @@ if __name__ == "__main__":
...
@@ -101,6 +82,4 @@ if __name__ == "__main__":
import
moxing
as
mox
import
moxing
as
mox
mox
.
file
.
copy_parallel
(
src_url
=
args
.
data_url
,
dst_url
=
'MNIST/'
)
mox
.
file
.
copy_parallel
(
src_url
=
args
.
data_url
,
dst_url
=
'MNIST/'
)
os
.
system
(
'rm -f *.ckpt *.ir *.meta'
)
# 清理旧的运行文件
test_train
()
test_train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录