提交 8cdf4919 编写于 作者: L lvmingfu

modify code formats in synchronization_training_and_evaluation

上级 1eec716b
...@@ -230,22 +230,22 @@ ...@@ -230,22 +230,22 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import matplotlib.pyplot as plt\n",
"from mindspore.train.callback import Callback\n", "from mindspore.train.callback import Callback\n",
"\n", "\n",
"class EvalCallBack(Callback):\n", "class EvalCallBack(Callback):\n",
" def __init__(self, model, eval_dataset, eval_per_epoch):\n", " def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):\n",
" self.model = model\n", " self.model = model\n",
" self.eval_dataset = eval_dataset\n", " self.eval_dataset = eval_dataset\n",
" self.eval_per_epoch = eval_per_epoch\n", " self.eval_per_epoch = eval_per_epoch\n",
" self.epoch_per_eval = epoch_per_eval\n",
" \n", " \n",
" def epoch_end(self, run_context):\n", " def epoch_end(self, run_context):\n",
" cb_param = run_context.original_args()\n", " cb_param = run_context.original_args()\n",
" cur_epoch = cb_param.cur_epoch_num\n", " cur_epoch = cb_param.cur_epoch_num\n",
" if cur_epoch % self.eval_per_epoch == 0:\n", " if cur_epoch % self.eval_per_epoch == 0:\n",
" acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True)\n", " acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)\n",
" epoch_per_eval[\"epoch\"].append(cur_epoch)\n", " self.epoch_per_eval[\"epoch\"].append(cur_epoch)\n",
" epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n", " self.epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n",
" print(acc)\n" " print(acc)\n"
] ]
}, },
...@@ -351,7 +351,6 @@ ...@@ -351,7 +351,6 @@
} }
], ],
"source": [ "source": [
"from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n",
"from mindspore.train import Model\n", "from mindspore.train import Model\n",
"from mindspore import context\n", "from mindspore import context\n",
...@@ -368,21 +367,21 @@ ...@@ -368,21 +367,21 @@
" repeat_size = 1\n", " repeat_size = 1\n",
" network = LeNet5()\n", " network = LeNet5()\n",
" \n", " \n",
" train_data = create_dataset(train_data_path,repeat_size = repeat_size)\n", " train_data = create_dataset(train_data_path, repeat_size=repeat_size)\n",
" eval_data = create_dataset(eval_data_path,repeat_size = repeat_size)\n", " eval_data = create_dataset(eval_data_path, repeat_size=repeat_size)\n",
" \n", " \n",
" # define the loss function\n", " # define the loss function\n",
" net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", " net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
" # define the optimizer\n", " # define the optimizer\n",
" net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n", " net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n",
" config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)\n", " config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)\n",
" ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\",directory=ckpt_save_dir, config=config_ck)\n", " ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", directory=ckpt_save_dir, config=config_ck)\n",
" model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n", " model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
" \n", " \n",
" epoch_per_eval = {\"epoch\":[],\"acc\":[]}\n", " epoch_per_eval = {\"epoch\": [], \"acc\": []}\n",
" eval_cb = EvalCallBack(model,eval_data,eval_per_epoch)\n", " eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)\n",
" \n", " \n",
" model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb],\n", " model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],\n",
" dataset_sink_mode=True)" " dataset_sink_mode=True)"
] ]
}, },
...@@ -441,11 +440,13 @@ ...@@ -441,11 +440,13 @@
} }
], ],
"source": [ "source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def eval_show(epoch_per_eval):\n", "def eval_show(epoch_per_eval):\n",
" plt.xlabel(\"epoch number\")\n", " plt.xlabel(\"epoch number\")\n",
" plt.ylabel(\"Model accuracy\")\n", " plt.ylabel(\"Model accuracy\")\n",
" plt.title(\"Model accuracy variation chart\")\n", " plt.title(\"Model accuracy variation chart\")\n",
" plt.plot(epoch_per_eval[\"epoch\"],epoch_per_eval[\"acc\"],\"red\")\n", " plt.plot(epoch_per_eval[\"epoch\"], epoch_per_eval[\"acc\"], \"red\")\n",
" plt.show()\n", " plt.show()\n",
" \n", " \n",
"eval_show(epoch_per_eval)" "eval_show(epoch_per_eval)"
......
# 同步训练和验证模型 # 同步训练和验证模型
`Ascend` `GPU` `CPU` `初级` `中级` `高级` `模型导出` `模型训练`
<!-- TOC --> <!-- TOC -->
...@@ -41,25 +43,25 @@ ...@@ -41,25 +43,25 @@
- `model`:即是MindSpore中的`Model`函数。 - `model`:即是MindSpore中的`Model`函数。
- `eval_dataset`:验证数据集。 - `eval_dataset`:验证数据集。
- `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch":[],"acc":[]}` - `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch": [], "acc": []}`
```python ```python
import matplotlib.pyplot as plt
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
class EvalCallBack(Callback): class EvalCallBack(Callback):
def __init__(self, model, eval_dataset, eval_per_epoch): def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):
self.model = model self.model = model
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.eval_per_epoch = eval_per_epoch self.eval_per_epoch = eval_per_epoch
self.epoch_per_eval = epoch_per_eval
def epoch_end(self, run_context): def epoch_end(self, run_context):
cb_param = run_context.original_args() cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num cur_epoch = cb_param.cur_epoch_num
if cur_epoch % self.eval_per_epoch == 0: if cur_epoch % self.eval_per_epoch == 0:
acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True) acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)
epoch_per_eval["epoch"].append(cur_epoch) self.epoch_per_eval["epoch"].append(cur_epoch)
epoch_per_eval["acc"].append(acc["Accuracy"]) self.epoch_per_eval["acc"].append(acc["Accuracy"])
print(acc) print(acc)
``` ```
...@@ -79,12 +81,10 @@ class EvalCallBack(Callback): ...@@ -79,12 +81,10 @@ class EvalCallBack(Callback):
- `epoch_per_eval`:定义收集`epoch`数和对应模型精度信息的字典。 - `epoch_per_eval`:定义收集`epoch`数和对应模型精度信息的字典。
```python ```python
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model from mindspore.train import Model
from mindspore import context from mindspore import context
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
if __name__ == "__main__": if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
...@@ -98,10 +98,10 @@ if __name__ == "__main__": ...@@ -98,10 +98,10 @@ if __name__ == "__main__":
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",directory=ckpt_save_dir, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",directory=ckpt_save_dir, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
epoch_per_eval = {"epoch":[],"acc":[]} epoch_per_eval = {"epoch": [], "acc": []}
eval_cb = EvalCallBack(model,eval_data,eval_per_epoch) eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)
model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb], model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],
dataset_sink_mode=True) dataset_sink_mode=True)
``` ```
...@@ -152,11 +152,13 @@ lenet_ckpt ...@@ -152,11 +152,13 @@ lenet_ckpt
```python ```python
import matplotlib.pyplot as plt
def eval_show(epoch_per_eval): def eval_show(epoch_per_eval):
plt.xlabel("epoch number") plt.xlabel("epoch number")
plt.ylabel("Model accuracy") plt.ylabel("Model accuracy")
plt.title("Model accuracy variation chart") plt.title("Model accuracy variation chart")
plt.plot(epoch_per_eval["epoch"],epoch_per_eval["acc"],"red") plt.plot(epoch_per_eval["epoch"], epoch_per_eval["acc"], "red")
plt.show() plt.show()
eval_show(epoch_per_eval) eval_show(epoch_per_eval)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册