diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b793254da688079c5a6782f2c071f1c3d8f992d4..c3d294e60091f68d93cab244dc495e4fca2aa5a6 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -34,7 +34,6 @@ def parse_args(): parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_fp16", type=str2bool, default=False) - parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--gpu_mem", type=int, default=8000) # params for text detector diff --git a/tools/program.py b/tools/program.py index 787a59d49b9963421c99b17bd563ddc10a2a601b..4331f9d46d053c8472389d338922ae23b3d2a4cf 100755 --- a/tools/program.py +++ b/tools/program.py @@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc -def preprocess(): +def preprocess(is_train=False): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) merge_config(FLAGS.opt) @@ -350,15 +350,17 @@ def preprocess(): device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1 - - # save_config - save_model_dir = config['Global']['save_model_dir'] - os.makedirs(save_model_dir, exist_ok=True) - with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: - yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) - - logger = get_logger( - name='root', log_file='{}/train.log'.format(save_model_dir)) + if is_train: + # save_config + save_model_dir = config['Global']['save_model_dir'] + os.makedirs(save_model_dir, exist_ok=True) + with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: + yaml.dump( + dict(config), f, default_flow_style=False, sort_keys=False) + log_file = '{}/train.log'.format(save_model_dir) + else: + log_file = None + logger = get_logger(name='root', log_file=log_file) if config['Global']['use_visualdl']: from visualdl import LogWriter vdl_writer_path = '{}/vdl/'.format(save_model_dir) diff --git a/tools/train.py b/tools/train.py index 6e44c5982ec5595c9202d83b14c058a7579c6a27..383f8d83919b054999f19be1490b92e3d90d7eee 100755 --- a/tools/train.py +++ b/tools/train.py @@ -110,6 +110,6 @@ def test_reader(config, device, logger): if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() + config, device, logger, vdl_writer = program.preprocess(is_train=True) main(config, device, logger, vdl_writer) # test_reader(config, device, logger)