提交 a52f39a2 编写于 作者: M mir-of

fix bool argument for argparse

上级 04a16945
......@@ -12,11 +12,21 @@ from ofrecord_util import add_ofrecord_args
def get_parser(parser=None):
def str_list(x):
return x.split(',')
def int_list(x):
return list(map(int, x.split(',')))
def float_list(x):
return list(map(float, x.split(',')))
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
if parser is None:
parser = argparse.ArgumentParser("flags for cnn benchmark")
......@@ -30,10 +40,28 @@ def get_parser(parser=None):
help='nodes ip list for training, devided by ",", length >= num_nodes')
parser.add_argument("--model", type=str, default="vgg16", help="vgg16 or resnet50")
parser.add_argument("--use_fp16", type=bool, default=False, help="fp16")
parser.add_argument("--use_boxing_v2", type=bool, default=False, help="use boxing v2")
parser.add_argument("--use_new_dataloader", type=bool, default=False, help="use new dataloader")
parser.add_argument(
'--use_fp16',
type=str2bool,
nargs='?',
const=True,
help='Whether to use use fp16'
)
parser.add_argument(
'--use_boxing_v2',
type=str2bool,
nargs='?',
const=True,
help='Whether to use boxing v2'
)
parser.add_argument(
'--use_new_dataloader',
type=str2bool,
nargs='?',
const=True,
help='Whether to use new dataloader'
)
# train and validaion
parser.add_argument('--num_epochs', type=int, default=90, help='number of epochs')
parser.add_argument("--model_load_dir", type=str, default=None, help="model load directory if need")
......@@ -54,8 +82,7 @@ def get_parser(parser=None):
## snapshot
parser.add_argument("--model_save_dir", type=str,
#default="./output/model_save-{}".format(str(datetime.now().strftime("%Y%m%d%H%M%S"))),
default="./output/snapshots",
default="./output/snapshots/model_save-{}".format(str(datetime.now().strftime("%Y%m%d%H%M%S"))),
help="model save directory",
)
......@@ -67,7 +94,6 @@ def get_parser(parser=None):
default=1,
help="print loss every n iteration",
)
#add_dali_args(parser)
add_ofrecord_args(parser)
add_optimizer_args(parser)
return parser
......
......@@ -28,12 +28,9 @@ def get_train_config(args):
if args.use_boxing_v2:
train_config.use_boxing_v2(True)
train_config.prune_parallel_cast_ops(True)
train_config.train.model_update_conf(get_optimizer(args))
#if args.use_fp16:
# train_config.enable_auto_mixed_precision()
train_config.enable_inplace(True)
return train_config
......
......@@ -16,9 +16,11 @@ from util import Snapshot, Summary, InitNodes, Metric
import ofrecord_util
from job_function_util import get_train_config, get_val_config
import oneflow as flow
#import vgg_model
import alexnet_model
import vgg_model
import resnet_model
#import alexnet_model
total_device_num = args.num_nodes * args.gpu_num_per_node
......@@ -31,8 +33,8 @@ num_val_steps = int(args.num_val_examples / val_batch_size)
model_dict = {
"resnet50": resnet_model.resnet50,
#"vgg16": vgg_model.vgg16,
#"alexnet": alexnet_model.alexnet,
"vgg16": vgg_model.vgg16,
"alexnet": alexnet_model.alexnet,
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册