提交 c3976c83 编写于 作者: S SunGaofeng

move reader and metrics out of models

上级 9ba1db46
......@@ -8,7 +8,8 @@ feature_names = ['rgb', 'audio']
feature_dims = [1024, 128]
seg_num = 100
cluster_nums = [32, 32]
class_num = 3862
num_classes = 3862
topk = 20
[TRAIN]
epoch = 5
......
......@@ -8,7 +8,8 @@ feature_names = ['rgb', 'audio']
feature_dims = [1024, 128]
embedding_size = 512
lstm_size = 1024
class_num = 3862
num_classes = 3862
topk = 20
[TRAIN]
epoch = 10
......
[MODEL]
name = "NEXTVLAD"
num_classes = 3862
topk = 20
video_feature_size = 1024
audio_feature_size = 128
cluster_size = 128
......
......@@ -40,15 +40,13 @@ class FeatureReader(DataReader):
def __init__(self, name, phase, cfg):
self.name = name
self.phase = phase
self.num_classes = cfg['num_classes']
self.num_classes = cfg.MODEL.num_classes
# set batch size and file list
self.batch_size = cfg['batch_size']
self.filelist = cfg['list']
if 'eigen_file' in cfg.keys():
self.eigen_file = cfg['eigen_file']
if 'seg_num' in cfg.keys():
self.seg_num = cfg['seg_num']
self.batch_size = cfg[phase.upper()]['batch_size']
self.filelist = cfg[phase.upper()]['filelist']
self.eigen_file = cfg.MODEL.get('eigen_file', None)
self.seg_num = cfg.MODEL.get('seg_num', None)
def create_reader(self):
fl = open(self.filelist).readlines()
......
......@@ -56,22 +56,22 @@ class KineticsReader(DataReader):
def __init__(self, name, phase, cfg):
self.name = name
self.phase = phase
self.format = cfg['format']
self.num_classes = cfg['num_classes']
self.seg_num = cfg['seg_num']
self.seglen = cfg['seglen']
self.short_size = cfg['short_size']
self.target_size = cfg['target_size']
self.num_reader_threads = cfg['num_reader_threads']
self.buf_size = cfg['buf_size']
self.img_mean = np.array(cfg['image_mean']).reshape(
self.format = cfg.MODEL.format #cfg['format']
self.num_classes = cfg.MODEL.num_classes #cfg['num_classes']
self.seg_num = cfg.MODEL.segnum #['seg_num']
self.seglen = cfg.MODEL.seglen #['seglen']
self.short_size = cfg[phase.upper()]['short_size'] # ['short_size']
self.target_size = cfg[phase.upper()]['target_size'] #['target_size']
self.num_reader_threads = cfg[phase.upper()]['num_reader_threads']
self.buf_size = cfg[phase.upper()]['buf_size']
self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32)
self.img_std = np.array(cfg['image_std']).reshape(
self.img_std = np.array(cfg.MODEL.image_std).reshape(
[3, 1, 1]).astype(np.float32)
# set batch size and file list
self.batch_size = cfg['batch_size']
self.filelist = cfg['list']
self.batch_size = cfg[phase.upper()]['batch_size']
self.filelist = cfg[phase.upper()]['filelist']
def create_reader(self):
_reader = _reader_creator(self.filelist, self.phase, seg_num=self.seg_num, seglen = self.seglen, \
......
......@@ -20,7 +20,6 @@ import numpy as np
import cv2
import logging
from . import nonlocal_video_io
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
......@@ -51,43 +50,91 @@ class NonlocalReader(DataReader):
def create_reader(self):
cfg = self.cfg
assert cfg['num_reader_threads'] >=1, \
"number of reader threads({}) should be a positive integer".format(cfg['num_reader_threads'])
if cfg['num_reader_threads'] == 1:
phase = self.phase
num_reader_threads = cfg[phase.upper()]['num_reader_threads']
assert num_reader_threads >=1, \
"number of reader threads({}) should be a positive integer".format(num_reader_threads)
if num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
dataset_args = {}
dataset_args['image_mean'] = cfg['image_mean']
dataset_args['image_std'] = cfg['image_std']
dataset_args['crop_size'] = cfg['crop_size']
dataset_args['sample_rate'] = cfg['sample_rate']
dataset_args['video_length'] = cfg['video_length']
dataset_args['min_size'] = cfg['jitter_scales'][0]
dataset_args['max_size'] = cfg['jitter_scales'][1]
dataset_args['num_reader_threads'] = cfg['num_reader_threads']
dataset_args['image_mean'] = cfg.MODEL.image_mean
dataset_args['image_std'] = cfg.MODEL.image_std
dataset_args['crop_size'] = cfg[phase.upper()]['crop_size']
dataset_args['sample_rate'] = cfg[phase.upper()]['sample_rate']
dataset_args['video_length'] = cfg[phase.upper()]['video_length']
dataset_args['min_size'] = cfg[phase.upper()]['jitter_scales'][0]
dataset_args['max_size'] = cfg[phase.upper()]['jitter_scales'][1]
dataset_args['num_reader_threads'] = num_reader_threads
filelist = cfg[phase.upper()]['list']
batch_size = cfg[phase.upper()]['batch_size']
if self.phase == 'train':
sample_times = 1
return reader_func(cfg['list'], cfg['batch_size'], sample_times,
True, True, **dataset_args)
return reader_func(filelist, batch_size, sample_times, True, True,
**dataset_args)
elif self.phase == 'valid':
sample_times = 1
return reader_func(cfg['list'], cfg['batch_size'], sample_times,
False, False, **dataset_args)
return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args)
elif self.phase == 'test':
sample_times = cfg['num_test_clips']
if cfg['use_multi_crop'] == 1:
sample_times = cfg['TEST']['num_test_clips']
if cfg['TEST']['use_multi_crop'] == 1:
sample_times = int(sample_times / 3)
if cfg['use_multi_crop'] == 2:
if cfg['TEST']['use_multi_crop'] == 2:
sample_times = int(sample_times / 6)
return reader_func(cfg['list'], cfg['batch_size'], sample_times,
False, False, **dataset_args)
return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args)
else:
logger.info('Not implemented')
raise
def video_fast_get_frame(video_path,
sampling_rate=1,
length=64,
start_frm=-1,
sample_times=1):
cap = cv2.VideoCapture(video_path)
frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
sampledFrames = []
# n_frame < sample area
video_output = np.ndarray(shape=[length, height, width, 3], dtype=np.uint8)
use_start_frm = start_frm
if start_frm < 0:
if (frame_cnt - length * sampling_rate > 0):
use_start_frm = random.randint(0,
frame_cnt - length * sampling_rate)
else:
use_start_frm = 0
else:
frame_gaps = float(frame_cnt) / float(sample_times)
use_start_frm = int(frame_gaps * start_frm) % frame_cnt
for i in range(frame_cnt):
ret, frame = cap.read()
# maybe first frame is empty
if ret == False:
continue
img = frame[:, :, ::-1]
sampledFrames.append(img)
for idx in range(length):
i = use_start_frm + idx * sampling_rate
i = i % len(sampledFrames)
video_output[idx] = sampledFrames[i]
cap.release()
return video_output
def apply_resize(rgbdata, min_size, max_size):
length, height, width, channel = rgbdata.shape
ratio = 1.0
......@@ -177,7 +224,7 @@ def make_reader(filelist, batch_size, sample_times, is_training, shuffle,
label = np.array([label]).astype(np.int64)
# 1, get rgb data for fixed length of frames
try:
rgbdata = nonlocal_video_io.video_fast_get_frame(fn, \
rgbdata = video_fast_get_frame(fn, \
sampling_rate = dataset_args['sample_rate'], length = dataset_args['video_length'], \
start_frm = start_frm, sample_times = in_sample_times)
except:
......@@ -244,7 +291,7 @@ def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle,
label = np.array([label]).astype(np.int64)
# 1, get rgb data for fixed length of frames
try:
rgbdata = nonlocal_video_io.video_fast_get_frame(fn, \
rgbdata = video_fast_get_frame(fn, \
sampling_rate = dataset_args['sample_rate'], length = dataset_args['video_length'], \
start_frm = start_frm, sample_times = in_sample_times)
except:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import cv2
import numpy as np
import random
def video_fast_get_frame(video_path,
sampling_rate=1,
length=64,
start_frm=-1,
sample_times=1):
cap = cv2.VideoCapture(video_path)
frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
sampledFrames = []
# n_frame < sample area
video_output = np.ndarray(shape=[length, height, width, 3], dtype=np.uint8)
use_start_frm = start_frm
if start_frm < 0:
if (frame_cnt - length * sampling_rate > 0):
use_start_frm = random.randint(0,
frame_cnt - length * sampling_rate)
else:
use_start_frm = 0
else:
frame_gaps = float(frame_cnt) / float(sample_times)
use_start_frm = int(frame_gaps * start_frm) % frame_cnt
for i in range(frame_cnt):
ret, frame = cap.read()
# maybe first frame is empty
if ret == False:
continue
img = frame[:, :, ::-1]
sampledFrames.append(img)
for idx in range(length):
i = use_start_frm + idx * sampling_rate
i = i % len(sampledFrames)
video_output[idx] = sampledFrames[i]
cap.release()
return video_output
......@@ -70,6 +70,6 @@ def regist_reader(name, reader):
reader_zoo.regist(name, reader)
def get_reader(name, mode='train', **cfg):
def get_reader(name, mode, cfg):
reader_model = reader_zoo.get(name, mode, cfg)
return reader_model.create_reader()
......@@ -22,13 +22,13 @@ import logging
import numpy as np
from metrics.youtube8m import eval_util as youtube8m_metrics
from metrics.kinetics import accuracy_metrics as kinetics_metrics
from metrics.non_local import nonlocal_test_metrics as nonlocal_test_metrics
from metrics.multicrop_test import multicrop_test_metrics as multicrop_test_metrics
logger = logging.getLogger(__name__)
class Metrics(object):
def __init__(self, name, mode, **metrics_args):
def __init__(self, name, mode, metrics_args):
"""Not implemented"""
pass
......@@ -50,12 +50,11 @@ class Metrics(object):
class Youtube8mMetrics(Metrics):
def __init__(self, name, mode, **metrics_args):
def __init__(self, name, mode, metrics_args):
self.name = name
self.mode = mode
self.metrics_args = metrics_args
self.num_classes = metrics_args['num_classes']
self.topk = metrics_args['topk']
self.num_classes = metrics_args['MODEL']['num_classes']
self.topk = metrics_args['MODEL']['topk']
self.calculator = youtube8m_metrics.EvaluationMetrics(self.num_classes,
self.topk)
......@@ -82,12 +81,10 @@ class Youtube8mMetrics(Metrics):
class Kinetics400Metrics(Metrics):
def __init__(self, name, mode, **metrics_args):
def __init__(self, name, mode, metrics_args):
self.name = name
self.mode = mode
self.metrics_args = metrics_args
self.calculator = kinetics_metrics.MetricsCalculator(name,
mode.lower())
self.calculator = kinetics_metrics.MetricsCalculator(name, mode.lower())
def calculate_and_log_out(self, loss, pred, label, info=''):
if loss is not None:
......@@ -114,14 +111,19 @@ class Kinetics400Metrics(Metrics):
self.calculator.reset()
class NonlocalMetrics(Metrics):
def __init__(self, name, mode, **metrics_args):
class MulticropMetrics(Metrics):
def __init__(self, name, mode, metrics_args):
self.name = name
self.mode = mode
self.metrics_args = metrics_args
if mode == 'test':
self.calculator = nonlocal_test_metrics.MetricsCalculator(
name, mode.lower(), **metrics_args)
args = {}
args['num_test_clips'] = metrics_args.TEST.num_test_clips
args['dataset_size'] = metrics_args.TEST.dataset_size
args['filename_gt'] = metrics_args.TEST.filename_gt
args['checkpoint_dir'] = metrics_args.TEST.checkpoint_dir
args['num_classes'] = metrics_args.MODEL.num_classes
self.calculator = multicrop_test_metrics.MetricsCalculator(
name, mode.lower(), **args)
else:
self.calculator = kinetics_metrics.MetricsCalculator(name,
mode.lower())
......@@ -166,10 +168,10 @@ class MetricsZoo(object):
type(metrics))
self.metrics_zoo[name] = metrics
def get(self, name, mode, **cfg):
def get(self, name, mode, cfg):
for k, v in self.metrics_zoo.items():
if k == name:
return v(name, mode, **cfg)
return v(name, mode, cfg)
raise MetricsNotFoundError(name, self.metrics_zoo.keys())
......@@ -181,8 +183,8 @@ def regist_metrics(name, metrics):
metrics_zoo.regist(name, metrics)
def get_metrics(name, mode='train', **cfg):
return metrics_zoo.get(name, mode, **cfg)
def get_metrics(name, mode, cfg):
return metrics_zoo.get(name, mode, cfg)
regist_metrics("NEXTVLAD", Youtube8mMetrics)
......@@ -191,4 +193,4 @@ regist_metrics("ATTENTIONCLUSTER", Youtube8mMetrics)
regist_metrics("TSN", Kinetics400Metrics)
regist_metrics("TSM", Kinetics400Metrics)
regist_metrics("STNET", Kinetics400Metrics)
regist_metrics("NONLOCAL", NonlocalMetrics)
regist_metrics("NONLOCAL", MulticropMetrics)
......@@ -34,7 +34,7 @@ class AttentionCluster(ModelBase):
self.feature_dims = self.cfg.MODEL.feature_dims
self.cluster_nums = self.cfg.MODEL.cluster_nums
self.seg_num = self.cfg.MODEL.seg_num
self.class_num = self.cfg.MODEL.class_num
self.class_num = self.cfg.MODEL.num_classes #self.cfg.MODEL.class_num
self.drop_rate = self.cfg.MODEL.drop_rate
# get mode configs
......
......@@ -154,13 +154,13 @@ class STNET(ModelBase):
return {}
def load_pretrain_params(self, exe, pretrain, prog):
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars)
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars)
param_tensor = fluid.global_scope().find_var(
"conv1_weights").get_tensor()
......
......@@ -21,8 +21,10 @@ import numpy as np
import paddle.fluid as fluid
from tools.train_utils import train_with_pyreader, train_without_pyreader
from config import *
import models
from config import *
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
......@@ -104,10 +106,8 @@ def train(args):
config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args))
train_model = models.get_model(
args.model_name, train_config, mode='train')
valid_model = models.get_model(
args.model_name, valid_config, mode='valid')
train_model = models.get_model(args.model_name, train_config, mode='train')
valid_model = models.get_model(args.model_name, valid_config, mode='valid')
# build model
startup = fluid.Program()
......@@ -141,7 +141,7 @@ def train(args):
valid_feeds = valid_model.feeds()
valid_outputs = valid_model.outputs()
valid_loss = valid_model.loss()
valid_metrics = valid_model.metrics()
#valid_metrics = valid_model.metrics()
valid_pyreader = valid_model.pyreader()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......@@ -165,16 +165,16 @@ def train(args):
main_program=valid_prog)
# get reader
# train_reader = get_reader(train_config)
# valid_reader = get_reader(valid_config)
train_reader = train_model.reader()
valid_reader = valid_model.reader()
train_reader = get_reader(args.model_name.upper(), 'train', train_config)
valid_reader = get_reader(args.model_name.upper(), 'valid', valid_config)
#train_reader = train_model.reader()
#valid_reader = valid_model.reader()
# get metrics
# train_metrics = get_metrics(train_config)
# valid_metrics = get_metrics(valid_config)
train_metrics = train_model.metrics()
train_metrics = train_model.metrics()
train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
#train_metrics = train_model.metrics()
#train_metrics = train_model.metrics()
train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
] + [train_feeds[-1].name]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册