提交 f4eac873 编写于 作者: W wukesong

modify lenet

上级 474016c6
...@@ -96,8 +96,8 @@ from mindspore import context ...@@ -96,8 +96,8 @@ from mindspore import context
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False) enable_mem_reuse=False)
...@@ -235,7 +235,6 @@ class LeNet5(nn.Cell): ...@@ -235,7 +235,6 @@ class LeNet5(nn.Cell):
#define the operator required #define the operator required
def __init__(self): def __init__(self):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5) self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
......
...@@ -98,8 +98,8 @@ from mindspore import context ...@@ -98,8 +98,8 @@ from mindspore import context
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False) enable_mem_reuse=False)
...@@ -237,7 +237,6 @@ class LeNet5(nn.Cell): ...@@ -237,7 +237,6 @@ class LeNet5(nn.Cell):
#define the operator required #define the operator required
def __init__(self): def __init__(self):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5) self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
......
...@@ -142,7 +142,6 @@ class LeNet5(nn.Cell): ...@@ -142,7 +142,6 @@ class LeNet5(nn.Cell):
# define the operator required # define the operator required
def __init__(self): def __init__(self):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5) self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
...@@ -192,8 +191,8 @@ def test_net(args, network, model, mnist_path): ...@@ -192,8 +191,8 @@ def test_net(args, network, model, mnist_path):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False) enable_mem_reuse=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册