提交 ebb8f1ee 编写于 作者: S SunGaofeng

fix test and infer for new reader and metrics

上级 c5e356c9
...@@ -17,7 +17,7 @@ learning_rate = 0.001 ...@@ -17,7 +17,7 @@ learning_rate = 0.001
pretrain_base = None pretrain_base = None
batch_size = 160 batch_size = 160
use_gpu = True use_gpu = True
gpu_num = 4 num_gpus = 4
filelist = "data/youtube8m/train.list" filelist = "data/youtube8m/train.list"
[VALID] [VALID]
......
...@@ -21,7 +21,7 @@ num_samples = 5000000 ...@@ -21,7 +21,7 @@ num_samples = 5000000
pretrain_base = None pretrain_base = None
batch_size = 160 batch_size = 160
use_gpu = True use_gpu = True
gpu_num = 4 num_gpus = 4
filelist = "data/youtube8m/train.list" filelist = "data/youtube8m/train.list"
[VALID] [VALID]
......
...@@ -56,12 +56,12 @@ class KineticsReader(DataReader): ...@@ -56,12 +56,12 @@ class KineticsReader(DataReader):
def __init__(self, name, phase, cfg): def __init__(self, name, phase, cfg):
self.name = name self.name = name
self.phase = phase self.phase = phase
self.format = cfg.MODEL.format #cfg['format'] self.format = cfg.MODEL.format
self.num_classes = cfg.MODEL.num_classes #cfg['num_classes'] self.num_classes = cfg.MODEL.num_classes
self.seg_num = cfg.MODEL.segnum #['seg_num'] self.seg_num = cfg.MODEL.seg_num
self.seglen = cfg.MODEL.seglen #['seglen'] self.seglen = cfg.MODEL.seglen
self.short_size = cfg[phase.upper()]['short_size'] # ['short_size'] self.short_size = cfg[phase.upper()]['short_size']
self.target_size = cfg[phase.upper()]['target_size'] #['target_size'] self.target_size = cfg[phase.upper()]['target_size']
self.num_reader_threads = cfg[phase.upper()]['num_reader_threads'] self.num_reader_threads = cfg[phase.upper()]['num_reader_threads']
self.buf_size = cfg[phase.upper()]['buf_size'] self.buf_size = cfg[phase.upper()]['buf_size']
......
...@@ -26,6 +26,7 @@ import paddle.fluid as fluid ...@@ -26,6 +26,7 @@ import paddle.fluid as fluid
from config import * from config import *
import models import models
from datareader import get_reader
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
...@@ -81,8 +82,7 @@ def infer(args): ...@@ -81,8 +82,7 @@ def infer(args):
# parse config # parse config
config = parse_config(args.config) config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args)) infer_config = merge_configs(config, 'infer', vars(args))
infer_model = models.get_model( infer_model = models.get_model(args.model_name, infer_config, mode='infer')
args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False) infer_model.build_input(use_pyreader=False)
infer_model.build_model() infer_model.build_model()
...@@ -97,7 +97,8 @@ def infer(args): ...@@ -97,7 +97,8 @@ def infer(args):
logger.error("[INFER] --filelist unset.") logger.error("[INFER] --filelist unset.")
return return
assert os.path.exists(args.filelist), "{} not exist.".format(args.filelist) assert os.path.exists(args.filelist), "{} not exist.".format(args.filelist)
infer_reader = infer_model.reader()
infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config)
if args.weights: if args.weights:
assert os.path.exists( assert os.path.exists(
......
python infer.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \ python3 infer.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
--filelist=./data/youtube8m/infer.list \ --filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/AttentionCluster_epoch0 \ --weights=./checkpoints/AttentionCluster_epoch0 \
--save-dir="./save" --save-dir="./save"
python test.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \ python3 test.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
--log-interval=5 --weights=./checkpoints/AttentionCluster_epoch0 --log-interval=5 --weights=./checkpoints/AttentionCluster_epoch0
python train.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch-num=5 \ python3 train.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch-num=5 \
--valid-interval=1 --save-interval=1 --log-interval=10 --valid-interval=1 --save-interval=1 --log-interval=10
...@@ -22,6 +22,8 @@ import paddle.fluid as fluid ...@@ -22,6 +22,8 @@ import paddle.fluid as fluid
from config import * from config import *
import models import models
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = [] logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
...@@ -68,8 +70,7 @@ def test(args): ...@@ -68,8 +70,7 @@ def test(args):
test_config = merge_configs(config, 'test', vars(args)) test_config = merge_configs(config, 'test', vars(args))
# build model # build model
test_model = models.get_model( test_model = models.get_model(args.model_name, test_config, mode='test')
args.model_name, test_config, mode='test')
test_model.build_input(use_pyreader=False) test_model.build_input(use_pyreader=False)
test_model.build_model() test_model.build_model()
test_feeds = test_model.feeds() test_feeds = test_model.feeds()
...@@ -90,8 +91,8 @@ def test(args): ...@@ -90,8 +91,8 @@ def test(args):
fluid.io.load_vars(exe, weights, predicate=if_exist) fluid.io.load_vars(exe, weights, predicate=if_exist)
# get reader and metrics # get reader and metrics
test_reader = test_model.reader() test_reader = get_reader(args.model_name.upper(), 'test', test_config)
test_metrics = test_model.metrics() test_metrics = get_metrics(args.model_name.upper(), 'test', test_config)
test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds) test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)
fetch_list = [loss.name] + [x.name fetch_list = [loss.name] + [x.name
......
...@@ -165,16 +165,19 @@ def train(args): ...@@ -165,16 +165,19 @@ def train(args):
main_program=valid_prog) main_program=valid_prog)
# get reader # get reader
bs_denominator = 1
if (not args.no_use_pyreader) and args.use_gpu:
bs_denominator = train_config.TRAIN.num_gpus
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size /
bs_denominator)
valid_config.VALID.batch_size = int(valid_config.VALID.batch_size /
bs_denominator)
train_reader = get_reader(args.model_name.upper(), 'train', train_config) train_reader = get_reader(args.model_name.upper(), 'train', train_config)
valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config) valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
#train_reader = train_model.reader()
#valid_reader = valid_model.reader()
# get metrics # get metrics
train_metrics = get_metrics(args.model_name.upper(), 'train', train_config) train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config) valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
#train_metrics = train_model.metrics()
#train_metrics = train_model.metrics()
train_fetch_list = [train_loss.name] + [x.name for x in train_outputs train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
] + [train_feeds[-1].name] ] + [train_feeds[-1].name]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册