diff --git a/fluid/PaddleCV/ocr_recognition/train.py b/fluid/PaddleCV/ocr_recognition/train.py index 7954d23dc02c93159315e4220ec2db0289fddb44..2e294907a6bbac5f311c420ad22d51eafa972da7 100755 --- a/fluid/PaddleCV/ocr_recognition/train.py +++ b/fluid/PaddleCV/ocr_recognition/train.py @@ -24,6 +24,10 @@ add_arg('log_period', int, 1000, "Log period.") add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.") add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.") add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.") +add_arg('train_images', str, None, "The directory of images to be used for training.") +add_arg('train_list', str, None, "The list file of images to be used for training.") +add_arg('test_images', str, None, "The directory of images to be used for test.") +add_arg('test_list', str, None, "The list file of images to be used for training.") add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'") add_arg('init_model', str, None, "The init model file of directory.") add_arg('use_gpu', bool, True, "Whether use GPU to train.") @@ -48,10 +52,6 @@ def train(args): get_feeder_data = get_attention_feeder_data num_classes = None - train_images = None - train_list = None - test_images = None - test_list = None num_classes = data_reader.num_classes( ) if num_classes is None else num_classes data_shape = data_reader.data_shape() @@ -62,12 +62,12 @@ def train(args): # data reader train_reader = data_reader.train( args.batch_size, - train_images_dir=train_images, - train_list_file=train_list, + train_images_dir=args.train_images, + train_list_file=args.train_list, cycle=args.total_step > 0, model=args.model) test_reader = data_reader.test( - test_images_dir=test_images, test_list_file=test_list, model=args.model) + test_images_dir=args.test_images, test_list_file=args.test_list, model=args.model) # prepare environment place = fluid.CPUPlace()