提交 7fd20bd9 编写于 作者: W wukesong

modify lenet dataset_sink_mode=True

上级 5b02f1ae
......@@ -100,6 +100,7 @@ if __name__ == "__main__":
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
dataset_sink_mode = not args.device_target == "CPU"
...
```
......@@ -338,12 +339,12 @@ from mindspore.train.callback import LossMonitor
from mindspore.train import Model
...
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb):
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode):
"""define the training method"""
print("============== Starting Training ==============")
#load training dataset
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) # train
...
if __name__ == "__main__":
......@@ -353,7 +354,7 @@ if __name__ == "__main__":
mnist_path = "./MNIST_Data"
repeat_size = epoch_size
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode)
...
```
In the preceding information:
......
......@@ -102,6 +102,7 @@ if __name__ == "__main__":
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
dataset_sink_mode = not args.device_target == "CPU"
...
```
......@@ -339,12 +340,12 @@ from mindspore.train.callback import LossMonitor
from mindspore.train import Model
...
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb):
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode):
"""define the training method"""
print("============== Starting Training ==============")
#load training dataset
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)
...
if __name__ == "__main__":
......@@ -354,7 +355,7 @@ if __name__ == "__main__":
mnist_path = "./MNIST_Data"
repeat_size = epoch_size
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode)
...
```
其中,
......
......@@ -169,12 +169,12 @@ class LeNet5(nn.Cell):
return x
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb):
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode):
"""Define the training method."""
print("============== Starting Training ==============")
# load training dataset
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)
def test_net(args, network, model, mnist_path):
......@@ -196,6 +196,7 @@ if __name__ == "__main__":
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
dataset_sink_mode = not args.device_target == "CPU"
# download mnist dataset
download_dataset()
# learning rate setting
......@@ -216,5 +217,5 @@ if __name__ == "__main__":
# group layers into an object with training and evaluation features
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode)
test_net(args, network, model, mnist_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册