提交 8cdf4919 编写于 作者: L lvmingfu

modify code formats in synchronization_training_and_evaluation

上级 1eec716b
......@@ -230,22 +230,22 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from mindspore.train.callback import Callback\n",
"\n",
"class EvalCallBack(Callback):\n",
" def __init__(self, model, eval_dataset, eval_per_epoch):\n",
" def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):\n",
" self.model = model\n",
" self.eval_dataset = eval_dataset\n",
" self.eval_per_epoch = eval_per_epoch\n",
" self.epoch_per_eval = epoch_per_eval\n",
" \n",
" def epoch_end(self, run_context):\n",
" cb_param = run_context.original_args()\n",
" cur_epoch = cb_param.cur_epoch_num\n",
" if cur_epoch % self.eval_per_epoch == 0:\n",
" acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True)\n",
" epoch_per_eval[\"epoch\"].append(cur_epoch)\n",
" epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n",
" acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)\n",
" self.epoch_per_eval[\"epoch\"].append(cur_epoch)\n",
" self.epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n",
" print(acc)\n"
]
},
......@@ -351,7 +351,6 @@
}
],
"source": [
"from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n",
"from mindspore.train import Model\n",
"from mindspore import context\n",
......@@ -368,21 +367,21 @@
" repeat_size = 1\n",
" network = LeNet5()\n",
" \n",
" train_data = create_dataset(train_data_path,repeat_size = repeat_size)\n",
" eval_data = create_dataset(eval_data_path,repeat_size = repeat_size)\n",
" train_data = create_dataset(train_data_path, repeat_size=repeat_size)\n",
" eval_data = create_dataset(eval_data_path, repeat_size=repeat_size)\n",
" \n",
" # define the loss function\n",
" net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
" # define the optimizer\n",
" net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n",
" config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)\n",
" ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\",directory=ckpt_save_dir, config=config_ck)\n",
" ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", directory=ckpt_save_dir, config=config_ck)\n",
" model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
" \n",
" epoch_per_eval = {\"epoch\":[],\"acc\":[]}\n",
" eval_cb = EvalCallBack(model,eval_data,eval_per_epoch)\n",
" epoch_per_eval = {\"epoch\": [], \"acc\": []}\n",
" eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)\n",
" \n",
" model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb],\n",
" model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],\n",
" dataset_sink_mode=True)"
]
},
......@@ -441,11 +440,13 @@
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def eval_show(epoch_per_eval):\n",
" plt.xlabel(\"epoch number\")\n",
" plt.ylabel(\"Model accuracy\")\n",
" plt.title(\"Model accuracy variation chart\")\n",
" plt.plot(epoch_per_eval[\"epoch\"],epoch_per_eval[\"acc\"],\"red\")\n",
" plt.plot(epoch_per_eval[\"epoch\"], epoch_per_eval[\"acc\"], \"red\")\n",
" plt.show()\n",
" \n",
"eval_show(epoch_per_eval)"
......
# 同步训练和验证模型
# 同步训练和验证模型
`Ascend` `GPU` `CPU` `初级` `中级` `高级` `模型导出` `模型训练`
<!-- TOC -->
......@@ -41,25 +43,25 @@
- `model`:即是MindSpore中的`Model`函数。
- `eval_dataset`:验证数据集。
- `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch":[],"acc":[]}`
- `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch": [], "acc": []}`
```python
import matplotlib.pyplot as plt
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
def __init__(self, model, eval_dataset, eval_per_epoch):
def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):
self.model = model
self.eval_dataset = eval_dataset
self.eval_per_epoch = eval_per_epoch
self.epoch_per_eval = epoch_per_eval
def epoch_end(self, run_context):
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
if cur_epoch % self.eval_per_epoch == 0:
acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True)
epoch_per_eval["epoch"].append(cur_epoch)
epoch_per_eval["acc"].append(acc["Accuracy"])
acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)
self.epoch_per_eval["epoch"].append(cur_epoch)
self.epoch_per_eval["acc"].append(acc["Accuracy"])
print(acc)
```
......@@ -79,12 +81,10 @@ class EvalCallBack(Callback):
- `epoch_per_eval`:定义收集`epoch`数和对应模型精度信息的字典。
```python
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore import context
from mindspore.nn.metrics import Accuracy
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
......@@ -98,10 +98,10 @@ if __name__ == "__main__":
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",directory=ckpt_save_dir, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
epoch_per_eval = {"epoch":[],"acc":[]}
eval_cb = EvalCallBack(model,eval_data,eval_per_epoch)
epoch_per_eval = {"epoch": [], "acc": []}
eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)
model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb],
model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],
dataset_sink_mode=True)
```
......@@ -152,11 +152,13 @@ lenet_ckpt
```python
import matplotlib.pyplot as plt
def eval_show(epoch_per_eval):
plt.xlabel("epoch number")
plt.ylabel("Model accuracy")
plt.title("Model accuracy variation chart")
plt.plot(epoch_per_eval["epoch"],epoch_per_eval["acc"],"red")
plt.plot(epoch_per_eval["epoch"], epoch_per_eval["acc"], "red")
plt.show()
eval_show(epoch_per_eval)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册