diff --git a/tools/program.py b/tools/program.py index b65e91887eaa0999f0e2991fa295dd71ef68c823..72fca1482b4a74547ad77b9a0b28de111c8d7a5d 100755 --- a/tools/program.py +++ b/tools/program.py @@ -560,12 +560,12 @@ def preprocess(is_train=False): loggers = [] - if config['Global']['use_visualdl']: + if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']: save_model_dir = config['Global']['save_model_dir'] vdl_writer_path = '{}/vdl/'.format(save_model_dir) log_writer = VDLLogger(save_model_dir) loggers.append(log_writer) - if config['Global']['use_wandb'] or 'wandb' in config: + if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: save_dir = config['Global']['save_model_dir'] wandb_writer_path = "{}/wandb".format(save_dir) if "wandb" in config: