提交 ebb8f1ee 编写于 作者: S SunGaofeng

fix test and infer for new reader and metrics

上级 c5e356c9
......@@ -17,7 +17,7 @@ learning_rate = 0.001
pretrain_base = None
batch_size = 160
use_gpu = True
gpu_num = 4
num_gpus = 4
filelist = "data/youtube8m/train.list"
[VALID]
......
......@@ -21,7 +21,7 @@ num_samples = 5000000
pretrain_base = None
batch_size = 160
use_gpu = True
gpu_num = 4
num_gpus = 4
filelist = "data/youtube8m/train.list"
[VALID]
......
......@@ -56,12 +56,12 @@ class KineticsReader(DataReader):
def __init__(self, name, phase, cfg):
self.name = name
self.phase = phase
self.format = cfg.MODEL.format #cfg['format']
self.num_classes = cfg.MODEL.num_classes #cfg['num_classes']
self.seg_num = cfg.MODEL.segnum #['seg_num']
self.seglen = cfg.MODEL.seglen #['seglen']
self.short_size = cfg[phase.upper()]['short_size'] # ['short_size']
self.target_size = cfg[phase.upper()]['target_size'] #['target_size']
self.format = cfg.MODEL.format
self.num_classes = cfg.MODEL.num_classes
self.seg_num = cfg.MODEL.seg_num
self.seglen = cfg.MODEL.seglen
self.short_size = cfg[phase.upper()]['short_size']
self.target_size = cfg[phase.upper()]['target_size']
self.num_reader_threads = cfg[phase.upper()]['num_reader_threads']
self.buf_size = cfg[phase.upper()]['buf_size']
......
......@@ -26,6 +26,7 @@ import paddle.fluid as fluid
from config import *
import models
from datareader import get_reader
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
......@@ -81,8 +82,7 @@ def infer(args):
# parse config
config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args))
infer_model = models.get_model(
args.model_name, infer_config, mode='infer')
infer_model = models.get_model(args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False)
infer_model.build_model()
......@@ -97,7 +97,8 @@ def infer(args):
logger.error("[INFER] --filelist unset.")
return
assert os.path.exists(args.filelist), "{} not exist.".format(args.filelist)
infer_reader = infer_model.reader()
infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config)
if args.weights:
assert os.path.exists(
......
python infer.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
python3 infer.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
--filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/AttentionCluster_epoch0 \
--save-dir="./save"
python test.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
python3 test.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
--log-interval=5 --weights=./checkpoints/AttentionCluster_epoch0
python train.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch-num=5 \
python3 train.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch-num=5 \
--valid-interval=1 --save-interval=1 --log-interval=10
......@@ -22,6 +22,8 @@ import paddle.fluid as fluid
from config import *
import models
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
......@@ -68,8 +70,7 @@ def test(args):
test_config = merge_configs(config, 'test', vars(args))
# build model
test_model = models.get_model(
args.model_name, test_config, mode='test')
test_model = models.get_model(args.model_name, test_config, mode='test')
test_model.build_input(use_pyreader=False)
test_model.build_model()
test_feeds = test_model.feeds()
......@@ -90,8 +91,8 @@ def test(args):
fluid.io.load_vars(exe, weights, predicate=if_exist)
# get reader and metrics
test_reader = test_model.reader()
test_metrics = test_model.metrics()
test_reader = get_reader(args.model_name.upper(), 'test', test_config)
test_metrics = get_metrics(args.model_name.upper(), 'test', test_config)
test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)
fetch_list = [loss.name] + [x.name
......
......@@ -165,16 +165,19 @@ def train(args):
main_program=valid_prog)
# get reader
bs_denominator = 1
if (not args.no_use_pyreader) and args.use_gpu:
bs_denominator = train_config.TRAIN.num_gpus
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
bs_denominator)
valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
bs_denominator)
train_reader = get_reader(args.model_name.upper(), 'train', train_config)
valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
#train_reader = train_model.reader()
#valid_reader = valid_model.reader()
# get metrics
train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
#train_metrics = train_model.metrics()
#train_metrics = train_model.metrics()
train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
] + [train_feeds[-1].name]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册