diff --git a/tutorials/notebook/synchronization_training_and_evaluation.ipynb b/tutorials/notebook/synchronization_training_and_evaluation.ipynb index 486fd4cf2f762b193054db536c9576b5bdc5512f..236ae433c882d620ead10a0247dc321bab8122d3 100644 --- a/tutorials/notebook/synchronization_training_and_evaluation.ipynb +++ b/tutorials/notebook/synchronization_training_and_evaluation.ipynb @@ -230,22 +230,22 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", "from mindspore.train.callback import Callback\n", "\n", "class EvalCallBack(Callback):\n", - " def __init__(self, model, eval_dataset, eval_per_epoch):\n", + " def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):\n", " self.model = model\n", " self.eval_dataset = eval_dataset\n", " self.eval_per_epoch = eval_per_epoch\n", + " self.epoch_per_eval = epoch_per_eval\n", " \n", " def epoch_end(self, run_context):\n", " cb_param = run_context.original_args()\n", " cur_epoch = cb_param.cur_epoch_num\n", " if cur_epoch % self.eval_per_epoch == 0:\n", - " acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True)\n", - " epoch_per_eval[\"epoch\"].append(cur_epoch)\n", - " epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n", + " acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True)\n", + " self.epoch_per_eval[\"epoch\"].append(cur_epoch)\n", + " self.epoch_per_eval[\"acc\"].append(acc[\"Accuracy\"])\n", " print(acc)\n" ] }, @@ -351,7 +351,6 @@ } ], "source": [ - "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n", "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor\n", "from mindspore.train import Model\n", "from mindspore import context\n", @@ -368,21 +367,21 @@ " repeat_size = 1\n", " network = LeNet5()\n", " \n", - " train_data = create_dataset(train_data_path,repeat_size = repeat_size)\n", - " eval_data = create_dataset(eval_data_path,repeat_size = repeat_size)\n", + " train_data = create_dataset(train_data_path, repeat_size=repeat_size)\n", + " eval_data = create_dataset(eval_data_path, repeat_size=repeat_size)\n", " \n", " # define the loss function\n", " net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')\n", " # define the optimizer\n", " net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n", " config_ck = CheckpointConfig(save_checkpoint_steps=eval_per_epoch*1875, keep_checkpoint_max=15)\n", - " ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\",directory=ckpt_save_dir, config=config_ck)\n", + " ckpoint_cb = ModelCheckpoint(prefix=\"checkpoint_lenet\", directory=ckpt_save_dir, config=config_ck)\n", " model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n", " \n", - " epoch_per_eval = {\"epoch\":[],\"acc\":[]}\n", - " eval_cb = EvalCallBack(model,eval_data,eval_per_epoch)\n", + " epoch_per_eval = {\"epoch\": [], \"acc\": []}\n", + " eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval)\n", " \n", - " model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb],\n", + " model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb],\n", " dataset_sink_mode=True)" ] }, @@ -441,11 +440,13 @@ } ], "source": [ + "import matplotlib.pyplot as plt\n", + "\n", "def eval_show(epoch_per_eval):\n", " plt.xlabel(\"epoch number\")\n", " plt.ylabel(\"Model accuracy\")\n", " plt.title(\"Model accuracy variation chart\")\n", - " plt.plot(epoch_per_eval[\"epoch\"],epoch_per_eval[\"acc\"],\"red\")\n", + " plt.plot(epoch_per_eval[\"epoch\"], epoch_per_eval[\"acc\"], \"red\")\n", " plt.show()\n", " \n", "eval_show(epoch_per_eval)" diff --git a/tutorials/source_zh_cn/advanced_use/synchronization_training_and_evaluation.md b/tutorials/source_zh_cn/advanced_use/synchronization_training_and_evaluation.md index 6e6932f1894f5a6caa018e5a7684e738a323f294..773b4c4535380e34e324b64ef2abf4b429ed0b2d 100644 --- a/tutorials/source_zh_cn/advanced_use/synchronization_training_and_evaluation.md +++ b/tutorials/source_zh_cn/advanced_use/synchronization_training_and_evaluation.md @@ -1,4 +1,6 @@ -# 同步训练和验证模型 +# 同步训练和验证模型 + +`Ascend` `GPU` `CPU` `初级` `中级` `高级` `模型导出` `模型训练` @@ -41,25 +43,25 @@ - `model`:即是MindSpore中的`Model`函数。 - `eval_dataset`:验证数据集。 -- `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch":[],"acc":[]}`。 +- `epoch_per_eval`:记录验证模型的精度和相应的epoch数,其数据形式为`{"epoch": [], "acc": []}`。 ```python -import matplotlib.pyplot as plt from mindspore.train.callback import Callback class EvalCallBack(Callback): - def __init__(self, model, eval_dataset, eval_per_epoch): + def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval): self.model = model self.eval_dataset = eval_dataset self.eval_per_epoch = eval_per_epoch + self.epoch_per_eval = epoch_per_eval def epoch_end(self, run_context): cb_param = run_context.original_args() cur_epoch = cb_param.cur_epoch_num if cur_epoch % self.eval_per_epoch == 0: - acc = self.model.eval(self.eval_dataset,dataset_sink_mode = True) - epoch_per_eval["epoch"].append(cur_epoch) - epoch_per_eval["acc"].append(acc["Accuracy"]) + acc = self.model.eval(self.eval_dataset, dataset_sink_mode=True) + self.epoch_per_eval["epoch"].append(cur_epoch) + self.epoch_per_eval["acc"].append(acc["Accuracy"]) print(acc) ``` @@ -79,12 +81,10 @@ class EvalCallBack(Callback): - `epoch_per_eval`:定义收集`epoch`数和对应模型精度信息的字典。 ```python -from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train import Model from mindspore import context from mindspore.nn.metrics import Accuracy -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -98,10 +98,10 @@ if __name__ == "__main__": ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",directory=ckpt_save_dir, config=config_ck) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - epoch_per_eval = {"epoch":[],"acc":[]} - eval_cb = EvalCallBack(model,eval_data,eval_per_epoch) + epoch_per_eval = {"epoch": [], "acc": []} + eval_cb = EvalCallBack(model, eval_data, eval_per_epoch, epoch_per_eval) - model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375),eval_cb], + model.train(epoch_size, train_data, callbacks=[ckpoint_cb, LossMonitor(375), eval_cb], dataset_sink_mode=True) ``` @@ -152,11 +152,13 @@ lenet_ckpt ```python +import matplotlib.pyplot as plt + def eval_show(epoch_per_eval): plt.xlabel("epoch number") plt.ylabel("Model accuracy") plt.title("Model accuracy variation chart") - plt.plot(epoch_per_eval["epoch"],epoch_per_eval["acc"],"red") + plt.plot(epoch_per_eval["epoch"], epoch_per_eval["acc"], "red") plt.show() eval_show(epoch_per_eval)