提交 17f96e95 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!17 change bool argument in argparse

Merge pull request !17 from panbingao/master
......@@ -21,6 +21,7 @@ import os
import random
import argparse
import numpy as np
from resnet import resnet50
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.ops.functional as F
......@@ -37,7 +38,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from resnet import resnet50
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
......@@ -45,10 +45,10 @@ de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distributei.')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute.')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'],
help='implement phase, set to train or test')
parser.add_argument('--epoch_size', type=int, default=1, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.')
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.')
......@@ -112,7 +112,7 @@ def create_dataset(repeat_num=1, training=True):
return ds
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
if args_opt.mode == 'train' and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
......@@ -124,7 +124,8 @@ if __name__ == '__main__':
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
if args_opt.do_train:
if args_opt.mode == 'train': # train
print("============== Starting Training ==============")
dataset = create_dataset()
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10)
......@@ -133,7 +134,8 @@ if __name__ == '__main__':
loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
if args_opt.do_eval:
if args_opt.mode == 'test': # test
print("============== Starting Testing ==============")
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册