提交 e67f0c39 编写于 作者: D dengkaipeng

refine script and config.

上级 1b5ce43e
......@@ -61,7 +61,10 @@ def merge_configs(cfg, sec, args_dict):
return cfg
def print_configs(cfg):
import pprint
logger.info('Training with config:')
logger.info(pprint.pformat(cfg))
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
......@@ -83,8 +83,7 @@ def infer(args):
# parse config
config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args))
logger.info("############### infer config ###############")
print_configs(infer_config)
print_configs(infer_config, "Infer")
infer_model = models.get_model(args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False)
infer_model.build_model()
......
python infer.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt \
--filelist=./data/youtube8m/infer.list \
--filelist=./dataset/youtube8m/infer.list \
--weights=./checkpoints/AttentionCluster_epoch0 \
--save_dir="./save"
python infer.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt \
--filelist=./data/youtube8m/infer.list \
--filelist=./dataset/youtube8m/infer.list \
--weights=./checkpoints/AttentionLSTM_epoch0 \
--save_dir="./save"
python infer.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./data/youtube8m/infer.list \
python infer.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./dataset/youtube8m/infer.list \
--weights=./checkpoints/NEXTVLAD_epoch0 \
--save_dir="./save"
python infer.py --model_name="STNET" --config=./configs/stnet.txt --filelist=./data/kinetics/infer.list \
python infer.py --model_name="STNET" --config=./configs/stnet.txt --filelist=./dataset/kinetics/infer.list \
--log_interval=10 --weights=./checkpoints/STNET_epoch0 --save_dir=./save
python infer.py --model_name="TSM" --config=./configs/tsm.txt --filelist=./data/kinetics/infer.list \
python infer.py --model_name="TSM" --config=./configs/tsm.txt --filelist=./dataset/kinetics/infer.list \
--log_interval=10 --weights=./checkpoints/TSM_epoch0 --save_dir=./save
python infer.py --model_name="TSN" --config=./configs/tsn.txt --filelist=./data/kinetics/infer.list \
python infer.py --model_name="TSN" --config=./configs/tsn.txt --filelist=./dataset/kinetics/infer.list \
--log_interval=10 --weights=./checkpoints/TSN_epoch0 --save_dir=./save
......@@ -68,8 +68,7 @@ def test(args):
# parse config
config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args))
logger.info("############### test config ###############")
print_configs(test_config)
print_configs(test_config, "Test")
# build model
test_model = models.get_model(args.model_name, test_config, mode='test')
......
......@@ -107,8 +107,7 @@ def train(args):
config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args))
logger.info("############### train config ###############")
print_configs(train_config)
print_configs(train_config, 'Train')
train_model = models.get_model(args.model_name, train_config, mode='train')
valid_model = models.get_model(args.model_name, valid_config, mode='valid')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册