提交 8dc08095 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!76 modify lenet

Merge pull request !76 from wukesong/master
...@@ -96,8 +96,8 @@ from mindspore import context ...@@ -96,8 +96,8 @@ from mindspore import context
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False) enable_mem_reuse=False)
...@@ -235,7 +235,6 @@ class LeNet5(nn.Cell): ...@@ -235,7 +235,6 @@ class LeNet5(nn.Cell):
#define the operator required #define the operator required
def __init__(self): def __init__(self):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5) self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
......
...@@ -98,8 +98,8 @@ from mindspore import context ...@@ -98,8 +98,8 @@ from mindspore import context
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False) enable_mem_reuse=False)
...@@ -235,7 +235,6 @@ class LeNet5(nn.Cell): ...@@ -235,7 +235,6 @@ class LeNet5(nn.Cell):
#define the operator required #define the operator required
def __init__(self): def __init__(self):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5) self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
......
...@@ -143,7 +143,6 @@ class LeNet5(nn.Cell): ...@@ -143,7 +143,6 @@ class LeNet5(nn.Cell):
# define the operator required # define the operator required
def __init__(self): def __init__(self):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.batch_size = 32
self.conv1 = conv(1, 6, 5) self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5) self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
...@@ -193,8 +192,8 @@ def test_net(args, network, model, mnist_path): ...@@ -193,8 +192,8 @@ def test_net(args, network, model, mnist_path):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False) enable_mem_reuse=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册