From f4eac8737808dfac0e07854d69e7d9f8196a3fb8 Mon Sep 17 00:00:00 2001 From: wukesong Date: Tue, 28 Apr 2020 11:50:24 +0800 Subject: [PATCH] modify lenet --- tutorials/source_en/quick_start/quick_start.md | 5 ++--- tutorials/source_zh_cn/quick_start/quick_start.md | 5 ++--- tutorials/tutorial_code/lenet.py | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tutorials/source_en/quick_start/quick_start.md b/tutorials/source_en/quick_start/quick_start.md index f0d1b7f6..c14bacee 100644 --- a/tutorials/source_en/quick_start/quick_start.md +++ b/tutorials/source_en/quick_start/quick_start.md @@ -96,8 +96,8 @@ from mindspore import context if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore LeNet Example') - parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], - help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) @@ -235,7 +235,6 @@ class LeNet5(nn.Cell): #define the operator required def __init__(self): super(LeNet5, self).__init__() - self.batch_size = 32 self.conv1 = conv(1, 6, 5) self.conv2 = conv(6, 16, 5) self.fc1 = fc_with_initialize(16 * 5 * 5, 120) diff --git a/tutorials/source_zh_cn/quick_start/quick_start.md b/tutorials/source_zh_cn/quick_start/quick_start.md index c00b41ef..86c57b82 100644 --- a/tutorials/source_zh_cn/quick_start/quick_start.md +++ b/tutorials/source_zh_cn/quick_start/quick_start.md @@ -98,8 +98,8 @@ from mindspore import context if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore LeNet Example') - parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], - help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) @@ -237,7 +237,6 @@ class LeNet5(nn.Cell): #define the operator required def __init__(self): super(LeNet5, self).__init__() - self.batch_size = 32 self.conv1 = conv(1, 6, 5) self.conv2 = conv(6, 16, 5) self.fc1 = fc_with_initialize(16 * 5 * 5, 120) diff --git a/tutorials/tutorial_code/lenet.py b/tutorials/tutorial_code/lenet.py index f3899de1..5287c880 100644 --- a/tutorials/tutorial_code/lenet.py +++ b/tutorials/tutorial_code/lenet.py @@ -142,7 +142,6 @@ class LeNet5(nn.Cell): # define the operator required def __init__(self): super(LeNet5, self).__init__() - self.batch_size = 32 self.conv1 = conv(1, 6, 5) self.conv2 = conv(6, 16, 5) self.fc1 = fc_with_initialize(16 * 5 * 5, 120) @@ -192,8 +191,8 @@ def test_net(args, network, model, mnist_path): if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore LeNet Example') - parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], - help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: CPU)') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) -- GitLab