提交 a5684e58 编写于 作者: S SunGaofeng

add print_configs in infer and test

上级 3b779edd
...@@ -82,7 +82,7 @@ use_multi_crop = 1 ...@@ -82,7 +82,7 @@ use_multi_crop = 1
[INFER] [INFER]
num_reader_threads = 8 num_reader_threads = 8
batch_size = 1 batch_size = 1
filelist = 'dataset/nonlocal/inferencelist.txt' filelist = 'dataset/nonlocal/inferlist.txt'
crop_size = 256 crop_size = 256
sample_rate = 8 sample_rate = 8
video_length = 8 video_length = 8
......
...@@ -79,7 +79,7 @@ class NonlocalReader(DataReader): ...@@ -79,7 +79,7 @@ class NonlocalReader(DataReader):
sample_times = 1 sample_times = 1
return reader_func(filelist, batch_size, sample_times, False, False, return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args) **dataset_args)
elif self.mode == 'test': elif self.mode == 'test' or self.mode == 'infer':
sample_times = cfg['TEST']['num_test_clips'] sample_times = cfg['TEST']['num_test_clips']
if cfg['TEST']['use_multi_crop'] == 1: if cfg['TEST']['use_multi_crop'] == 1:
sample_times = int(sample_times / 3) sample_times = int(sample_times / 3)
......
...@@ -83,8 +83,9 @@ def infer(args): ...@@ -83,8 +83,9 @@ def infer(args):
# parse config # parse config
config = parse_config(args.config) config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args)) infer_config = merge_configs(config, 'infer', vars(args))
logger.info("############### infer config ###############")
print_configs(infer_config)
infer_model = models.get_model(args.model_name, infer_config, mode='infer') infer_model = models.get_model(args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False) infer_model.build_input(use_pyreader=False)
infer_model.build_model() infer_model.build_model()
infer_feeds = infer_model.feeds() infer_feeds = infer_model.feeds()
...@@ -105,10 +106,8 @@ def infer(args): ...@@ -105,10 +106,8 @@ def infer(args):
# if no weight files specified, download weights from paddle # if no weight files specified, download weights from paddle
weights = args.weights or infer_model.get_weights() weights = args.weights or infer_model.get_weights()
def if_exist(var): infer_model.load_test_weights(exe, weights,
return os.path.exists(os.path.join(weights, var.name)) fluid.default_main_program(), place)
fluid.io.load_vars(exe, weights, predicate=if_exist)
infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds) infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
fetch_list = [x.name for x in infer_outputs] fetch_list = [x.name for x in infer_outputs]
......
...@@ -123,7 +123,7 @@ def obtain_arc(arc_type, video_length): ...@@ -123,7 +123,7 @@ def obtain_arc(arc_type, video_length):
def create_model(data, label, cfg, is_training=True, mode='train'): def create_model(data, label, cfg, is_training=True, mode='train'):
group = cfg.RESNETS.num_groups group = cfg.RESNETS.num_groups
width_per_group = cfg.RESNETS.width_per_group width_per_group = cfg.RESNETS.width_per_group
batch_size = int(cfg.TRAIN.batch_size / cfg.NUM_GPUS) batch_size = int(cfg.TRAIN.batch_size / cfg.TRAIN.num_gpus)
logger.info('--------------- ResNet-{} {}x{}d-{}, {} ---------------'. logger.info('--------------- ResNet-{} {}x{}d-{}, {} ---------------'.
format(cfg.MODEL.depth, group, width_per_group, format(cfg.MODEL.depth, group, width_per_group,
......
python infer.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --filelist=./dataset/nonlocal/infer.list \ python infer.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --filelist=./dataset/nonlocal/inferlist.txt \
--log_interval=10 --weights=./checkpoints/NONLOCAL_epoch0 --save_dir=./save --log_interval=10 --weights=./checkpoints/NONLOCAL_epoch0 --save_dir=./save
...@@ -68,6 +68,8 @@ def test(args): ...@@ -68,6 +68,8 @@ def test(args):
# parse config # parse config
config = parse_config(args.config) config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args)) test_config = merge_configs(config, 'test', vars(args))
logger.info("############### test config ###############")
print_configs(test_config)
# build model # build model
test_model = models.get_model(args.model_name, test_config, mode='test') test_model = models.get_model(args.model_name, test_config, mode='test')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册