提交 e9a8806a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!773 synchronization_training_and_evaluation.md: Modify some error.

Merge pull request !773 from lvmingfu/master
...@@ -230,20 +230,20 @@ ...@@ -230,20 +230,20 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import matplotlib.pyplot as plt\n",
"from mindspore.train.callback import Callback\n", "from mindspore.train.callback import Callback\n",
"\n", "\n",
"class EvalCallBack(Callback):\n", "class EvalCallBack(Callback):\n",
" def __init__(self, model, eval_dataset, eval_per_epoch):\n", " def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):\n",
" self.model = model\n", " self.model = model\n",
" self.eval_dataset = eval_dataset\n", " self.eval_dataset = eval_dataset\n",
" self.eval_per_epoch = eval_per_epoch\n", " self.eval_per_epoch = eval_per_epoch\n",
" self.epoch_per_eval = epoch_per_eval\n",
" \n", " \n",
" def epoch_end(self, run_context):\n", " def epoch_end(self, run_context):\n",
" cb_param = run_context.original_args()\n", " cb_param = run_context.original_args()\n",
" cur_epoch = cb_param.cur_epoch_num\n", " cur_epoch = cb_param.cur_epoch_num\n",
" if cur_epoch % self.eval_per_epoch == 0:\n", " if cur_epoch % self.eval_per_epoch == 0:\n",
" acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True)\n", " acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)\n",
" epoch_per_eval[\"epoch\"].append(cur_epoch)\n", " epoch_per_eval[\"epoch\"].append(cur_epoch)\n",
" epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n", " epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n",
" print(acc)\n" " print(acc)\n"
...@@ -351,7 +351,6 @@ ...@@ -351,7 +351,6 @@
} }
], ],
"source": [ "source": [
"from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
"from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n",
"from mindspore.train import Model\n", "from mindspore.train import Model\n",
"from mindspore import context\n", "from mindspore import context\n",
...@@ -368,21 +367,21 @@ ...@@ -368,21 +367,21 @@
" repeat_size = 1\n", " repeat_size = 1\n",
" network = LeNet5()\n", " network = LeNet5()\n",
" \n", " \n",
" train_data = create_dataset(train_data_path,repeat_size = repeat_size)\n", " train_data = create_dataset(train_data_path, repeat_size=repeat_size)\n",
" eval_data = create_dataset(eval_data_path,repeat_size = repeat_size)\n", " eval_data = create_dataset(eval_data_path, repeat_size=repeat_size)\n",
" \n", " \n",
" # define the loss function\n", " # define the loss function\n",
" net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", " net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n",
" # define the optimizer\n", " # define the optimizer\n",
" net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n", " net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n",
" config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)\n", " config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)\n",
" ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\",directory=ckpt_save_dir, config=config_ck)\n", " ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", directory=ckpt_save_dir, config=config_ck)\n",
" model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n", " model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
" \n", " \n",
" epoch_per_eval = {\"epoch\":[],\"acc\":[]}\n", " epoch_per_eval = {\"epoch\": [], \"acc\": []}\n",
" eval_cb = EvalCallBack(model,eval_data,eval_per_epoch)\n", " eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)\n",
" \n", " \n",
" model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb],\n", " model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],\n",
" dataset_sink_mode=True)" " dataset_sink_mode=True)"
] ]
}, },
...@@ -441,11 +440,13 @@ ...@@ -441,11 +440,13 @@
} }
], ],
"source": [ "source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def eval_show(epoch_per_eval):\n", "def eval_show(epoch_per_eval):\n",
" plt.xlabel(\"epoch number\")\n", " plt.xlabel(\"epoch number\")\n",
" plt.ylabel(\"Model accuracy\")\n", " plt.ylabel(\"Model accuracy\")\n",
" plt.title(\"Model accuracy variation chart\")\n", " plt.title(\"Model accuracy variation chart\")\n",
" plt.plot(epoch_per_eval[\"epoch\"],epoch_per_eval[\"acc\"],\"red\")\n", " plt.plot(epoch_per_eval[\"epoch\"], epoch_per_eval[\"acc\"], \"red\")\n",
" plt.show()\n", " plt.show()\n",
" \n", " \n",
"eval_show(epoch_per_eval)" "eval_show(epoch_per_eval)"
......
...@@ -41,23 +41,23 @@ ...@@ -41,23 +41,23 @@
- `model`:即是MindSpore中的`Model`函数。 - `model`:即是MindSpore中的`Model`函数。
- `eval_dataset`:验证数据集。 - `eval_dataset`:验证数据集。
- `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch":[],"acc":[]}` - `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch": [], "acc": []}`
```python ```python
import matplotlib.pyplot as plt
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
class EvalCallBack(Callback): class EvalCallBack(Callback):
def __init__(self, model, eval_dataset, eval_per_epoch): def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):
self.model = model self.model = model
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.eval_per_epoch = eval_per_epoch self.eval_per_epoch = eval_per_epoch
self.epoch_per_eval = epoch_per_eval
def epoch_end(self, run_context): def epoch_end(self, run_context):
cb_param = run_context.original_args() cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num cur_epoch = cb_param.cur_epoch_num
if cur_epoch % self.eval_per_epoch == 0: if cur_epoch % self.eval_per_epoch == 0:
acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True) acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)
epoch_per_eval["epoch"].append(cur_epoch) epoch_per_eval["epoch"].append(cur_epoch)
epoch_per_eval["acc"].append(acc["Accuracy"]) epoch_per_eval["acc"].append(acc["Accuracy"])
print(acc) print(acc)
...@@ -79,12 +79,10 @@ class EvalCallBack(Callback): ...@@ -79,12 +79,10 @@ class EvalCallBack(Callback):
- `epoch_per_eval`:定义收集`epoch`数和对应模型精度信息的字典。 - `epoch_per_eval`:定义收集`epoch`数和对应模型精度信息的字典。
```python ```python
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model from mindspore.train import Model
from mindspore import context from mindspore import context
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
if __name__ == "__main__": if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
...@@ -98,10 +96,10 @@ if __name__ == "__main__": ...@@ -98,10 +96,10 @@ if __name__ == "__main__":
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",directory=ckpt_save_dir, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",directory=ckpt_save_dir, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
epoch_per_eval = {"epoch":[],"acc":[]} epoch_per_eval = {"epoch": [], "acc": []}
eval_cb = EvalCallBack(model,eval_data,eval_per_epoch) eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)
model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb], model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],
dataset_sink_mode=True) dataset_sink_mode=True)
``` ```
...@@ -152,11 +150,13 @@ lenet_ckpt ...@@ -152,11 +150,13 @@ lenet_ckpt
```python ```python
import matplotlib.pyplot as plt
def eval_show(epoch_per_eval): def eval_show(epoch_per_eval):
plt.xlabel("epoch number") plt.xlabel("epoch number")
plt.ylabel("Model accuracy") plt.ylabel("Model accuracy")
plt.title("Model accuracy variation chart") plt.title("Model accuracy variation chart")
plt.plot(epoch_per_eval["epoch"],epoch_per_eval["acc"],"red") plt.plot(epoch_per_eval["epoch"], epoch_per_eval["acc"], "red")
plt.show() plt.show()
eval_show(epoch_per_eval) eval_show(epoch_per_eval)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册