提交 7fd20bd9 编写于 作者: W wukesong

modify lenet dataset_sink_mode=True

上级 5b02f1ae
...@@ -100,6 +100,7 @@ if __name__ == "__main__": ...@@ -100,6 +100,7 @@ if __name__ == "__main__":
help='device where the code will be implemented (default: CPU)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
dataset_sink_mode = not args.device_target == "CPU"
... ...
``` ```
...@@ -338,12 +339,12 @@ from mindspore.train.callback import LossMonitor ...@@ -338,12 +339,12 @@ from mindspore.train.callback import LossMonitor
from mindspore.train import Model from mindspore.train import Model
... ...
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode):
"""define the training method""" """define the training method"""
print("============== Starting Training ==============") print("============== Starting Training ==============")
#load training dataset #load training dataset
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) # train model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode) # train
... ...
if __name__ == "__main__": if __name__ == "__main__":
...@@ -353,7 +354,7 @@ if __name__ == "__main__": ...@@ -353,7 +354,7 @@ if __name__ == "__main__":
mnist_path = "./MNIST_Data" mnist_path = "./MNIST_Data"
repeat_size = epoch_size repeat_size = epoch_size
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode)
... ...
``` ```
In the preceding information: In the preceding information:
......
...@@ -102,6 +102,7 @@ if __name__ == "__main__": ...@@ -102,6 +102,7 @@ if __name__ == "__main__":
help='device where the code will be implemented (default: CPU)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
dataset_sink_mode = not args.device_target == "CPU"
... ...
``` ```
...@@ -339,12 +340,12 @@ from mindspore.train.callback import LossMonitor ...@@ -339,12 +340,12 @@ from mindspore.train.callback import LossMonitor
from mindspore.train import Model from mindspore.train import Model
... ...
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode):
"""define the training method""" """define the training method"""
print("============== Starting Training ==============") print("============== Starting Training ==============")
#load training dataset #load training dataset
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)
... ...
if __name__ == "__main__": if __name__ == "__main__":
...@@ -354,7 +355,7 @@ if __name__ == "__main__": ...@@ -354,7 +355,7 @@ if __name__ == "__main__":
mnist_path = "./MNIST_Data" mnist_path = "./MNIST_Data"
repeat_size = epoch_size repeat_size = epoch_size
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode)
... ...
``` ```
其中, 其中,
......
...@@ -169,12 +169,12 @@ class LeNet5(nn.Cell): ...@@ -169,12 +169,12 @@ class LeNet5(nn.Cell):
return x return x
def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb): def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, sink_mode):
"""Define the training method.""" """Define the training method."""
print("============== Starting Training ==============") print("============== Starting Training ==============")
# load training dataset # load training dataset
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)
def test_net(args, network, model, mnist_path): def test_net(args, network, model, mnist_path):
...@@ -196,6 +196,7 @@ if __name__ == "__main__": ...@@ -196,6 +196,7 @@ if __name__ == "__main__":
help='device where the code will be implemented (default: CPU)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
dataset_sink_mode = not args.device_target == "CPU"
# download mnist dataset # download mnist dataset
download_dataset() download_dataset()
# learning rate setting # learning rate setting
...@@ -216,5 +217,5 @@ if __name__ == "__main__": ...@@ -216,5 +217,5 @@ if __name__ == "__main__":
# group layers into an object with training and evaluation features # group layers into an object with training and evaluation features
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb) train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb, dataset_sink_mode)
test_net(args, network, model, mnist_path) test_net(args, network, model, mnist_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册