提交 e67f0c39 编写于 作者: D dengkaipeng

refine script and config.

上级 1b5ce43e
...@@ -61,7 +61,10 @@ def merge_configs(cfg, sec, args_dict): ...@@ -61,7 +61,10 @@ def merge_configs(cfg, sec, args_dict):
return cfg return cfg
def print_configs(cfg): def print_configs(cfg, mode):
import pprint logger.info("---------------- {:>5} Arguments ----------------".format(mode))
logger.info('Training with config:') for sec, sec_items in cfg.items():
logger.info(pprint.pformat(cfg)) logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
...@@ -83,8 +83,7 @@ def infer(args): ...@@ -83,8 +83,7 @@ def infer(args):
# parse config # parse config
config = parse_config(args.config) config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args)) infer_config = merge_configs(config, 'infer', vars(args))
logger.info("############### infer config ###############") print_configs(infer_config, "Infer")
print_configs(infer_config)
infer_model = models.get_model(args.model_name, infer_config, mode='infer') infer_model = models.get_model(args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False) infer_model.build_input(use_pyreader=False)
infer_model.build_model() infer_model.build_model()
......
python infer.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt \ python infer.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt \
--filelist=./data/youtube8m/infer.list \ --filelist=./dataset/youtube8m/infer.list \
--weights=./checkpoints/AttentionCluster_epoch0 \ --weights=./checkpoints/AttentionCluster_epoch0 \
--save_dir="./save" --save_dir="./save"
python infer.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt \ python infer.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt \
--filelist=./data/youtube8m/infer.list \ --filelist=./dataset/youtube8m/infer.list \
--weights=./checkpoints/AttentionLSTM_epoch0 \ --weights=./checkpoints/AttentionLSTM_epoch0 \
--save_dir="./save" --save_dir="./save"
python infer.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./data/youtube8m/infer.list \ python infer.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./dataset/youtube8m/infer.list \
--weights=./checkpoints/NEXTVLAD_epoch0 \ --weights=./checkpoints/NEXTVLAD_epoch0 \
--save_dir="./save" --save_dir="./save"
python infer.py --model_name="STNET" --config=./configs/stnet.txt --filelist=./data/kinetics/infer.list \ python infer.py --model_name="STNET" --config=./configs/stnet.txt --filelist=./dataset/kinetics/infer.list \
--log_interval=10 --weights=./checkpoints/STNET_epoch0 --save_dir=./save --log_interval=10 --weights=./checkpoints/STNET_epoch0 --save_dir=./save
python infer.py --model_name="TSM" --config=./configs/tsm.txt --filelist=./data/kinetics/infer.list \ python infer.py --model_name="TSM" --config=./configs/tsm.txt --filelist=./dataset/kinetics/infer.list \
--log_interval=10 --weights=./checkpoints/TSM_epoch0 --save_dir=./save --log_interval=10 --weights=./checkpoints/TSM_epoch0 --save_dir=./save
python infer.py --model_name="TSN" --config=./configs/tsn.txt --filelist=./data/kinetics/infer.list \ python infer.py --model_name="TSN" --config=./configs/tsn.txt --filelist=./dataset/kinetics/infer.list \
--log_interval=10 --weights=./checkpoints/TSN_epoch0 --save_dir=./save --log_interval=10 --weights=./checkpoints/TSN_epoch0 --save_dir=./save
...@@ -68,8 +68,7 @@ def test(args): ...@@ -68,8 +68,7 @@ def test(args):
# parse config # parse config
config = parse_config(args.config) config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args)) test_config = merge_configs(config, 'test', vars(args))
logger.info("############### test config ###############") print_configs(test_config, "Test")
print_configs(test_config)
# build model # build model
test_model = models.get_model(args.model_name, test_config, mode='test') test_model = models.get_model(args.model_name, test_config, mode='test')
......
...@@ -107,8 +107,7 @@ def train(args): ...@@ -107,8 +107,7 @@ def train(args):
config = parse_config(args.config) config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args)) train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args)) valid_config = merge_configs(config, 'valid', vars(args))
logger.info("############### train config ###############") print_configs(train_config, 'Train')
print_configs(train_config)
train_model = models.get_model(args.model_name, train_config, mode='train') train_model = models.get_model(args.model_name, train_config, mode='train')
valid_model = models.get_model(args.model_name, valid_config, mode='valid') valid_model = models.get_model(args.model_name, valid_config, mode='valid')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册