未验证 提交 eb36421c 编写于 作者: W whs 提交者: GitHub

Add args for train.py (#1472)

上级 19b08cfd
...@@ -24,6 +24,10 @@ add_arg('log_period', int, 1000, "Log period.") ...@@ -24,6 +24,10 @@ add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.") add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.") add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.") add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'") add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('init_model', str, None, "The init model file of directory.") add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.") add_arg('use_gpu', bool, True, "Whether use GPU to train.")
...@@ -48,10 +52,6 @@ def train(args): ...@@ -48,10 +52,6 @@ def train(args):
get_feeder_data = get_attention_feeder_data get_feeder_data = get_attention_feeder_data
num_classes = None num_classes = None
train_images = None
train_list = None
test_images = None
test_list = None
num_classes = data_reader.num_classes( num_classes = data_reader.num_classes(
) if num_classes is None else num_classes ) if num_classes is None else num_classes
data_shape = data_reader.data_shape() data_shape = data_reader.data_shape()
...@@ -62,12 +62,12 @@ def train(args): ...@@ -62,12 +62,12 @@ def train(args):
# data reader # data reader
train_reader = data_reader.train( train_reader = data_reader.train(
args.batch_size, args.batch_size,
train_images_dir=train_images, train_images_dir=args.train_images,
train_list_file=train_list, train_list_file=args.train_list,
cycle=args.total_step > 0, cycle=args.total_step > 0,
model=args.model) model=args.model)
test_reader = data_reader.test( test_reader = data_reader.test(
test_images_dir=test_images, test_list_file=test_list, model=args.model) test_images_dir=args.test_images, test_list_file=args.test_list, model=args.model)
# prepare environment # prepare environment
place = fluid.CPUPlace() place = fluid.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册