提交 17f96e95 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!17 change bool argument in argparse

Merge pull request !17 from panbingao/master
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
import random import random
import argparse import argparse
import numpy as np import numpy as np
from resnet import resnet50
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.ops.functional as F import mindspore.ops.functional as F
...@@ -37,7 +38,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net ...@@ -37,7 +38,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from resnet import resnet50
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
de.config.set_seed(1) de.config.set_seed(1)
...@@ -45,10 +45,10 @@ de.config.set_seed(1) ...@@ -45,10 +45,10 @@ de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification.') parser = argparse.ArgumentParser(description='Image classification.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distributei.') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute.')
parser.add_argument('--device_num', type=int, default=1, help='Device num.') parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.') parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'],
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') help='implement phase, set to train or test')
parser.add_argument('--epoch_size', type=int, default=1, help='Epoch size.') parser.add_argument('--epoch_size', type=int, default=1, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.') parser.add_argument('--batch_size', type=int, default=32, help='Batch size.')
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.') parser.add_argument('--num_classes', type=int, default=10, help='Num classes.')
...@@ -112,7 +112,7 @@ def create_dataset(repeat_num=1, training=True): ...@@ -112,7 +112,7 @@ def create_dataset(repeat_num=1, training=True):
return ds return ds
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if args_opt.mode == 'train' and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140]) auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init() init()
...@@ -124,7 +124,8 @@ if __name__ == '__main__': ...@@ -124,7 +124,8 @@ if __name__ == '__main__':
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'}) model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
if args_opt.do_train: if args_opt.mode == 'train': # train
print("============== Starting Training ==============")
dataset = create_dataset() dataset = create_dataset()
batch_num = dataset.get_dataset_size() batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10) config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10)
...@@ -133,7 +134,8 @@ if __name__ == '__main__': ...@@ -133,7 +134,8 @@ if __name__ == '__main__':
loss_cb = LossMonitor() loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
if args_opt.do_eval: if args_opt.mode == 'test': # test
print("============== Starting Testing ==============")
if args_opt.checkpoint_path: if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册