From 7fd20bd96c772b8b364e141b0e286f7f2a2bd1e4 Mon Sep 17 00:00:00 2001 From: wukesong Date: Wed, 3 Jun 2020 13:00:18 +0800 Subject: [PATCH] modify lenet dataset_sink_mode=True --- tutorials/source_en/quick_start/quick_start.md | 7 ++++--- tutorials/source_zh_cn/quick_start/quick_start.md | 7 ++++--- tutorials/tutorial_code/lenet.py | 7 ++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tutorials/source_en/quick_start/quick_start.md b/tutorials/source_en/quick_start/quick_start.md index 0ad49251..801aa59c 100644 --- a/tutorials/source_en/quick_start/quick_start.md +++ b/tutorials/source_en/quick_start/quick_start.md @@ -100,6 +100,7 @@ if __name__ == "__main__": help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + dataset_sink_mode = not args.device_target == "CPU" ... ``` @@ -338,12 +339,12 @@ from mindspore.train.callback import LossMonitor from mindspore.train import Model ... -def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): +def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode): """define the training method""" print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) # train ... if __name__ == "__main__": @@ -353,7 +354,7 @@ if __name__ == "__main__": mnist_path = "./MNIST_Data" repeat_size = epoch_size model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) + train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode) ... ``` In the preceding information: diff --git a/tutorials/source_zh_cn/quick_start/quick_start.md b/tutorials/source_zh_cn/quick_start/quick_start.md index f4ffebd7..0e83e675 100644 --- a/tutorials/source_zh_cn/quick_start/quick_start.md +++ b/tutorials/source_zh_cn/quick_start/quick_start.md @@ -102,6 +102,7 @@ if __name__ == "__main__": help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + dataset_sink_mode = not args.device_target == "CPU" ... ``` @@ -339,12 +340,12 @@ from mindspore.train.callback import LossMonitor from mindspore.train import Model ... -def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): +def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode): """define the training method""" print("============== Starting Training ==============") #load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) ... if __name__ == "__main__": @@ -354,7 +355,7 @@ if __name__ == "__main__": mnist_path = "./MNIST_Data" repeat_size = epoch_size model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) + train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode) ... ``` 其中, diff --git a/tutorials/tutorial_code/lenet.py b/tutorials/tutorial_code/lenet.py index 84d0a9c5..441f4233 100644 --- a/tutorials/tutorial_code/lenet.py +++ b/tutorials/tutorial_code/lenet.py @@ -169,12 +169,12 @@ class LeNet5(nn.Cell): return x -def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): +def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode): """Define the training method.""" print("============== Starting Training ==============") # load training dataset ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) - model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) + model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) def test_net(args, network, model, mnist_path): @@ -196,6 +196,7 @@ if __name__ == "__main__": help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + dataset_sink_mode = not args.device_target == "CPU" # download mnist dataset download_dataset() # learning rate setting @@ -216,5 +217,5 @@ if __name__ == "__main__": # group layers into an object with training and evaluation features model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) + train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode) test_net(args, network, model, mnist_path) -- GitLab