Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
docs
提交
8cdf4919
D
docs
项目概览
MindSpore
/
docs
通知
4
Star
2
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
docs
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8cdf4919
编写于
8月 28, 2020
作者:
L
lvmingfu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify code formats in synchronization_training_and_evaluation
上级
1eec716b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
26 deletion
+29
-26
tutorials/notebook/synchronization_training_and_evaluation.ipynb
...ls/notebook/synchronization_training_and_evaluation.ipynb
+14
-13
tutorials/source_zh_cn/advanced_use/synchronization_training_and_evaluation.md
...n/advanced_use/synchronization_training_and_evaluation.md
+15
-13
未找到文件。
tutorials/notebook/synchronization_training_and_evaluation.ipynb
浏览文件 @
8cdf4919
...
...
@@ -230,22 +230,22 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from mindspore.train.callback import Callback\n",
"\n",
"class EvalCallBack(Callback):\n",
" def __init__(self, model, eval_dataset, eval_per_epoch):\n",
" def __init__(self, model, eval_dataset, eval_per_epoch
, epoch_per_eval
):\n",
" self.model = model\n",
" self.eval_dataset = eval_dataset\n",
" self.eval_per_epoch = eval_per_epoch\n",
" self.epoch_per_eval = epoch_per_eval\n",
" \n",
" def epoch_end(self, run_context):\n",
" cb_param = run_context.original_args()\n",
" cur_epoch = cb_param.cur_epoch_num\n",
" if cur_epoch % self.eval_per_epoch == 0:\n",
" acc = self.model.eval(self.eval_dataset,
dataset_sink_mode =
True)\n",
" epoch_per_eval[\"epoch\"].append(cur_epoch)\n",
" epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n",
" acc = self.model.eval(self.eval_dataset,
dataset_sink_mode=
True)\n",
"
self.
epoch_per_eval[\"epoch\"].append(cur_epoch)\n",
"
self.
epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n",
" print(acc)\n"
]
},
...
...
@@ -351,7 +351,6 @@
}
],
"source": [
"from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n",
"from mindspore.train import Model\n",
"from mindspore import context\n",
...
...
@@ -368,21 +367,21 @@
" repeat_size = 1\n",
" network = LeNet5()\n",
" \n",
" train_data = create_dataset(train_data_path,
repeat_size =
repeat_size)\n",
" eval_data = create_dataset(eval_data_path,
repeat_size =
repeat_size)\n",
" train_data = create_dataset(train_data_path,
repeat_size=
repeat_size)\n",
" eval_data = create_dataset(eval_data_path,
repeat_size=
repeat_size)\n",
" \n",
" # define the loss function\n",
" net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
" # define the optimizer\n",
" net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n",
" config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)\n",
" ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\",directory=ckpt_save_dir, config=config_ck)\n",
" ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\",
directory=ckpt_save_dir, config=config_ck)\n",
" model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
" \n",
" epoch_per_eval = {\"epoch\":
[],\"acc\":
[]}\n",
" eval_cb = EvalCallBack(model,
eval_data,eval_per_epoch
)\n",
" epoch_per_eval = {\"epoch\":
[], \"acc\":
[]}\n",
" eval_cb = EvalCallBack(model,
eval_data, eval_per_epoch, epoch_per_eval
)\n",
" \n",
" model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb],\n",
" model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),
eval_cb],\n",
" dataset_sink_mode=True)"
]
},
...
...
@@ -441,11 +440,13 @@
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def eval_show(epoch_per_eval):\n",
" plt.xlabel(\"epoch number\")\n",
" plt.ylabel(\"Model accuracy\")\n",
" plt.title(\"Model accuracy variation chart\")\n",
" plt.plot(epoch_per_eval[\"epoch\"],
epoch_per_eval[\"acc\"],
\"red\")\n",
" plt.plot(epoch_per_eval[\"epoch\"],
epoch_per_eval[\"acc\"],
\"red\")\n",
" plt.show()\n",
" \n",
"eval_show(epoch_per_eval)"
...
...
tutorials/source_zh_cn/advanced_use/synchronization_training_and_evaluation.md
浏览文件 @
8cdf4919
# 同步训练和验证模型
# 同步训练和验证模型
`Ascend`
`GPU`
`CPU`
`初级`
`中级`
`高级`
`模型导出`
`模型训练`
<!-- TOC -->
...
...
@@ -41,25 +43,25 @@
-
`model`
:即是MindSpore中的
`Model`
函数。
-
`eval_dataset`
:验证数据集。
-
`epoch_per_eval`
:记录验证模型的精度和相应的epoch数,其数据形式为
`{"epoch":
[],"acc":
[]}`
。
-
`epoch_per_eval`
:记录验证模型的精度和相应的epoch数,其数据形式为
`{"epoch":
[], "acc":
[]}`
。
```
python
import
matplotlib.pyplot
as
plt
from
mindspore.train.callback
import
Callback
class
EvalCallBack
(
Callback
):
def
__init__
(
self
,
model
,
eval_dataset
,
eval_per_epoch
):
def
__init__
(
self
,
model
,
eval_dataset
,
eval_per_epoch
,
epoch_per_eval
):
self
.
model
=
model
self
.
eval_dataset
=
eval_dataset
self
.
eval_per_epoch
=
eval_per_epoch
self
.
epoch_per_eval
=
epoch_per_eval
def
epoch_end
(
self
,
run_context
):
cb_param
=
run_context
.
original_args
()
cur_epoch
=
cb_param
.
cur_epoch_num
if
cur_epoch
%
self
.
eval_per_epoch
==
0
:
acc
=
self
.
model
.
eval
(
self
.
eval_dataset
,
dataset_sink_mode
=
True
)
epoch_per_eval
[
"epoch"
].
append
(
cur_epoch
)
epoch_per_eval
[
"acc"
].
append
(
acc
[
"Accuracy"
])
acc
=
self
.
model
.
eval
(
self
.
eval_dataset
,
dataset_sink_mode
=
True
)
self
.
epoch_per_eval
[
"epoch"
].
append
(
cur_epoch
)
self
.
epoch_per_eval
[
"acc"
].
append
(
acc
[
"Accuracy"
])
print
(
acc
)
```
...
...
@@ -79,12 +81,10 @@ class EvalCallBack(Callback):
-
`epoch_per_eval`
:定义收集
`epoch`
数和对应模型精度信息的字典。
```
python
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train
import
Model
from
mindspore
import
context
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.loss
import
SoftmaxCrossEntropyWithLogits
if
__name__
==
"__main__"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
...
...
@@ -98,10 +98,10 @@ if __name__ == "__main__":
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
directory
=
ckpt_save_dir
,
config
=
config_ck
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
epoch_per_eval
=
{
"epoch"
:
[],
"acc"
:
[]}
eval_cb
=
EvalCallBack
(
model
,
eval_data
,
eval_per_epoch
)
epoch_per_eval
=
{
"epoch"
:
[],
"acc"
:
[]}
eval_cb
=
EvalCallBack
(
model
,
eval_data
,
eval_per_epoch
,
epoch_per_eval
)
model
.
train
(
epoch_size
,
train_data
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(
375
),
eval_cb
],
model
.
train
(
epoch_size
,
train_data
,
callbacks
=
[
ckpoint_cb
,
LossMonitor
(
375
),
eval_cb
],
dataset_sink_mode
=
True
)
```
...
...
@@ -152,11 +152,13 @@ lenet_ckpt
```
python
import
matplotlib.pyplot
as
plt
def
eval_show
(
epoch_per_eval
):
plt
.
xlabel
(
"epoch number"
)
plt
.
ylabel
(
"Model accuracy"
)
plt
.
title
(
"Model accuracy variation chart"
)
plt
.
plot
(
epoch_per_eval
[
"epoch"
],
epoch_per_eval
[
"acc"
],
"red"
)
plt
.
plot
(
epoch_per_eval
[
"epoch"
],
epoch_per_eval
[
"acc"
],
"red"
)
plt
.
show
()
eval_show
(
epoch_per_eval
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录