提交 8dc08095 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!76 modify lenet

Merge pull request !76 from wukesong/master
......@@ -96,8 +96,8 @@ from mindspore import context
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False)
......@@ -235,7 +235,6 @@ class LeNet5(nn.Cell):
#define the operator required
def __init__(self):
super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
......
......@@ -98,8 +98,8 @@ from mindspore import context
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False)
......@@ -235,7 +235,6 @@ class LeNet5(nn.Cell):
#define the operator required
def __init__(self):
super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
......
......@@ -143,7 +143,6 @@ class LeNet5(nn.Cell):
# define the operator required
def __init__(self):
super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
......@@ -193,8 +192,8 @@ def test_net(args, network, model, mnist_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册