提交 a5684e58 编写于 作者: S SunGaofeng

add print_configs in infer and test

上级 3b779edd
......@@ -82,7 +82,7 @@ use_multi_crop = 1
[INFER]
num_reader_threads = 8
batch_size = 1
filelist = 'dataset/nonlocal/inferencelist.txt'
filelist = 'dataset/nonlocal/inferlist.txt'
crop_size = 256
sample_rate = 8
video_length = 8
......
......@@ -79,7 +79,7 @@ class NonlocalReader(DataReader):
sample_times = 1
return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args)
elif self.mode == 'test':
elif self.mode == 'test' or self.mode == 'infer':
sample_times = cfg['TEST']['num_test_clips']
if cfg['TEST']['use_multi_crop'] == 1:
sample_times = int(sample_times / 3)
......
......@@ -83,8 +83,9 @@ def infer(args):
# parse config
config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args))
logger.info("############### infer config ###############")
print_configs(infer_config)
infer_model = models.get_model(args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False)
infer_model.build_model()
infer_feeds = infer_model.feeds()
......@@ -105,10 +106,8 @@ def infer(args):
# if no weight files specified, download weights from paddle
weights = args.weights or infer_model.get_weights()
def if_exist(var):
return os.path.exists(os.path.join(weights, var.name))
fluid.io.load_vars(exe, weights, predicate=if_exist)
infer_model.load_test_weights(exe, weights,
fluid.default_main_program(), place)
infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
fetch_list = [x.name for x in infer_outputs]
......
......@@ -123,7 +123,7 @@ def obtain_arc(arc_type, video_length):
def create_model(data, label, cfg, is_training=True, mode='train'):
group = cfg.RESNETS.num_groups
width_per_group = cfg.RESNETS.width_per_group
batch_size = int(cfg.TRAIN.batch_size / cfg.NUM_GPUS)
batch_size = int(cfg.TRAIN.batch_size / cfg.TRAIN.num_gpus)
logger.info('--------------- ResNet-{} {}x{}d-{}, {} ---------------'.
format(cfg.MODEL.depth, group, width_per_group,
......
python infer.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --filelist=./dataset/nonlocal/infer.list \
python infer.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --filelist=./dataset/nonlocal/inferlist.txt \
--log_interval=10 --weights=./checkpoints/NONLOCAL_epoch0 --save_dir=./save
......@@ -68,6 +68,8 @@ def test(args):
# parse config
config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args))
logger.info("############### test config ###############")
print_configs(test_config)
# build model
test_model = models.get_model(args.model_name, test_config, mode='test')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册