未验证 提交 362f62cc 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #1710 from SunGaofeng/video_classification

add video classification models in PaddleCV
dataset
checkpoints
output*
*.pyc
*.swp
*_result
# VideoClassification
Video Classification
To run train:
bash ./scripts/train/train_${model_name}.sh
To run test:
bash ./scripts/test/test_${model_name}.sh
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
try:
from configparser import ConfigParser
except:
from ConfigParser import ConfigParser
from utils import AttrDict
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
def parse_config(cfg_file):
parser = ConfigParser()
cfg = AttrDict()
parser.read(cfg_file)
for sec in parser.sections():
sec_dict = AttrDict()
for k, v in parser.items(sec):
try:
v = eval(v)
except:
pass
setattr(sec_dict, k, v)
setattr(cfg, sec.upper(), sec_dict)
return cfg
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
[MODEL]
name = "AttentionCluster"
dataset = "YouTube-8M"
bone_network = None
drop_rate = 0.5
feature_num = 2
feature_names = ['rgb', 'audio']
feature_dims = [1024, 128]
seg_num = 100
cluster_nums = [32, 32]
num_classes = 3862
topk = 20
[TRAIN]
epoch = 5
learning_rate = 0.001
pretrain_base = None
batch_size = 160
use_gpu = True
num_gpus = 4
filelist = "dataset/youtube8m/train.list"
[VALID]
batch_size = 160
filelist = "dataset/youtube8m/val.list"
[TEST]
batch_size = 40
filelist = "dataset/youtube8m/test.list"
[INFER]
batch_size = 1
filelist = "dataset/youtube8m/infer.list"
[MODEL]
name = "AttentionLSTM"
dataset = "YouTube-8M"
bone_nework = None
drop_rate = 0.5
feature_num = 2
feature_names = ['rgb', 'audio']
feature_dims = [1024, 128]
embedding_size = 512
lstm_size = 1024
num_classes = 3862
topk = 20
[TRAIN]
epoch = 10
learning_rate = 0.001
decay_epochs = [5]
decay_gamma = 0.1
weight_decay = 0.0008
num_samples = 5000000
pretrain_base = None
batch_size = 160
use_gpu = True
num_gpus = 4
filelist = "dataset/youtube8m/train.list"
[VALID]
batch_size = 160
filelist = "dataset/youtube8m/val.list"
[TEST]
batch_size = 40
filelist = "dataset/youtube8m/test.list"
[INFER]
batch_size = 1
filelist = "dataset/youtube8m/infer.list"
[MODEL]
name = "NEXTVLAD"
num_classes = 3862
topk = 20
video_feature_size = 1024
audio_feature_size = 128
cluster_size = 128
hidden_size = 2048
groups = 8
expansion = 2
drop_rate = 0.5
gating_reduction = 8
eigen_file = "./dataset/youtube8m/yt8m_pca/eigenvals.npy"
[TRAIN]
epoch = 6
learning_rate = 0.0002
lr_boundary_examples = 2000000
max_iter = 700000
learning_rate_decay = 0.8
l2_penalty = 1e-5
gradient_clip_norm = 1.0
use_gpu = True
num_gpus = 4
batch_size = 160
filelist = "./dataset/youtube8m/train.list"
[VALID]
batch_size = 160
filelist = "./dataset/youtube8m/val.list"
[TEST]
batch_size = 40
filelist = "./dataset/youtube8m/test.list"
[INFER]
batch_size = 1
filelist = "./dataset/youtube8m/infer.list"
[MODEL]
name = "STNET"
format = "pkl"
num_classes = 400
seg_num = 7
seglen = 5
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
num_layers = 50
[TRAIN]
epoch = 60
short_size = 256
target_size = 224
num_reader_threads = 12
buf_size = 1024
batch_size = 128
num_gpus = 8
use_gpu = True
filelist = "./dataset/kinetics/train.list"
learning_rate = 0.01
learning_rate_decay = 0.1
l2_weight_decay = 1e-4
momentum = 0.9
total_videos = 224684
pretrain_base = "./dataset/pretrained/ResNet50_pretrained"
[VALID]
short_size = 256
target_size = 224
num_reader_threads = 12
buf_size = 1024
batch_size = 128
filelist = "./dataset/kinetics/val.list"
[TEST]
short_size = 256
target_size = 256
num_reader_threads = 12
buf_size = 1024
batch_size = 16
filelist = "./dataset/kinetics/test.list"
[INFER]
short_size = 256
target_size = 256
num_reader_threads = 12
buf_size = 1024
batch_size = 1
filelist = "./dataset/kinetics/infer.list"
[MODEL]
name = "TSN"
format = "pkl"
num_classes = 400
seg_num = 3
seglen = 1
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
num_layers = 50
[TRAIN]
epoch = 45
short_size = 256
target_size = 224
num_reader_threads = 12
buf_size = 1024
batch_size = 256
use_gpu = True
num_gpus = 8
filelist = "./dataset/kinetics/train.list"
learning_rate = 0.01
learning_rate_decay = 0.1
l2_weight_decay = 1e-4
momentum = 0.9
total_videos = 224684
[VALID]
short_size = 256
target_size = 224
num_reader_threads = 12
buf_size = 1024
batch_size = 256
filelist = "./dataset/kinetics/val.list"
[TEST]
short_size = 256
target_size = 224
num_reader_threads = 12
buf_size = 1024
batch_size = 32
filelist = "./dataset/kinetics/test.list"
[INFER]
short_size = 256
target_size = 224
num_reader_threads = 12
buf_size = 1024
batch_size = 1
filelist = "./dataset/kinetics/infer.list"
from .reader_utils import regist_reader, get_reader
from .feature_reader import FeatureReader
from .kinetics_reader import KineticsReader
from .nonlocal_reader import NonlocalReader
regist_reader("ATTENTIONCLUSTER", FeatureReader)
regist_reader("NEXTVLAD", FeatureReader)
regist_reader("ATTENTIONLSTM", FeatureReader)
regist_reader("TSN", KineticsReader)
regist_reader("TSM", KineticsReader)
regist_reader("STNET", KineticsReader)
regist_reader("NONLOCAL", NonlocalReader)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import sys
from .reader_utils import DataReader
try:
import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
import numpy as np
import random
python_ver = sys.version_info
class FeatureReader(DataReader):
"""
Data reader for youtube-8M dataset, which was stored as features extracted by prior networks
This is for the three models: lstm, attention cluster, nextvlad
dataset cfg: num_classes
batch_size
list
NextVlad only: eigen_file
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.num_classes = cfg.MODEL.num_classes
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[mode.upper()]['filelist']
self.eigen_file = cfg.MODEL.get('eigen_file', None)
self.seg_num = cfg.MODEL.get('seg_num', None)
def create_reader(self):
fl = open(self.filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != '']
if self.mode == 'train':
random.shuffle(fl)
def reader():
batch_out = []
for filepath in fl:
if python_ver < (3, 0):
data = pickle.load(open(filepath, 'rb'))
else:
data = pickle.load(open(filepath, 'rb'), encoding='bytes')
indexes = list(range(len(data)))
if self.mode == 'train':
random.shuffle(indexes)
for i in indexes:
record = data[i]
nframes = record[b'nframes']
rgb = record[b'feature'].astype(float)
audio = record[b'audio'].astype(float)
if self.mode != 'infer':
label = record[b'label']
one_hot_label = make_one_hot(label, self.num_classes)
video = record[b'video']
rgb = rgb[0:nframes, :]
audio = audio[0:nframes, :]
rgb = dequantize(
rgb, max_quantized_value=2., min_quantized_value=-2.)
audio = dequantize(
audio, max_quantized_value=2, min_quantized_value=-2)
if self.name == 'NEXTVLAD':
# add the effect of eigen values
eigen_file = self.eigen_file
eigen_val = np.sqrt(np.load(eigen_file)
[:1024, 0]).astype(np.float32)
eigen_val = eigen_val + 1e-4
rgb = (rgb - 4. / 512) * eigen_val
if self.name == 'ATTENTIONCLUSTER':
sample_inds = generate_random_idx(rgb.shape[0],
self.seg_num)
rgb = rgb[sample_inds]
audio = audio[sample_inds]
if self.mode != 'infer':
batch_out.append((rgb, audio, one_hot_label))
else:
batch_out.append((rgb, audio, video))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def dequantize(feat_vector, max_quantized_value=2., min_quantized_value=-2.):
"""
Dequantize the feature from the byte format to the float format
"""
assert max_quantized_value > min_quantized_value
quantized_range = max_quantized_value - min_quantized_value
scalar = quantized_range / 255.0
bias = (quantized_range / 512.0) + min_quantized_value
return feat_vector * scalar + bias
def make_one_hot(label, dim=3862):
one_hot_label = np.zeros(dim)
one_hot_label = one_hot_label.astype(float)
for ind in label:
one_hot_label[int(ind)] = 1
return one_hot_label
def generate_random_idx(feature_len, seg_num):
idxs = []
stride = float(feature_len) / seg_num
for i in range(seg_num):
pos = (i + np.random.random()) * stride
idxs.append(min(feature_len - 1, int(pos)))
return idxs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import math
import random
import functools
try:
import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
import numpy as np
import paddle
from PIL import Image, ImageEnhance
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
class KineticsReader(DataReader):
"""
Data reader for kinetics dataset of two format mp4 and pkl.
1. mp4, the original format of kinetics400
2. pkl, the mp4 was decoded previously and stored as pkl
In both case, load the data, and then get the frame data in the form of numpy and label as an integer.
dataset cfg: format
num_classes
seg_num
short_size
target_size
num_reader_threads
buf_size
image_mean
image_std
batch_size
list
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.format = cfg.MODEL.format
self.num_classes = cfg.MODEL.num_classes
self.seg_num = cfg.MODEL.seg_num
self.seglen = cfg.MODEL.seglen
self.short_size = cfg[mode.upper()]['short_size']
self.target_size = cfg[mode.upper()]['target_size']
self.num_reader_threads = cfg[mode.upper()]['num_reader_threads']
self.buf_size = cfg[mode.upper()]['buf_size']
self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32)
self.img_std = np.array(cfg.MODEL.image_std).reshape(
[3, 1, 1]).astype(np.float32)
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.filelist = cfg[mode.upper()]['filelist']
def create_reader(self):
_reader = _reader_creator(self.filelist, self.mode, seg_num=self.seg_num, seglen = self.seglen, \
short_size = self.short_size, target_size = self.target_size, \
img_mean = self.img_mean, img_std = self.img_std, \
shuffle = (self.mode == 'train'), \
num_threads = self.num_reader_threads, \
buf_size = self.buf_size, format = self.format)
def _batch_reader():
batch_out = []
for imgs, label in _reader():
if imgs is None:
continue
batch_out.append((imgs, label))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return _batch_reader
def _reader_creator(pickle_list,
mode,
seg_num,
seglen,
short_size,
target_size,
img_mean,
img_std,
shuffle=False,
num_threads=1,
buf_size=1024,
format='pkl'):
def reader():
with open(pickle_list) as flist:
lines = [line.strip() for line in flist]
if shuffle:
random.shuffle(lines)
for line in lines:
pickle_path = line.strip()
yield [pickle_path]
if format == 'pkl':
decode_func = decode_pickle
elif format == 'mp4':
decode_func = decode_mp4
else:
raise "Not implemented format {}".format(format)
mapper = functools.partial(
decode_func,
mode=mode,
seg_num=seg_num,
seglen=seglen,
short_size=short_size,
target_size=target_size,
img_mean=img_mean,
img_std=img_std)
return paddle.reader.xmap_readers(mapper, reader, num_threads, buf_size)
def decode_mp4(sample, mode, seg_num, seglen, short_size, target_size, img_mean,
img_std):
sample = sample[0].split(' ')
mp4_path = sample[0]
# when infer, we store vid as label
label = int(sample[1])
try:
imgs = mp4_loader(mp4_path, seg_num, seglen, mode)
if len(imgs) < 1:
logger.error('{} frame length {} less than 1.'.format(mp4_path,
len(imgs)))
return None, None
except:
logger.error('Error when loading {}'.format(mp4_path))
return None, None
return imgs_transform(imgs, label, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std)
def decode_pickle(sample, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std):
pickle_path = sample[0]
try:
if python_ver < (3, 0):
data_loaded = pickle.load(open(pickle_path, 'rb'))
else:
data_loaded = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
vid, label, frames = data_loaded
if len(frames) < 1:
logger.error('{} frame length {} less than 1.'.format(pickle_path,
len(frames)))
return None, None
except:
logger.info('Error when loading {}'.format(pickle_path))
return None, None
if mode == 'train' or mode == 'valid' or mode == 'test':
ret_label = label
elif mode == 'infer':
ret_label = vid
imgs = video_loader(frames, seg_num, seglen, mode)
return imgs_transform(imgs, ret_label, mode, seg_num, seglen, \
short_size, target_size, img_mean, img_std)
def imgs_transform(imgs, label, mode, seg_num, seglen, short_size, target_size,
img_mean, img_std):
imgs = group_scale(imgs, short_size)
if mode == 'train':
imgs = group_random_crop(imgs, target_size)
imgs = group_random_flip(imgs)
else:
imgs = group_center_crop(imgs, target_size)
np_imgs = (np.array(imgs[0]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
for i in range(len(imgs) - 1):
img = (np.array(imgs[i + 1]).astype('float32').transpose(
(2, 0, 1))).reshape(1, 3, target_size, target_size) / 255
np_imgs = np.concatenate((np_imgs, img))
imgs = np_imgs
imgs -= img_mean
imgs /= img_std
imgs = np.reshape(imgs, (seg_num, seglen * 3, target_size, target_size))
return imgs, label
def group_random_crop(img_group, target_size):
w, h = img_group[0].size
th, tw = target_size, target_size
assert (w >= target_size) and (h >= target_size), \
"image width({}) and height({}) should be larger than crop size".format(w, h, target_size)
out_images = []
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for img in img_group:
if w == tw and h == th:
out_images.append(img)
else:
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return out_images
def group_random_flip(img_group):
v = random.random()
if v < 0.5:
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
return ret
else:
return img_group
def group_center_crop(img_group, target_size):
img_crop = []
for img in img_group:
w, h = img.size
th, tw = target_size, target_size
assert (w >= target_size) and (h >= target_size), \
"image width({}) and height({}) should be larger than crop size".format(w, h, target_size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
img_crop.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return img_crop
def group_scale(imgs, target_size):
resized_imgs = []
for i in range(len(imgs)):
img = imgs[i]
w, h = img.size
if (w <= h and w == target_size) or (h <= w and h == target_size):
resized_imgs.append(img)
continue
if w < h:
ow = target_size
oh = int(target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
else:
oh = target_size
ow = int(target_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
return resized_imgs
def imageloader(buf):
if isinstance(buf, str):
img = Image.open(StringIO(buf))
else:
img = Image.open(BytesIO(buf))
return img.convert('RGB')
def video_loader(frames, nsample, seglen, mode):
videolen = len(frames)
average_dur = int(videolen / nsample)
imgs = []
for i in range(nsample):
idx = 0
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= seglen:
idx = (average_dur - seglen) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + seglen):
imgbuf = frames[int(jj % videolen)]
img = imageloader(imgbuf)
imgs.append(img)
return imgs
def mp4_loader(filepath, nsample, seglen, mode):
cap = cv2.VideoCapture(filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
average_dur = int(videolen / nsample)
sampledFrames = []
for i in range(videolen):
ret, frame = cap.read()
# maybe first frame is empty
if ret == False:
continue
img = frame[:, :, ::-1]
sampledFrames.append(img)
imgs = []
for i in range(nsample):
idx = 0
if mode == 'train':
if average_dur >= seglen:
idx = random.randint(0, average_dur - seglen)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= seglen:
idx = (average_dur - 1) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + seglen):
imgbuf = sampledFrames[int(jj % videolen)]
img = Image.fromarray(imgbuf, mode='RGB')
imgs.append(img)
return imgs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import random
import time
import multiprocessing
import numpy as np
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
class NonlocalReader(DataReader):
"""
Data reader for kinetics dataset, which read mp4 file and decode into numpy.
This is for nonlocal neural network model.
cfg: num_classes
num_reader_threads
image_mean
image_std
batch_size
list
crop_size
sample_rate
video_length
jitter_scales
Test only cfg: num_test_clips
use_multi_crop
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.cfg = cfg
def create_reader(self):
cfg = self.cfg
mode = self.mode
num_reader_threads = cfg[mode.upper()]['num_reader_threads']
assert num_reader_threads >=1, \
"number of reader threads({}) should be a positive integer".format(num_reader_threads)
if num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
dataset_args = {}
dataset_args['image_mean'] = cfg.MODEL.image_mean
dataset_args['image_std'] = cfg.MODEL.image_std
dataset_args['crop_size'] = cfg[mode.upper()]['crop_size']
dataset_args['sample_rate'] = cfg[mode.upper()]['sample_rate']
dataset_args['video_length'] = cfg[mode.upper()]['video_length']
dataset_args['min_size'] = cfg[mode.upper()]['jitter_scales'][0]
dataset_args['max_size'] = cfg[mode.upper()]['jitter_scales'][1]
dataset_args['num_reader_threads'] = num_reader_threads
filelist = cfg[mode.upper()]['list']
batch_size = cfg[mode.upper()]['batch_size']
if self.mode == 'train':
sample_times = 1
return reader_func(filelist, batch_size, sample_times, True, True,
**dataset_args)
elif self.mode == 'valid':
sample_times = 1
return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args)
elif self.mode == 'test':
sample_times = cfg['TEST']['num_test_clips']
if cfg['TEST']['use_multi_crop'] == 1:
sample_times = int(sample_times / 3)
if cfg['TEST']['use_multi_crop'] == 2:
sample_times = int(sample_times / 6)
return reader_func(filelist, batch_size, sample_times, False, False,
**dataset_args)
else:
logger.info('Not implemented')
raise NotImplementedError
def video_fast_get_frame(video_path,
sampling_rate=1,
length=64,
start_frm=-1,
sample_times=1):
cap = cv2.VideoCapture(video_path)
frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
sampledFrames = []
video_output = np.ndarray(shape=[length, height, width, 3], dtype=np.uint8)
use_start_frm = start_frm
if start_frm < 0:
if (frame_cnt - length * sampling_rate > 0):
use_start_frm = random.randint(0,
frame_cnt - length * sampling_rate)
else:
use_start_frm = 0
else:
frame_gaps = float(frame_cnt) / float(sample_times)
use_start_frm = int(frame_gaps * start_frm) % frame_cnt
for i in range(frame_cnt):
ret, frame = cap.read()
# maybe first frame is empty
if ret == False:
continue
img = frame[:, :, ::-1]
sampledFrames.append(img)
for idx in range(length):
i = use_start_frm + idx * sampling_rate
i = i % len(sampledFrames)
video_output[idx] = sampledFrames[i]
cap.release()
return video_output
def apply_resize(rgbdata, min_size, max_size):
length, height, width, channel = rgbdata.shape
ratio = 1.0
# generate random scale between [min_size, max_size]
if min_size == max_size:
side_length = min_size
else:
side_length = np.random.randint(min_size, max_size)
if height > width:
ratio = float(side_length) / float(width)
else:
ratio = float(side_length) / float(height)
out_height = int(height * ratio)
out_width = int(width * ratio)
outdata = np.zeros(
(length, out_height, out_width, channel), dtype=rgbdata.dtype)
for i in range(length):
outdata[i] = cv2.resize(rgbdata[i], (out_width, out_height))
return outdata
def crop_mirror_transform(rgbdata,
mean,
std,
cropsize=224,
use_mirror=True,
center_crop=False,
spatial_pos=-1):
channel, length, height, width = rgbdata.shape
assert height >= cropsize, "crop size should not be larger than video height"
assert width >= cropsize, "crop size should not be larger than video width"
# crop to specific scale
if center_crop:
h_off = int((height - cropsize) / 2)
w_off = int((width - cropsize) / 2)
if spatial_pos >= 0:
now_pos = spatial_pos % 3
if h_off > 0:
h_off = h_off * now_pos
else:
w_off = w_off * now_pos
else:
h_off = np.random.randint(0, height - cropsize)
w_off = np.random.randint(0, width - cropsize)
outdata = np.zeros(
(channel, length, cropsize, cropsize), dtype=rgbdata.dtype)
outdata[:, :, :, :] = rgbdata[:, :, h_off:h_off + cropsize, w_off:w_off +
cropsize]
# apply mirror
mirror_indicator = (np.random.rand() > 0.5)
mirror_me = use_mirror and mirror_indicator
if spatial_pos > 0:
mirror_me = (int(spatial_pos / 3) > 0)
if mirror_me:
outdata = outdata[:, :, :, ::-1]
# substract mean and divide std
outdata = outdata.astype(np.float32)
outdata = (outdata - mean) / std
return outdata
def make_reader(filelist, batch_size, sample_times, is_training, shuffle,
**dataset_args):
# should add smaple_times param
fl = open(filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != '']
if shuffle:
random.shuffle(fl)
def reader():
batch_out = []
for line in fl:
# start_time = time.time()
line_items = line.split(' ')
fn = line_items[0]
label = int(line_items[1])
if len(line_items) > 2:
start_frm = int(line_items[2])
spatial_pos = int(line_items[3])
in_sample_times = sample_times
else:
start_frm = -1
spatial_pos = -1
in_sample_times = 1
label = np.array([label]).astype(np.int64)
# 1, get rgb data for fixed length of frames
try:
rgbdata = video_fast_get_frame(fn, \
sampling_rate = dataset_args['sample_rate'], length = dataset_args['video_length'], \
start_frm = start_frm, sample_times = in_sample_times)
except:
logger.info('Error when loading {}, just skip this file'.format(
fn))
continue
# add prepocessing
# 2, reszie to randomly scale between [min_size, max_size] when training, or cgf.TEST.SCALE when inference
min_size = dataset_args['min_size']
max_size = dataset_args['max_size']
rgbdata = apply_resize(rgbdata, min_size, max_size)
# transform [length, height, width, channel] to [channel, length, height, width]
rgbdata = np.transpose(rgbdata, [3, 0, 1, 2])
# 3 crop, mirror and transform
rgbdata = crop_mirror_transform(rgbdata, mean = dataset_args['image_mean'], \
std = dataset_args['image_std'], cropsize = dataset_args['crop_size'], \
use_mirror = is_training, center_crop = (not is_training), \
spatial_pos = spatial_pos)
batch_out.append((rgbdata, label))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader
def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle,
**dataset_args):
fl = open(filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != '']
if shuffle:
random.shuffle(fl)
n = dataset_args['num_reader_threads']
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
def read_into_queue(flq, queue):
batch_out = []
for line in flq:
line_items = line.split(' ')
fn = line_items[0]
label = int(line_items[1])
if len(line_items) > 2:
start_frm = int(line_items[2])
spatial_pos = int(line_items[3])
in_sample_times = sample_times
else:
start_frm = -1
spatial_pos = -1
in_sample_times = 1
label = np.array([label]).astype(np.int64)
# 1, get rgb data for fixed length of frames
try:
rgbdata = video_fast_get_frame(fn, \
sampling_rate = dataset_args['sample_rate'], length = dataset_args['video_length'], \
start_frm = start_frm, sample_times = in_sample_times)
except:
logger.info('Error when loading {}, just skip this file'.format(
fn))
continue
# add prepocessing
# 2, reszie to randomly scale between [min_size, max_size] when training, or cgf.TEST.SCALE when inference
min_size = dataset_args['min_size']
max_size = dataset_args['max_size']
rgbdata = apply_resize(rgbdata, min_size, max_size)
# transform [length, height, width, channel] to [channel, length, height, width]
rgbdata = np.transpose(rgbdata, [3, 0, 1, 2])
# 3 crop, mirror and transform
rgbdata = crop_mirror_transform(rgbdata, mean = dataset_args['image_mean'], \
std = dataset_args['image_std'], cropsize = dataset_args['crop_size'], \
use_mirror = is_training, center_crop = (not is_training), \
spatial_pos = spatial_pos)
batch_out.append((rgbdata, label))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
p_list[i].terminate()
p_list[i].join()
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import pickle
import cv2
import numpy as np
import random
class ReaderNotFoundError(Exception):
"Error: reader not found"
def __init__(self, reader_name, avail_readers):
super(ReaderNotFoundError, self).__init__()
self.reader_name = reader_name
self.avail_readers = avail_readers
def __str__(self):
msg = "Reader {} Not Found.\nAvailiable readers:\n".format(
self.reader_name)
for reader in self.avail_readers:
msg += " {}\n".format(reader)
return msg
class DataReader(object):
"""data reader for video input"""
def __init__(self, model_name, mode, cfg):
"""Not implemented"""
pass
def create_reader(self):
"""Not implemented"""
pass
class ReaderZoo(object):
def __init__(self):
self.reader_zoo = {}
def regist(self, name, reader):
assert reader.__base__ == DataReader, "Unknow model type {}".format(
type(reader))
self.reader_zoo[name] = reader
def get(self, name, mode, cfg):
for k, v in self.reader_zoo.items():
if k == name:
return v(name, mode, cfg)
raise ReaderNotFoundError(name, self.reader_zoo.keys())
# singleton reader_zoo
reader_zoo = ReaderZoo()
def regist_reader(name, reader):
reader_zoo.regist(name, reader)
def get_reader(name, mode, cfg):
reader_model = reader_zoo.get(name, mode, cfg)
return reader_model.create_reader()
1. download kinetics-400_train.csv and kinetics-400_val.csv
2. ffmpeg is required to decode mp4
3. transfer mp4 video to pkl file, with each pkl stores [video_id, images, label]
python generate_label.py kinetics-400_train.csv kinetics400_label.txt # generate label file
python video2pkl.py kinetics-400_train.csv $Source_dir $Target_dir $NUM_THREADS
import sys
# kinetics-400_train.csv should be down loaded first and set as sys.argv[1]
# sys.argv[2] can be set as kinetics400_label.txt
# python generate_label.py kinetics-400_train.csv kinetics400_label.txt
num_classes = 400
fname = sys.argv[1]
outname = sys.argv[2]
fl = open(fname).readlines()
fl = fl[1:]
outf = open(outname, 'w')
label_list = []
for line in fl:
label = line.strip().split(',')[0].strip('"')
if label in label_list:
continue
else:
label_list.append(label)
assert len(label_list
) == num_classes, "there should be {} labels in list, but ".format(
num_classes, len(label_list))
label_list.sort()
for i in range(num_classes):
outf.write('{} {}'.format(label_list[i], i) + '\n')
outf.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import glob
import cPickle
from multiprocessing import Pool
# example command line: python generate_k400_pkl.py kinetics-400_train.csv 8
#
# kinetics-400_train.csv is the training set file of K400 official release
# each line contains laebl,youtube_id,time_start,time_end,split,is_cc
assert (len(sys.argv) == 5)
f = open(sys.argv[1])
source_dir = sys.argv[2]
target_dir = sys.argv[3]
num_threads = sys.argv[4]
all_video_entries = [x.strip().split(',') for x in f.readlines()]
all_video_entries = all_video_entries[1:]
f.close()
category_label_map = {}
f = open('kinetics400_label.txt')
for line in f:
ens = line.strip().split(' ')
category = " ".join(ens[0:-1])
label = int(ens[-1])
category_label_map[category] = label
f.close()
def generate_pkl(entry):
mode = entry[4]
category = entry[0].strip('"')
category_dir = category
video_path = os.path.join(
'./',
entry[1] + "_%06d" % int(entry[2]) + "_%06d" % int(entry[3]) + ".mp4")
video_path = os.path.join(source_dir, category_dir, video_path)
label = category_label_map[category]
vid = './' + video_path.split('/')[-1].split('.')[0]
if os.path.exists(video_path):
if not os.path.exists(vid):
os.makedirs(vid)
os.system('ffmpeg -i ' + video_path + ' -q 0 ' + vid + '/%06d.jpg')
else:
print("File not exists {}".format(video_path))
return
images = sorted(glob.glob(vid + '/*.jpg'))
ims = []
for img in images:
f = open(img)
ims.append(f.read())
f.close()
output_pkl = vid + ".pkl"
output_pkl = os.path.join(target_dir, output_pkl)
f = open(output_pkl, 'w')
cPickle.dump((vid, label, ims), f, -1)
f.close()
os.system('rm -rf %s' % vid)
pool = Pool(processes=int(sys.argv[4]))
pool.map(generate_pkl, all_video_entries)
pool.close()
pool.join()
1. Tensorflow is required to process tfrecords
2. python tf2pkl.py $Source_dir $Target_dir
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""Provides readers configured for different datasets."""
import os, sys
import numpy as np
import tensorflow as tf
from tensorflow import logging
import cPickle
from tensorflow.python.platform import gfile
assert (len(sys.argv) == 3)
source_dir = sys.argv[1]
target_dir = sys.argv[2]
def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
"""Dequantize the feature from the byte format to the float format.
Args:
feat_vector: the input 1-d vector.
max_quantized_value: the maximum of the quantized value.
min_quantized_value: the minimum of the quantized value.
Returns:
A float vector which has the same shape as feat_vector.
"""
assert max_quantized_value > min_quantized_value
quantized_range = max_quantized_value - min_quantized_value
scalar = quantized_range / 255.0
bias = (quantized_range / 512.0) + min_quantized_value
return feat_vector * scalar + bias
def resize_axis(tensor, axis, new_size, fill_value=0):
"""Truncates or pads a tensor to new_size on on a given axis.
Truncate or extend tensor such that tensor.shape[axis] == new_size. If the
size increases, the padding will be performed at the end, using fill_value.
Args:
tensor: The tensor to be resized.
axis: An integer representing the dimension to be sliced.
new_size: An integer or 0d tensor representing the new value for
tensor.shape[axis].
fill_value: Value to use to fill any new entries in the tensor. Will be
cast to the type of tensor.
Returns:
The resized tensor.
"""
tensor = tf.convert_to_tensor(tensor)
shape = tf.unstack(tf.shape(tensor))
pad_shape = shape[:]
pad_shape[axis] = tf.maximum(0, new_size - shape[axis])
shape[axis] = tf.minimum(shape[axis], new_size)
shape = tf.stack(shape)
resized = tf.concat([
tf.slice(tensor, tf.zeros_like(shape), shape),
tf.fill(tf.stack(pad_shape), tf.cast(fill_value, tensor.dtype))
], axis)
# Update shape.
new_shape = tensor.get_shape().as_list() # A copy is being made.
new_shape[axis] = new_size
resized.set_shape(new_shape)
return resized
class BaseReader(object):
"""Inherit from this class when implementing new readers."""
def prepare_reader(self, unused_filename_queue):
"""Create a thread for generating prediction and label tensors."""
raise NotImplementedError()
class YT8MFrameFeatureReader(BaseReader):
"""Reads TFRecords of SequenceExamples.
The TFRecords must contain SequenceExamples with the sparse in64 'labels'
context feature and a fixed length byte-quantized feature vector, obtained
from the features in 'feature_names'. The quantized features will be mapped
back into a range between min_quantized_value and max_quantized_value.
"""
def __init__(self,
num_classes=3862,
feature_sizes=[1024],
feature_names=["inc3"],
max_frames=300):
"""Construct a YT8MFrameFeatureReader.
Args:
num_classes: a positive integer for the number of classes.
feature_sizes: positive integer(s) for the feature dimensions as a list.
feature_names: the feature name(s) in the tensorflow record as a list.
max_frames: the maximum number of frames to process.
"""
assert len(feature_names) == len(feature_sizes), \
"length of feature_names (={}) != length of feature_sizes (={})".format( \
len(feature_names), len(feature_sizes))
self.num_classes = num_classes
self.feature_sizes = feature_sizes
self.feature_names = feature_names
self.max_frames = max_frames
def get_video_matrix(self, features, feature_size, max_frames,
max_quantized_value, min_quantized_value):
"""Decodes features from an input string and quantizes it.
Args:
features: raw feature values
feature_size: length of each frame feature vector
max_frames: number of frames (rows) in the output feature_matrix
max_quantized_value: the maximum of the quantized value.
min_quantized_value: the minimum of the quantized value.
Returns:
feature_matrix: matrix of all frame-features
num_frames: number of frames in the sequence
"""
decoded_features = tf.reshape(
tf.cast(tf.decode_raw(features, tf.uint8), tf.float32),
[-1, feature_size])
num_frames = tf.minimum(tf.shape(decoded_features)[0], max_frames)
feature_matrix = decoded_features
return feature_matrix, num_frames
def prepare_reader(self,
filename_queue,
max_quantized_value=2,
min_quantized_value=-2):
"""Creates a single reader thread for YouTube8M SequenceExamples.
Args:
filename_queue: A tensorflow queue of filename locations.
max_quantized_value: the maximum of the quantized value.
min_quantized_value: the minimum of the quantized value.
Returns:
A tuple of video indexes, video features, labels, and padding data.
"""
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
contexts, features = tf.parse_single_sequence_example(
serialized_example,
context_features={
"id": tf.FixedLenFeature([], tf.string),
"labels": tf.VarLenFeature(tf.int64)
},
sequence_features={
feature_name: tf.FixedLenSequenceFeature(
[], dtype=tf.string)
for feature_name in self.feature_names
})
# read ground truth labels
labels = (tf.cast(
tf.sparse_to_dense(
contexts["labels"].values, (self.num_classes, ),
1,
validate_indices=False),
tf.bool))
# loads (potentially) different types of features and concatenates them
num_features = len(self.feature_names)
assert num_features > 0, "No feature selected: feature_names is empty!"
assert len(self.feature_names) == len(self.feature_sizes), \
"length of feature_names (={}) != length of feature_sizes (={})".format( \
len(self.feature_names), len(self.feature_sizes))
num_frames = -1 # the number of frames in the video
feature_matrices = [None
] * num_features # an array of different features
for feature_index in range(num_features):
feature_matrix, num_frames_in_this_feature = self.get_video_matrix(
features[self.feature_names[feature_index]],
self.feature_sizes[feature_index], self.max_frames,
max_quantized_value, min_quantized_value)
if num_frames == -1:
num_frames = num_frames_in_this_feature
#else:
# tf.assert_equal(num_frames, num_frames_in_this_feature)
feature_matrices[feature_index] = feature_matrix
# cap the number of frames at self.max_frames
num_frames = tf.minimum(num_frames, self.max_frames)
# concatenate different features
video_matrix = feature_matrices[0]
audio_matrix = feature_matrices[1]
return contexts["id"], video_matrix, audio_matrix, labels, num_frames
def main(files_pattern):
data_files = gfile.Glob(files_pattern)
filename_queue = tf.train.string_input_producer(
data_files, num_epochs=1, shuffle=False)
reader = YT8MFrameFeatureReader(
feature_sizes=[1024, 128], feature_names=["rgb", "audio"])
vals = reader.prepare_reader(filename_queue)
with tf.Session() as sess:
sess.run(tf.initialize_local_variables())
sess.run(tf.initialize_all_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
vid_num = 0
all_data = []
try:
while not coord.should_stop():
vid, features, audios, labels, nframes = sess.run(vals)
label_index = np.where(labels == True)[0].tolist()
vid_num += 1
#print vid, features.shape, audios.shape, label_index, nframes
features_int = features.astype(np.uint8)
audios_int = audios.astype(np.uint8)
value_dict = {}
value_dict['video'] = vid
value_dict['feature'] = features_int
value_dict['audio'] = audios_int
value_dict['label'] = label_index
value_dict['nframes'] = nframes
all_data.append(value_dict)
except tf.errors.OutOfRangeError:
print('Finished extracting.')
finally:
coord.request_stop()
coord.join(threads)
print vid_num
record_name = files_pattern.split('/')[-1].split('.')[0]
outputdir = target_dir
fn = '%s.pkl' % record_name
outp = open(os.path.join(outputdir, fn), 'wb')
cPickle.dump(all_data, outp, protocol=cPickle.HIGHEST_PROTOCOL)
outp.close()
if __name__ == '__main__':
record_dir = source_dir
record_files = os.listdir(record_dir)
for f in record_files:
record_path = os.path.join(record_dir, f)
main(record_path)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import time
import logging
import argparse
import numpy as np
try:
import cPickle as pickle
except:
import pickle
import paddle.fluid as fluid
from config import *
import models
from datareader import get_reader
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model-name',
type=str,
default='AttentionCluster',
help='name of model to train.')
parser.add_argument(
'--config',
type=str,
default='configs/attention_cluster.txt',
help='path to config file of model')
parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument(
'--weights',
type=str,
default=None,
help='weight path, None to use weights from Paddle.')
parser.add_argument(
'--batch-size',
type=int,
default=1,
help='sample number in a batch for inference.')
parser.add_argument(
'--filelist',
type=str,
default=None,
help='path to inferenece data file lists file.')
parser.add_argument(
'--log-interval',
type=int,
default=1,
help='mini-batch interval to log.')
parser.add_argument(
'--infer-topk',
type=int,
default=20,
help='topk predictions to restore.')
parser.add_argument(
'--save-dir', type=str, default='./', help='directory to store results')
args = parser.parse_args()
return args
def infer(args):
# parse config
config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args))
infer_model = models.get_model(args.model_name, infer_config, mode='infer')
infer_model.build_input(use_pyreader=False)
infer_model.build_model()
infer_feeds = infer_model.feeds()
infer_outputs = infer_model.outputs()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
filelist = args.filelist or infer_config.INFER.filelist
assert os.path.exists(filelist), "{} not exist.".format(args.filelist)
# get infer reader
infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config)
if args.weights:
assert os.path.exists(
args.weights), "Given weight dir {} not exist.".format(args.weights)
# if no weight files specified, download weights from paddle
weights = args.weights or infer_model.get_weights()
def if_exist(var):
return os.path.exists(os.path.join(weights, var.name))
fluid.io.load_vars(exe, weights, predicate=if_exist)
infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
fetch_list = [x.name for x in infer_outputs]
periods = []
results = []
cur_time = time.time()
for infer_iter, data in enumerate(infer_reader()):
data_feed_in = [items[:-1] for items in data]
video_id = [items[-1] for items in data]
infer_outs = exe.run(fetch_list=fetch_list,
feed=infer_feeder.feed(data_feed_in))
predictions = np.array(infer_outs[0])
for i in range(len(predictions)):
topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
results.append(
(video_id[i], preds.tolist(), topk_inds.tolist()))
prev_time = cur_time
cur_time = time.time()
period = cur_time - prev_time
periods.append(period)
if args.log_interval > 0 and infer_iter % args.log_interval == 0:
logger.info('Processed {} samples'.format((infer_iter) * len(
predictions)))
logger.info('[INFER] infer finished. average time: {}'.format(
np.mean(periods)))
if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir)
result_file_name = os.path.join(args.save_dir,
"{}_infer_result".format(args.model_name))
pickle.dump(results, open(result_file_name, 'wb'))
if __name__ == "__main__":
args = parse_args()
logger.info(args)
infer(args)
from .metrics_util import get_metrics
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import numpy as np
import datetime
import logging
logger = logging.getLogger(__name__)
class MetricsCalculator():
def __init__(self, name, mode):
self.name = name
self.mode = mode # 'train', 'val', 'test'
self.reset()
def reset(self):
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_acc1 = 0.0
self.aggr_acc5 = 0.0
self.aggr_loss = 0.0
self.aggr_batch_size = 0
def finalize_metrics(self):
self.avg_acc1 = self.aggr_acc1 / self.aggr_batch_size
self.avg_acc5 = self.aggr_acc5 / self.aggr_batch_size
self.avg_loss = self.aggr_loss / self.aggr_batch_size
def get_computed_metrics(self):
json_stats = {}
json_stats['avg_loss'] = self.avg_loss
json_stats['avg_acc1'] = self.avg_acc1
json_stats['avg_acc5'] = self.avg_acc5
return json_stats
def calculate_metrics(self, loss, softmax, labels):
accuracy1 = compute_topk_accuracy(softmax, labels, top_k=1) * 100.
accuracy5 = compute_topk_accuracy(softmax, labels, top_k=5) * 100.
return accuracy1, accuracy5
def accumulate(self, loss, softmax, labels):
cur_batch_size = softmax.shape[0]
# if returned loss is None for e.g. test, just set loss to be 0.
if loss is None:
cur_loss = 0.
else:
cur_loss = np.mean(np.array(loss)) #
self.aggr_batch_size += cur_batch_size
self.aggr_loss += cur_loss * cur_batch_size
accuracy1 = compute_topk_accuracy(softmax, labels, top_k=1) * 100.
accuracy5 = compute_topk_accuracy(softmax, labels, top_k=5) * 100.
self.aggr_acc1 += accuracy1 * cur_batch_size
self.aggr_acc5 += accuracy5 * cur_batch_size
return
# ----------------------------------------------
# other utils
# ----------------------------------------------
def compute_topk_correct_hits(top_k, preds, labels):
'''Compute the number of corret hits'''
batch_size = preds.shape[0]
top_k_preds = np.zeros((batch_size, top_k), dtype=np.float32)
for i in range(batch_size):
top_k_preds[i, :] = np.argsort(-preds[i, :])[:top_k]
correctness = np.zeros(batch_size, dtype=np.int32)
for i in range(batch_size):
if labels[i] in top_k_preds[i, :].astype(np.int32).tolist():
correctness[i] = 1
correct_hits = sum(correctness)
return correct_hits
def compute_topk_accuracy(softmax, labels, top_k):
computed_metrics = {}
assert labels.shape[0] == softmax.shape[0], "Batch size mismatch."
aggr_batch_size = labels.shape[0]
aggr_top_k_correct_hits = compute_topk_correct_hits(top_k, softmax, labels)
# normalize results
computed_metrics = \
float(aggr_top_k_correct_hits) / aggr_batch_size
return computed_metrics
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import logging
import numpy as np
from metrics.youtube8m import eval_util as youtube8m_metrics
from metrics.kinetics import accuracy_metrics as kinetics_metrics
from metrics.multicrop_test import multicrop_test_metrics as multicrop_test_metrics
logger = logging.getLogger(__name__)
class Metrics(object):
def __init__(self, name, mode, metrics_args):
"""Not implemented"""
pass
def calculate_and_log_out(self, loss, pred, label, info=''):
"""Not implemented"""
pass
def accumulate(self, loss, pred, label, info=''):
"""Not implemented"""
pass
def finalize_and_log_out(self, info=''):
"""Not implemented"""
pass
def reset(self):
"""Not implemented"""
pass
class Youtube8mMetrics(Metrics):
def __init__(self, name, mode, metrics_args):
self.name = name
self.mode = mode
self.num_classes = metrics_args['MODEL']['num_classes']
self.topk = metrics_args['MODEL']['topk']
self.calculator = youtube8m_metrics.EvaluationMetrics(self.num_classes,
self.topk)
def calculate_and_log_out(self, loss, pred, label, info=''):
loss = np.mean(np.array(loss))
hit_at_one = youtube8m_metrics.calculate_hit_at_one(pred, label)
perr = youtube8m_metrics.calculate_precision_at_equal_recall_rate(pred,
label)
gap = youtube8m_metrics.calculate_gap(pred, label)
logger.info(info + ' , loss = {0}, Hit@1 = {1}, PERR = {2}, GAP = {3}'.format(\
'%.6f' % loss, '%.2f' % hit_at_one, '%.2f' % perr, '%.2f' % gap))
def accumulate(self, loss, pred, label, info=''):
self.calculator.accumulate(loss, pred, label)
def finalize_and_log_out(self, info=''):
epoch_info_dict = self.calculator.get()
logger.info(info + '\tavg_hit_at_one: {0},\tavg_perr: {1},\tavg_loss :{2},\taps: {3},\tgap:{4}'\
.format(epoch_info_dict['avg_hit_at_one'], epoch_info_dict['avg_perr'], \
epoch_info_dict['avg_loss'], epoch_info_dict['aps'], epoch_info_dict['gap']))
def reset(self):
self.calculator.clear()
class Kinetics400Metrics(Metrics):
def __init__(self, name, mode, metrics_args):
self.name = name
self.mode = mode
self.calculator = kinetics_metrics.MetricsCalculator(name, mode.lower())
def calculate_and_log_out(self, loss, pred, label, info=''):
if loss is not None:
loss = np.mean(np.array(loss))
else:
loss = 0.
acc1, acc5 = self.calculator.calculate_metrics(loss, pred, label)
logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
'%.2f' % acc1, '%.2f' % acc5))
def accumulate(self, loss, pred, label, info=''):
self.calculator.accumulate(loss, pred, label)
def finalize_and_log_out(self, info=''):
self.calculator.finalize_metrics()
metrics_dict = self.calculator.get_computed_metrics()
loss = metrics_dict['avg_loss']
acc1 = metrics_dict['avg_acc1']
acc5 = metrics_dict['avg_acc5']
logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
'%.2f' % acc1, '%.2f' % acc5))
def reset(self):
self.calculator.reset()
class MulticropMetrics(Metrics):
def __init__(self, name, mode, metrics_args):
self.name = name
self.mode = mode
if mode == 'test':
args = {}
args['num_test_clips'] = metrics_args.TEST.num_test_clips
args['dataset_size'] = metrics_args.TEST.dataset_size
args['filename_gt'] = metrics_args.TEST.filename_gt
args['checkpoint_dir'] = metrics_args.TEST.checkpoint_dir
args['num_classes'] = metrics_args.MODEL.num_classes
self.calculator = multicrop_test_metrics.MetricsCalculator(
name, mode.lower(), **args)
else:
self.calculator = kinetics_metrics.MetricsCalculator(name,
mode.lower())
def calculate_and_log_out(self, loss, pred, label, info=''):
if self.mode == 'test':
pass
else:
if loss is not None:
loss = np.mean(np.array(loss))
else:
loss = 0.
acc1, acc5 = self.calculator.calculate_metrics(loss, pred, label)
logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
'%.2f' % acc1, '%.2f' % acc5))
def accumulate(self, loss, pred, label):
self.calculator.accumulate(loss, pred, label)
def finalize_and_log_out(self, info=''):
if self.mode == 'test':
self.calculator.finalize_metrics()
else:
self.calculator.finalize_metrics()
metrics_dict = self.calculator.get_computed_metrics()
loss = metrics_dict['avg_loss']
acc1 = metrics_dict['avg_acc1']
acc5 = metrics_dict['avg_acc5']
logger.info(info + '\tLoss: {},\ttop1_acc: {}, \ttop5_acc: {}'.format('%.6f' % loss, \
'%.2f' % acc1, '%.2f' % acc5))
def reset(self):
self.calculator.reset()
class MetricsZoo(object):
def __init__(self):
self.metrics_zoo = {}
def regist(self, name, metrics):
assert metrics.__base__ == Metrics, "Unknow model type {}".format(
type(metrics))
self.metrics_zoo[name] = metrics
def get(self, name, mode, cfg):
for k, v in self.metrics_zoo.items():
if k == name:
return v(name, mode, cfg)
raise MetricsNotFoundError(name, self.metrics_zoo.keys())
# singleton metrics_zoo
metrics_zoo = MetricsZoo()
def regist_metrics(name, metrics):
metrics_zoo.regist(name, metrics)
def get_metrics(name, mode, cfg):
return metrics_zoo.get(name, mode, cfg)
regist_metrics("NEXTVLAD", Youtube8mMetrics)
regist_metrics("ATTENTIONLSTM", Youtube8mMetrics)
regist_metrics("ATTENTIONCLUSTER", Youtube8mMetrics)
regist_metrics("TSN", Kinetics400Metrics)
regist_metrics("TSM", Kinetics400Metrics)
regist_metrics("STNET", Kinetics400Metrics)
regist_metrics("NONLOCAL", MulticropMetrics)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import sys
import os
import numpy as np
import datetime
import logging
from collections import defaultdict
import pickle
logger = logging.getLogger(__name__)
class MetricsCalculator():
def __init__(self, name, mode, **metrics_args):
"""
metrics args:
num_test_clips, number of clips of each video when test
dataset_size, total number of videos in the dataset
filename_gt, a file with each line stores the groud truth of each video
checkpoint_dir, dir where to store the test results
num_classes, number of classes of the dataset
"""
self.name = name
self.mode = mode # 'train', 'val', 'test'
self.metrics_args = metrics_args
self.num_test_clips = metrics_args['num_test_clips']
self.dataset_size = metrics_args['dataset_size']
self.filename_gt = metrics_args['filename_gt']
self.checkpoint_dir = metrics_args['checkpoint_dir']
self.num_classes = metrics_args['num_classes']
self.reset()
def reset(self):
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_acc1 = 0.0
self.aggr_acc5 = 0.0
self.aggr_loss = 0.0
self.aggr_batch_size = 0
self.seen_inds = defaultdict(int)
self.results = []
def calculate_metrics(self, loss, pred, labels):
pass
def accumulate(self, loss, pred, labels):
labels = labels.astype(int)
for i in range(pred.shape[0]):
probs = pred[i, :].tolist()
vid = labels[i]
self.seen_inds[vid] += 1
if self.seen_inds[vid] > self.num_test_clips:
logger.warning('Video id {} have been seen. Skip.'.format(vid,
))
continue
save_pairs = [vid, probs]
self.results.append(save_pairs)
logger.info("({0} / {1}) videos".format(\
len(self.seen_inds), self.dataset_size))
def finalize_metrics(self):
if self.filename_gt is not None:
evaluate_results(self.results, self.filename_gt, self.dataset_size, \
self.num_classes, self.num_test_clips)
# save temporary file
pkl_path = os.path.join(self.checkpoint_dir, "results_probs.pkl")
with open(pkl_path, 'w') as f:
pickle.dump(self.results, f)
logger.info('Temporary file saved to: {}'.format(pkl_path))
def read_groundtruth(filename_gt):
f = open(filename_gt, 'r')
labels = []
for line in f:
rows = line.split()
labels.append(int(rows[1]))
f.close()
return labels
def evaluate_results(results, filename_gt, test_dataset_size, num_classes,
num_test_clips):
gt_labels = read_groundtruth(filename_gt)
sample_num = test_dataset_size
class_num = num_classes
sample_video_times = num_test_clips
counts = np.zeros(sample_num, dtype=np.int32)
probs = np.zeros((sample_num, class_num))
assert (len(gt_labels) == sample_num)
"""
clip_accuracy: the (e.g.) 10*19761 clips' average accuracy
clip1_accuracy: the 1st clip's accuracy (starting from frame 0)
"""
clip_accuracy = 0
clip1_accuracy = 0
clip1_count = 0
seen_inds = defaultdict(int)
# evaluate
for entry in results:
vid = entry[0]
prob = np.array(entry[1])
probs[vid] += prob[0:class_num]
counts[vid] += 1
idx = prob.argmax()
if idx == gt_labels[vid]:
# clip accuracy
clip_accuracy += 1
# clip1 accuracy
seen_inds[vid] += 1
if seen_inds[vid] == 1:
clip1_count += 1
if idx == gt_labels[vid]:
clip1_accuracy += 1
# sanity checkcnt = 0
max_clips = 0
min_clips = sys.maxsize
count_empty = 0
count_corrupted = 0
for i in range(sample_num):
max_clips = max(max_clips, counts[i])
min_clips = min(min_clips, counts[i])
if counts[i] != sample_video_times:
count_corrupted += 1
logger.warning('Id: {} count: {}'.format(i, counts[i]))
if counts[i] == 0:
count_empty += 1
logger.info('Num of empty videos: {}'.format(count_empty))
logger.info('Num of corrupted videos: {}'.format(count_corrupted))
logger.info('Max num of clips in a video: {}'.format(max_clips))
logger.info('Min num of clips in a video: {}'.format(min_clips))
# clip1 accuracy for sanity (# print clip1 first as it is lowest)
logger.info('Clip1 accuracy: {:.2f} percent ({}/{})'.format(
100. * clip1_accuracy / clip1_count, clip1_accuracy, clip1_count))
# clip accuracy for sanity
logger.info('Clip accuracy: {:.2f} percent ({}/{})'.format(
100. * clip_accuracy / len(results), clip_accuracy, len(results)))
# compute accuracy
accuracy = 0
accuracy_top5 = 0
for i in range(sample_num):
prob = probs[i]
# top-1
idx = prob.argmax()
if idx == gt_labels[i] and counts[i] > 0:
accuracy = accuracy + 1
ids = np.argsort(prob)[::-1]
for j in range(5):
if ids[j] == gt_labels[i] and counts[i] > 0:
accuracy_top5 = accuracy_top5 + 1
break
accuracy = float(accuracy) / float(sample_num)
accuracy_top5 = float(accuracy_top5) / float(sample_num)
logger.info('-' * 80)
logger.info('top-1 accuracy: {:.2f} percent'.format(accuracy * 100))
logger.info('top-5 accuracy: {:.2f} percent'.format(accuracy_top5 * 100))
logger.info('-' * 80)
for i in range(sample_num):
prob = probs[i]
# top-1
idx = prob.argmax()
if idx == gt_labels[i] and counts[i] > 0:
accuracy = accuracy + 1
ids = np.argsort(prob)[::-1]
for j in range(5):
if ids[j] == gt_labels[i] and counts[i] > 0:
accuracy_top5 = accuracy_top5 + 1
break
accuracy = float(accuracy) / float(sample_num)
accuracy_top5 = float(accuracy_top5) / float(sample_num)
logger.info('-' * 80)
logger.info('top-1 accuracy: {:.2f} percent'.format(accuracy * 100))
logger.info('top-5 accuracy: {:.2f} percent'.format(accuracy_top5 * 100))
logger.info('-' * 80)
return
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate or keep track of the interpolated average precision.
It provides an interface for calculating interpolated average precision for an
entire list or the top-n ranked items. For the definition of the
(non-)interpolated average precision:
http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf
Example usages:
1) Use it as a static function call to directly calculate average precision for
a short ranked list in the memory.
```
import random
p = np.array([random.random() for _ in xrange(10)])
a = np.array([random.choice([0, 1]) for _ in xrange(10)])
ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a)
```
2) Use it as an object for long ranked list that cannot be stored in memory or
the case where partial predictions can be observed at a time (Tensorflow
predictions). In this case, we first call the function accumulate many times
to process parts of the ranked list. After processing all the parts, we call
peek_interpolated_ap_at_n.
```
p1 = np.array([random.random() for _ in xrange(5)])
a1 = np.array([random.choice([0, 1]) for _ in xrange(5)])
p2 = np.array([random.random() for _ in xrange(5)])
a2 = np.array([random.choice([0, 1]) for _ in xrange(5)])
# interpolated average precision at 10 using 1000 break points
calculator = average_precision_calculator.AveragePrecisionCalculator(10)
calculator.accumulate(p1, a1)
calculator.accumulate(p2, a2)
ap3 = calculator.peek_ap_at_n()
```
"""
import heapq
import random
import numbers
import numpy
class AveragePrecisionCalculator(object):
"""Calculate the average precision and average precision at n."""
def __init__(self, top_n=None):
"""Construct an AveragePrecisionCalculator to calculate average precision.
This class is used to calculate the average precision for a single label.
Args:
top_n: A positive Integer specifying the average precision at n, or
None to use all provided data points.
Raises:
ValueError: An error occurred when the top_n is not a positive integer.
"""
if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None):
raise ValueError("top_n must be a positive integer or None.")
self._top_n = top_n # average precision at n
self._total_positives = 0 # total number of positives have seen
self._heap = [] # max heap of (prediction, actual)
@property
def heap_size(self):
"""Gets the heap size maintained in the class."""
return len(self._heap)
@property
def num_accumulated_positives(self):
"""Gets the number of positive samples that have been accumulated."""
return self._total_positives
def accumulate(self, predictions, actuals, num_positives=None):
"""Accumulate the predictions and their ground truth labels.
After the function call, we may call peek_ap_at_n to actually calculate
the average precision.
Note predictions and actuals must have the same shape.
Args:
predictions: a list storing the prediction scores.
actuals: a list storing the ground truth labels. Any value
larger than 0 will be treated as positives, otherwise as negatives.
num_positives = If the 'predictions' and 'actuals' inputs aren't complete,
then it's possible some true positives were missed in them. In that case,
you can provide 'num_positives' in order to accurately track recall.
Raises:
ValueError: An error occurred when the format of the input is not the
numpy 1-D array or the shape of predictions and actuals does not match.
"""
if len(predictions) != len(actuals):
raise ValueError(
"the shape of predictions and actuals does not match.")
if not num_positives is None:
if not isinstance(num_positives,
numbers.Number) or num_positives < 0:
raise ValueError(
"'num_positives' was provided but it wan't a nonzero number."
)
if not num_positives is None:
self._total_positives += num_positives
else:
self._total_positives += numpy.size(numpy.where(actuals > 0))
topk = self._top_n
heap = self._heap
for i in range(numpy.size(predictions)):
if topk is None or len(heap) < topk:
heapq.heappush(heap, (predictions[i], actuals[i]))
else:
if predictions[i] > heap[0][0]: # heap[0] is the smallest
heapq.heappop(heap)
heapq.heappush(heap, (predictions[i], actuals[i]))
def clear(self):
"""Clear the accumulated predictions."""
self._heap = []
self._total_positives = 0
def peek_ap_at_n(self):
"""Peek the non-interpolated average precision at n.
Returns:
The non-interpolated average precision at n (default 0).
If n is larger than the length of the ranked list,
the average precision will be returned.
"""
if self.heap_size <= 0:
return 0
predlists = numpy.array(list(zip(*self._heap)))
ap = self.ap_at_n(
predlists[0],
predlists[1],
n=self._top_n,
total_num_positives=self._total_positives)
return ap
@staticmethod
def ap(predictions, actuals):
"""Calculate the non-interpolated average precision.
Args:
predictions: a numpy 1-D array storing the sparse prediction scores.
actuals: a numpy 1-D array storing the ground truth labels. Any value
larger than 0 will be treated as positives, otherwise as negatives.
Returns:
The non-interpolated average precision at n.
If n is larger than the length of the ranked list,
the average precision will be returned.
Raises:
ValueError: An error occurred when the format of the input is not the
numpy 1-D array or the shape of predictions and actuals does not match.
"""
return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None)
@staticmethod
def ap_at_n(predictions, actuals, n=20, total_num_positives=None):
"""Calculate the non-interpolated average precision.
Args:
predictions: a numpy 1-D array storing the sparse prediction scores.
actuals: a numpy 1-D array storing the ground truth labels. Any value
larger than 0 will be treated as positives, otherwise as negatives.
n: the top n items to be considered in ap@n.
total_num_positives : (optionally) you can specify the number of total
positive
in the list. If specified, it will be used in calculation.
Returns:
The non-interpolated average precision at n.
If n is larger than the length of the ranked list,
the average precision will be returned.
Raises:
ValueError: An error occurred when
1) the format of the input is not the numpy 1-D array;
2) the shape of predictions and actuals does not match;
3) the input n is not a positive integer.
"""
if len(predictions) != len(actuals):
raise ValueError(
"the shape of predictions and actuals does not match.")
if n is not None:
if not isinstance(n, int) or n <= 0:
raise ValueError("n must be 'None' or a positive integer."
" It was '%s'." % n)
ap = 0.0
predictions = numpy.array(predictions)
actuals = numpy.array(actuals)
# add a shuffler to avoid overestimating the ap
predictions, actuals = AveragePrecisionCalculator._shuffle(predictions,
actuals)
sortidx = sorted(
range(len(predictions)), key=lambda k: predictions[k], reverse=True)
if total_num_positives is None:
numpos = numpy.size(numpy.where(actuals > 0))
else:
numpos = total_num_positives
if numpos == 0:
return 0
if n is not None:
numpos = min(numpos, n)
delta_recall = 1.0 / numpos
poscount = 0.0
# calculate the ap
r = len(sortidx)
if n is not None:
r = min(r, n)
for i in range(r):
if actuals[sortidx[i]] > 0:
poscount += 1
ap += poscount / (i + 1) * delta_recall
return ap
@staticmethod
def _shuffle(predictions, actuals):
random.seed(0)
suffidx = random.sample(range(len(predictions)), len(predictions))
predictions = predictions[suffidx]
actuals = actuals[suffidx]
return predictions, actuals
@staticmethod
def _zero_one_normalize(predictions, epsilon=1e-7):
"""Normalize the predictions to the range between 0.0 and 1.0.
For some predictions like SVM predictions, we need to normalize them before
calculate the interpolated average precision. The normalization will not
change the rank in the original list and thus won't change the average
precision.
Args:
predictions: a numpy 1-D array storing the sparse prediction scores.
epsilon: a small constant to avoid denominator being zero.
Returns:
The normalized prediction.
"""
denominator = numpy.max(predictions) - numpy.min(predictions)
ret = (predictions - numpy.min(predictions)) / numpy.max(denominator,
epsilon)
return ret
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides functions to help with evaluating models."""
import datetime
import numpy
from . import mean_average_precision_calculator as map_calculator
from . import average_precision_calculator as ap_calculator
def flatten(l):
""" Merges a list of lists into a single list. """
return [item for sublist in l for item in sublist]
def calculate_hit_at_one(predictions, actuals):
"""Performs a local (numpy) calculation of the hit at one.
Args:
predictions: Matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
actuals: Matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
Returns:
float: The average hit at one across the entire batch.
"""
top_prediction = numpy.argmax(predictions, 1)
hits = actuals[numpy.arange(actuals.shape[0]), top_prediction]
return numpy.average(hits)
def calculate_precision_at_equal_recall_rate(predictions, actuals):
"""Performs a local (numpy) calculation of the PERR.
Args:
predictions: Matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
actuals: Matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
Returns:
float: The average precision at equal recall rate across the entire batch.
"""
aggregated_precision = 0.0
num_videos = actuals.shape[0]
for row in numpy.arange(num_videos):
num_labels = int(numpy.sum(actuals[row]))
top_indices = numpy.argpartition(predictions[row],
-num_labels)[-num_labels:]
item_precision = 0.0
for label_index in top_indices:
if predictions[row][label_index] > 0:
item_precision += actuals[row][label_index]
item_precision /= top_indices.size
aggregated_precision += item_precision
aggregated_precision /= num_videos
return aggregated_precision
def calculate_gap(predictions, actuals, top_k=20):
"""Performs a local (numpy) calculation of the global average precision.
Only the top_k predictions are taken for each of the videos.
Args:
predictions: Matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
actuals: Matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
top_k: How many predictions to use per video.
Returns:
float: The global average precision.
"""
gap_calculator = ap_calculator.AveragePrecisionCalculator()
sparse_predictions, sparse_labels, num_positives = top_k_by_class(
predictions, actuals, top_k)
gap_calculator.accumulate(
flatten(sparse_predictions), flatten(sparse_labels), sum(num_positives))
return gap_calculator.peek_ap_at_n()
def top_k_by_class(predictions, labels, k=20):
"""Extracts the top k predictions for each video, sorted by class.
Args:
predictions: A numpy matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
k: the top k non-zero entries to preserve in each prediction.
Returns:
A tuple (predictions,labels, true_positives). 'predictions' and 'labels'
are lists of lists of floats. 'true_positives' is a list of scalars. The
length of the lists are equal to the number of classes. The entries in the
predictions variable are probability predictions, and
the corresponding entries in the labels variable are the ground truth for
those predictions. The entries in 'true_positives' are the number of true
positives for each class in the ground truth.
Raises:
ValueError: An error occurred when the k is not a positive integer.
"""
if k <= 0:
raise ValueError("k must be a positive integer.")
k = min(k, predictions.shape[1])
num_classes = predictions.shape[1]
prediction_triplets = []
for video_index in range(predictions.shape[0]):
prediction_triplets.extend(
top_k_triplets(predictions[video_index], labels[video_index], k))
out_predictions = [[] for v in range(num_classes)]
out_labels = [[] for v in range(num_classes)]
for triplet in prediction_triplets:
out_predictions[triplet[0]].append(triplet[1])
out_labels[triplet[0]].append(triplet[2])
out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)]
return out_predictions, out_labels, out_true_positives
def top_k_triplets(predictions, labels, k=20):
"""Get the top_k for a 1-d numpy array. Returns a sparse list of tuples in
(prediction, class) format"""
m = len(predictions)
k = min(k, m)
indices = numpy.argpartition(predictions, -k)[-k:]
return [(index, predictions[index], labels[index]) for index in indices]
class EvaluationMetrics(object):
"""A class to store the evaluation metrics."""
def __init__(self, num_class, top_k):
"""Construct an EvaluationMetrics object to store the evaluation metrics.
Args:
num_class: A positive integer specifying the number of classes.
top_k: A positive integer specifying how many predictions are considered per video.
Raises:
ValueError: An error occurred when MeanAveragePrecisionCalculator cannot
not be constructed.
"""
self.sum_hit_at_one = 0.0
self.sum_perr = 0.0
self.sum_loss = 0.0
self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
num_class)
self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
self.top_k = top_k
self.num_examples = 0
#def accumulate(self, predictions, labels, loss):
def accumulate(self, loss, predictions, labels):
"""Accumulate the metrics calculated locally for this mini-batch.
Args:
predictions: A numpy matrix containing the outputs of the model.
Dimensions are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
loss: A numpy array containing the loss for each sample.
Returns:
dictionary: A dictionary storing the metrics for the mini-batch.
Raises:
ValueError: An error occurred when the shape of predictions and actuals
does not match.
"""
batch_size = labels.shape[0]
mean_hit_at_one = calculate_hit_at_one(predictions, labels)
mean_perr = calculate_precision_at_equal_recall_rate(predictions,
labels)
mean_loss = numpy.mean(loss)
# Take the top 20 predictions.
sparse_predictions, sparse_labels, num_positives = top_k_by_class(
predictions, labels, self.top_k)
self.map_calculator.accumulate(sparse_predictions, sparse_labels,
num_positives)
self.global_ap_calculator.accumulate(
flatten(sparse_predictions),
flatten(sparse_labels), sum(num_positives))
self.num_examples += batch_size
self.sum_hit_at_one += mean_hit_at_one * batch_size
self.sum_perr += mean_perr * batch_size
self.sum_loss += mean_loss * batch_size
return {
"hit_at_one": mean_hit_at_one,
"perr": mean_perr,
"loss": mean_loss
}
def get(self):
"""Calculate the evaluation metrics for the whole epoch.
Raises:
ValueError: If no examples were accumulated.
Returns:
dictionary: a dictionary storing the evaluation metrics for the epoch. The
dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and
aps (default nan).
"""
if self.num_examples <= 0:
raise ValueError("total_sample must be positive.")
avg_hit_at_one = self.sum_hit_at_one / self.num_examples
avg_perr = self.sum_perr / self.num_examples
avg_loss = self.sum_loss / self.num_examples
aps = self.map_calculator.peek_map_at_n()
gap = self.global_ap_calculator.peek_ap_at_n()
epoch_info_dict = {}
return {
"avg_hit_at_one": avg_hit_at_one,
"avg_perr": avg_perr,
"avg_loss": avg_loss,
"aps": aps,
"gap": gap
}
def clear(self):
"""Clear the evaluation metrics and reset the EvaluationMetrics object."""
self.sum_hit_at_one = 0.0
self.sum_perr = 0.0
self.sum_loss = 0.0
self.map_calculator.clear()
self.global_ap_calculator.clear()
self.num_examples = 0
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate the mean average precision.
It provides an interface for calculating mean average precision
for an entire list or the top-n ranked items.
Example usages:
We first call the function accumulate many times to process parts of the ranked
list. After processing all the parts, we call peek_map_at_n
to calculate the mean average precision.
```
import random
p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
for _ in xrange(1000)])
# mean average precision for 50 classes.
calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
num_class=50)
calculator.accumulate(p, a)
aps = calculator.peek_map_at_n()
```
"""
import numpy
from . import average_precision_calculator
class MeanAveragePrecisionCalculator(object):
"""This class is to calculate mean average precision.
"""
def __init__(self, num_class):
"""Construct a calculator to calculate the (macro) average precision.
Args:
num_class: A positive Integer specifying the number of classes.
top_n_array: A list of positive integers specifying the top n for each
class. The top n in each class will be used to calculate its average
precision at n.
The size of the array must be num_class.
Raises:
ValueError: An error occurred when num_class is not a positive integer;
or the top_n_array is not a list of positive integers.
"""
if not isinstance(num_class, int) or num_class <= 1:
raise ValueError("num_class must be a positive integer.")
self._ap_calculators = [] # member of AveragePrecisionCalculator
self._num_class = num_class # total number of classes
for i in range(num_class):
self._ap_calculators.append(
average_precision_calculator.AveragePrecisionCalculator())
def accumulate(self, predictions, actuals, num_positives=None):
"""Accumulate the predictions and their ground truth labels.
Args:
predictions: A list of lists storing the prediction scores. The outer
dimension corresponds to classes.
actuals: A list of lists storing the ground truth labels. The dimensions
should correspond to the predictions input. Any value
larger than 0 will be treated as positives, otherwise as negatives.
num_positives: If provided, it is a list of numbers representing the
number of true positives for each class. If not provided, the number of
true positives will be inferred from the 'actuals' array.
Raises:
ValueError: An error occurred when the shape of predictions and actuals
does not match.
"""
if not num_positives:
num_positives = [None for i in predictions.shape[1]]
calculators = self._ap_calculators
for i in range(len(predictions)):
calculators[i].accumulate(predictions[i], actuals[i],
num_positives[i])
def clear(self):
for calculator in self._ap_calculators:
calculator.clear()
def is_empty(self):
return ([calculator.heap_size for calculator in self._ap_calculators] ==
[0 for _ in range(self._num_class)])
def peek_map_at_n(self):
"""Peek the non-interpolated mean average precision at n.
Returns:
An array of non-interpolated average precision at n (default 0) for each
class.
"""
aps = [
self._ap_calculators[i].peek_ap_at_n()
for i in range(self._num_class)
]
return aps
from .model import regist_model, get_model
from .attention_cluster import AttentionCluster
from .nextvlad import NEXTVLAD
from .tsn import TSN
from .stnet import STNET
from .attention_lstm import AttentionLSTM
# regist models
regist_model("AttentionCluster", AttentionCluster)
regist_model("NEXTVLAD", NEXTVLAD)
regist_model("TSN", TSN)
regist_model("STNET", STNET)
regist_model("AttentionLSTM", AttentionLSTM)
from __future__ import absolute_import
from .attention_cluster import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from ..model import ModelBase
from .shifting_attention import ShiftingAttentionModel
from .logistic_model import LogisticModel
__all__ = ["AttentionCluster"]
class AttentionCluster(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(AttentionCluster, self).__init__(name, cfg, mode)
self.get_config()
def get_config(self):
# get model configs
self.feature_num = self.cfg.MODEL.feature_num
self.feature_names = self.cfg.MODEL.feature_names
self.feature_dims = self.cfg.MODEL.feature_dims
self.cluster_nums = self.cfg.MODEL.cluster_nums
self.seg_num = self.cfg.MODEL.seg_num
self.class_num = self.cfg.MODEL.num_classes
self.drop_rate = self.cfg.MODEL.drop_rate
if self.mode == 'train':
self.learning_rate = self.get_config_from_sec('train',
'learning_rate', 1e-3)
def build_input(self, use_pyreader):
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
shapes = []
for dim in self.feature_dims:
shapes.append([-1, self.seg_num, dim])
shapes.append([-1, self.class_num]) # label
self.py_reader = fluid.layers.py_reader(
capacity=1024,
shapes=shapes,
lod_levels=[0] * (self.feature_num + 1),
dtypes=['float32'] * (self.feature_num + 1),
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
inputs = fluid.layers.read_file(self.py_reader)
self.feature_input = inputs[:self.feature_num]
self.label_input = inputs[-1]
else:
self.feature_input = []
for name, dim in zip(self.feature_names, self.feature_dims):
self.feature_input.append(
fluid.layers.data(
shape=[self.seg_num, dim], dtype='float32', name=name))
if self.mode == 'infer':
self.label_input = None
else:
self.label_input = fluid.layers.data(
shape=[self.class_num], dtype='float32', name='label')
def build_model(self):
att_outs = []
for i, (input_dim, cluster_num, feature) in enumerate(
zip(self.feature_dims, self.cluster_nums, self.feature_input)):
att = ShiftingAttentionModel(input_dim, self.seg_num, cluster_num,
"satt{}".format(i))
att_out = att.forward(feature)
att_outs.append(att_out)
out = fluid.layers.concat(att_outs, axis=1)
if self.drop_rate > 0.:
out = fluid.layers.dropout(
out, self.drop_rate, is_test=(not self.is_training))
fc1 = fluid.layers.fc(
out,
size=1024,
act='tanh',
param_attr=ParamAttr(
name="fc1.weights",
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=ParamAttr(
name="fc1.bias", initializer=fluid.initializer.MSRA()))
fc2 = fluid.layers.fc(
fc1,
size=4096,
act='tanh',
param_attr=ParamAttr(
name="fc2.weights",
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=ParamAttr(
name="fc2.bias", initializer=fluid.initializer.MSRA()))
aggregate_model = LogisticModel()
self.output, self.logit = aggregate_model.build_model(
model_input=fc2,
vocab_size=self.class_num,
is_training=self.is_training)
def optimizer(self):
assert self.mode == 'train', "optimizer only can be get in train mode"
return fluid.optimizer.AdamOptimizer(self.learning_rate)
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
cost = fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.logit, label=self.label_input)
cost = fluid.layers.reduce_sum(cost, dim=-1)
self.loss_ = fluid.layers.mean(x=cost)
return self.loss_
def outputs(self):
return [self.output, self.logit]
def feeds(self):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def weights_info(self):
return (
"attention_cluster_youtube8m",
"https://paddlemodels.bj.bcebos.com/video_classification/attention_cluster_youtube8m.tar.gz"
)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
class LogisticModel(object):
"""Logistic model."""
def build_model(self,
model_input,
vocab_size,
**unused_params):
"""Creates a logistic model.
Args:
model_input: 'batch' x 'num_features' matrix of input features.
vocab_size: The number of classes in the dataset.
Returns:
A dictionary with a tensor containing the probability predictions of the
model in the 'predictions' key. The dimensions of the tensor are
batch_size x num_classes."""
logit = fluid.layers.fc(
input=model_input,
size=vocab_size,
act=None,
name='logits_clf',
param_attr=fluid.ParamAttr(
name='logistic.weights',
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=fluid.ParamAttr(
name='logistic.bias',
initializer=fluid.initializer.MSRA(uniform=False)))
output = fluid.layers.sigmoid(logit)
return output, logit
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
class ShiftingAttentionModel(object):
"""Shifting Attention Model"""
def __init__(self, input_dim, seg_num, n_att, name):
self.n_att = n_att
self.input_dim = input_dim
self.seg_num = seg_num
self.name = name
self.gnorm = np.sqrt(n_att)
def softmax_m1(self, x):
x_shape = fluid.layers.shape(x)
x_shape.stop_gradient = True
flat_x = fluid.layers.reshape(x, shape=(-1, self.seg_num))
flat_softmax = fluid.layers.softmax(flat_x)
return fluid.layers.reshape(
flat_softmax, shape=x.shape, actual_shape=x_shape)
def glorot(self, n):
return np.sqrt(1.0 / np.sqrt(n))
def forward(self, x):
"""Forward shifting attention model.
Args:
x: input features in shape of [N, L, F].
Returns:
out: output features in shape of [N, F * C]
"""
trans_x = fluid.layers.transpose(x, perm=[0, 2, 1])
# scores and weight in shape [N, C, L], sum(weights, -1) = 1
trans_x = fluid.layers.unsqueeze(trans_x, [-1])
scores = fluid.layers.conv2d(
trans_x,
self.n_att,
filter_size=1,
param_attr=ParamAttr(
name=self.name + ".conv.weight",
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=ParamAttr(
name=self.name + ".conv.bias",
initializer=fluid.initializer.MSRA()))
scores = fluid.layers.squeeze(scores, [-1])
weights = self.softmax_m1(scores)
glrt = self.glorot(self.n_att)
self.w = fluid.layers.create_parameter(
shape=(self.n_att, ),
dtype=x.dtype,
attr=ParamAttr(self.name + ".shift_w"),
default_initializer=fluid.initializer.Normal(0.0, glrt))
self.b = fluid.layers.create_parameter(
shape=(self.n_att, ),
dtype=x.dtype,
attr=ParamAttr(name=self.name + ".shift_b"),
default_initializer=fluid.initializer.Normal(0.0, glrt))
outs = []
for i in range(self.n_att):
# slice weight and expand to shape [N, L, C]
weight = fluid.layers.slice(
weights, axes=[1], starts=[i], ends=[i + 1])
weight = fluid.layers.transpose(weight, perm=[0, 2, 1])
weight = fluid.layers.expand(weight, [1, 1, self.input_dim])
w_i = fluid.layers.slice(self.w, axes=[0], starts=[i], ends=[i + 1])
b_i = fluid.layers.slice(self.b, axes=[0], starts=[i], ends=[i + 1])
shift = fluid.layers.reduce_sum(x * weight, dim=1) * w_i + b_i
l2_norm = fluid.layers.l2_normalize(shift, axis=-1)
outs.append(l2_norm / self.gnorm)
out = fluid.layers.concat(outs, axis=1)
return out
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from ..model import ModelBase
from .lstm_attention import LSTMAttentionModel
__all__ = ["AttentionLSTM"]
class AttentionLSTM(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(AttentionLSTM, self).__init__(name, cfg, mode)
self.get_config()
def get_config(self):
# get model configs
self.feature_num = self.cfg.MODEL.feature_num
self.feature_names = self.cfg.MODEL.feature_names
self.feature_dims = self.cfg.MODEL.feature_dims
self.num_classes = self.cfg.MODEL.num_classes
self.embedding_size = self.cfg.MODEL.embedding_size
self.lstm_size = self.cfg.MODEL.lstm_size
self.drop_rate = self.cfg.MODEL.drop_rate
# get mode configs
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size', 1)
self.num_gpus = self.get_config_from_sec(self.mode, 'num_gpus', 1)
if self.mode == 'train':
self.learning_rate = self.get_config_from_sec('train',
'learning_rate', 1e-3)
self.weight_decay = self.get_config_from_sec('train',
'weight_decay', 8e-4)
self.num_samples = self.get_config_from_sec('train', 'num_samples',
5000000)
self.decay_epochs = self.get_config_from_sec('train',
'decay_epochs', [5])
self.decay_gamma = self.get_config_from_sec('train', 'decay_gamma',
0.1)
def build_input(self, use_pyreader):
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
shapes = []
for dim in self.feature_dims:
shapes.append([-1, dim])
shapes.append([-1, self.num_classes]) # label
self.py_reader = fluid.layers.py_reader(
capacity=1024,
shapes=shapes,
lod_levels=[1] * self.feature_num + [0],
dtypes=['float32'] * (self.feature_num + 1),
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
inputs = fluid.layers.read_file(self.py_reader)
self.feature_input = inputs[:self.feature_num]
self.label_input = inputs[-1]
else:
self.feature_input = []
for name, dim in zip(self.feature_names, self.feature_dims):
self.feature_input.append(
fluid.layers.data(
shape=[dim], lod_level=1, dtype='float32', name=name))
if self.mode == 'infer':
self.label_input = None
else:
self.label_input = fluid.layers.data(
shape=[self.num_classes], dtype='float32', name='label')
def build_model(self):
att_outs = []
for i, (input_dim, feature
) in enumerate(zip(self.feature_dims, self.feature_input)):
att = LSTMAttentionModel(input_dim, self.embedding_size,
self.lstm_size, self.drop_rate)
att_out = att.forward(feature, is_training=(self.mode == 'train'))
att_outs.append(att_out)
out = fluid.layers.concat(att_outs, axis=1)
fc1 = fluid.layers.fc(
input=out,
size=8192,
act='relu',
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
fc2 = fluid.layers.fc(
input=fc1,
size=4096,
act='tanh',
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
self.logit = fluid.layers.fc(input=fc2, size=self.num_classes, act=None, \
bias_attr=ParamAttr(regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
self.output = fluid.layers.sigmoid(self.logit)
def optimizer(self):
assert self.mode == 'train', "optimizer only can be get in train mode"
values = [
self.learning_rate * (self.decay_gamma**i)
for i in range(len(self.decay_epochs) + 1)
]
iter_per_epoch = self.num_samples / self.batch_size
boundaries = [e * iter_per_epoch for e in self.decay_epochs]
return fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(
values=values, boundaries=boundaries),
centered=True,
regularization=fluid.regularizer.L2Decay(self.weight_decay))
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
cost = fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.logit, label=self.label_input)
cost = fluid.layers.reduce_sum(cost, dim=-1)
sum_cost = fluid.layers.reduce_sum(cost)
self.loss_ = fluid.layers.scale(
sum_cost, scale=self.num_gpus, bias_after_scale=False)
return self.loss_
def outputs(self):
return [self.output, self.logit]
def feeds(self):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def weights_info(self):
return (None, None)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
class LSTMAttentionModel(object):
"""LSTM Attention Model"""
def __init__(self,
bias_attr,
embedding_size=512,
lstm_size=1024,
drop_rate=0.5):
self.lstm_size = lstm_size
self.embedding_size = embedding_size
self.drop_rate = drop_rate
def forward(self, input, is_training):
input_fc = fluid.layers.fc(
input=input,
size=self.embedding_size,
act='tanh',
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
lstm_forward_fc = fluid.layers.fc(
input=input_fc,
size=self.lstm_size * 4,
act=None,
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
lstm_forward, _ = fluid.layers.dynamic_lstm(
input=lstm_forward_fc, size=self.lstm_size * 4, is_reverse=False)
lsmt_backward_fc = fluid.layers.fc(
input=input_fc,
size=self.lstm_size * 4,
act=None,
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
lstm_backward, _ = fluid.layers.dynamic_lstm(
input=lsmt_backward_fc, size=self.lstm_size * 4, is_reverse=True)
lstm_concat = fluid.layers.concat(
input=[lstm_forward, lstm_backward], axis=1)
lstm_dropout = fluid.layers.dropout(
x=lstm_concat, dropout_prob=self.drop_rate, is_test=(not is_training))
lstm_weight = fluid.layers.fc(
input=lstm_dropout,
size=1,
act='sequence_softmax',
bias_attr=ParamAttr(
regularizer=fluid.regularizer.L2Decay(0.0),
initializer=fluid.initializer.NormalInitializer(scale=0.0)))
scaled = fluid.layers.elementwise_mul(
x=lstm_dropout, y=lstm_weight, axis=0)
lstm_pool = fluid.layers.sequence_pool(input=scaled, pool_type='sum')
return lstm_pool
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import logging
try:
from configparser import ConfigParser
except:
from ConfigParser import ConfigParser
import paddle.fluid as fluid
from datareader import get_reader
from metrics import get_metrics
from .utils import download, AttrDict
WEIGHT_DIR = os.path.expanduser("~/.paddle/weights")
logger = logging.getLogger(__name__)
class NotImplementError(Exception):
"Error: model function not implement"
def __init__(self, model, function):
super(NotImplementError, self).__init__()
self.model = model.__class__.__name__
self.function = function.__name__
def __str__(self):
return "Function {}() is not implemented in model {}".format(
self.function, self.model)
class ModelNotFoundError(Exception):
"Error: model not found"
def __init__(self, model_name, avail_models):
super(ModelNotFoundError, self).__init__()
self.model_name = model_name
self.avail_models = avail_models
def __str__(self):
msg = "Model {} Not Found.\nAvailiable models:\n".format(
self.model_name)
for model in self.avail_models:
msg += " {}\n".format(model)
return msg
class ModelBase(object):
def __init__(self, name, cfg, mode='train'):
assert mode in ['train', 'valid', 'test', 'infer'], \
"Unknown mode type {}".format(mode)
self.name = name
self.is_training = (mode == 'train')
self.mode = mode
self.py_reader = None
# parse config
# assert os.path.exists(cfg), \
# "Config file {} not exists".format(cfg)
# self._config = ModelConfig(cfg)
# self._config.parse()
# if args and isinstance(args, dict):
# self._config.merge_configs(mode, args)
# self.cfg = self._config.get_configs()
self.cfg = cfg
def build_model(self):
"build model struct"
raise NotImplementError(self, self.build_model)
def build_input(self, use_pyreader):
"build input Variable"
raise NotImplementError(self, self.build_input)
def optimizer(self):
"get model optimizer"
raise NotImplementError(self, self.optimizer)
def outputs():
"get output variable"
raise notimplementerror(self, self.outputs)
def loss(self):
"get loss variable"
raise notimplementerror(self, self.loss)
def feeds(self):
"get feed inputs list"
raise NotImplementError(self, self.feeds)
def weights_info(self):
"get model weight default path and download url"
raise NotImplementError(self, self.weights_info)
def get_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.weights_info()
path = os.path.join(WEIGHT_DIR, path)
if os.path.exists(path):
return path
logger.info("Download weights of {} from {}".format(self.name, url))
download(url, path)
return path
def pyreader(self):
return self.py_reader
def epoch_num(self):
"get train epoch num"
return self.cfg.TRAIN.epoch
def pretrain_info(self):
"get pretrain base model directory"
return (None, None)
def get_pretrain_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.pretrain_info()
if not path:
return None
path = os.path.join(WEIGHT_DIR, path)
if os.path.exists(path):
return path
logger.info("Download pretrain weights of {} from {}".format(
self.name, url))
download(url, path)
return path
def load_pretrain_params(self, exe, pretrain, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name))
fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg:
return default
return self.cfg[sec.upper()].get(item, default)
class ModelZoo(object):
def __init__(self):
self.model_zoo = {}
def regist(self, name, model):
assert model.__base__ == ModelBase, "Unknow model type {}".format(
type(model))
self.model_zoo[name] = model
def get(self, name, cfg, mode='train'):
for k, v in self.model_zoo.items():
if k == name:
return v(name, cfg, mode)
raise ModelNotFoundError(name, self.model_zoo.keys())
# singleton model_zoo
model_zoo = ModelZoo()
def regist_model(name, model):
model_zoo.regist(name, model)
def get_model(name, cfg, mode='train'):
return model_zoo.get(name, cfg, mode)
from __future__ import absolute_import
from .nextvlad import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
class LogisticModel(object):
"""Logistic model with L2 regularization."""
def create_model(self,
model_input,
vocab_size,
l2_penalty=None,
**unused_params):
"""Creates a logistic model.
Args:
model_input: 'batch' x 'num_features' matrix of input features.
vocab_size: The number of classes in the dataset.
Returns:
A dictionary with a tensor containing the probability predictions of the
model in the 'predictions' key. The dimensions of the tensor are
batch_size x num_classes."""
logits = fluid.layers.fc(
input=model_input,
size=vocab_size,
act=None,
name='logits_clf',
param_attr=fluid.ParamAttr(
name='logits_clf_weights',
initializer=fluid.initializer.MSRA(uniform=False),
regularizer=fluid.regularizer.L2DecayRegularizer(l2_penalty)),
bias_attr=fluid.ParamAttr(
name='logits_clf_bias',
regularizer=fluid.regularizer.L2DecayRegularizer(l2_penalty)))
output = fluid.layers.sigmoid(logits)
return {'predictions': output, 'logits': logits}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from ..model import ModelBase
from .clf_model import LogisticModel
from . import nextvlad_model
__all__ = ["NEXTVLAD"]
class NEXTVLAD(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(NEXTVLAD, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
# model params
self.num_classes = self.get_config_from_sec('model', 'num_classes')
self.video_feature_size = self.get_config_from_sec('model',
'video_feature_size')
self.audio_feature_size = self.get_config_from_sec('model',
'audio_feature_size')
self.cluster_size = self.get_config_from_sec('model', 'cluster_size')
self.hidden_size = self.get_config_from_sec('model', 'hidden_size')
self.groups = self.get_config_from_sec('model', 'groups')
self.expansion = self.get_config_from_sec('model', 'expansion')
self.drop_rate = self.get_config_from_sec('model', 'drop_rate')
self.gating_reduction = self.get_config_from_sec('model',
'gating_reduction')
self.eigen_file = self.get_config_from_sec('model', 'eigen_file')
# training params
self.base_learning_rate = self.get_config_from_sec('train',
'learning_rate')
self.lr_boundary_examples = self.get_config_from_sec(
'train', 'lr_boundary_examples')
self.max_iter = self.get_config_from_sec('train', 'max_iter')
self.learning_rate_decay = self.get_config_from_sec(
'train', 'learning_rate_decay')
self.l2_penalty = self.get_config_from_sec('train', 'l2_penalty')
self.gradient_clip_norm = self.get_config_from_sec('train',
'gradient_clip_norm')
self.use_gpu = self.get_config_from_sec('train', 'use_gpu')
self.num_gpus = self.get_config_from_sec('train', 'num_gpus')
# other params
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
def build_input(self, use_pyreader=True):
rgb_shape = [self.video_feature_size]
audio_shape = [self.audio_feature_size]
label_shape = [self.num_classes]
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
py_reader = fluid.layers.py_reader(
capacity=100,
shapes=[[-1] + rgb_shape, [-1] + audio_shape,
[-1] + label_shape],
lod_levels=[1, 1, 0],
dtypes=['float32', 'float32', 'float32'],
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
rgb, audio, label = fluid.layers.read_file(py_reader)
self.py_reader = py_reader
else:
rgb = fluid.layers.data(
name='train_rgb' if self.is_training else 'test_rgb',
shape=rgb_shape,
dtype='float32',
lod_level=1)
audio = fluid.layers.data(
name='train_audio' if self.is_training else 'test_audio',
shape=audio_shape,
dtype='float32',
lod_level=1)
if self.mode == 'infer':
label = None
else:
label = fluid.layers.data(
name='train_label' if self.is_training else 'test_label',
shape=label_shape,
dtype='float32')
self.feature_input = [rgb, audio]
self.label_input = label
def create_model_args(self):
model_args = {}
model_args['class_dim'] = self.num_classes
model_args['cluster_size'] = self.cluster_size
model_args['hidden_size'] = self.hidden_size
model_args['groups'] = self.groups
model_args['expansion'] = self.expansion
model_args['drop_rate'] = self.drop_rate
model_args['gating_reduction'] = self.gating_reduction
model_args['l2_penalty'] = self.l2_penalty
return model_args
def build_model(self):
model_args = self.create_model_args()
videomodel = nextvlad_model.NeXtVLADModel()
rgb = self.feature_input[0]
audio = self.feature_input[1]
out = videomodel.create_model(
rgb, audio, is_training=(self.mode == 'train'), **model_args)
self.logits = out['logits']
self.predictions = out['predictions']
self.network_outputs = [out['predictions']]
def optimizer(self):
assert self.mode == 'train', "optimizer only can be get in train mode"
im_per_batch = self.batch_size
lr_bounds, lr_values = get_learning_rate_decay_list(
self.base_learning_rate, self.learning_rate_decay, self.max_iter,
self.lr_boundary_examples, im_per_batch)
return fluid.optimizer.AdamOptimizer(
learning_rate=fluid.layers.piecewise_decay(
boundaries=lr_bounds, values=lr_values))
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
cost = fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.logits, label=self.label_input)
cost = fluid.layers.reduce_sum(cost, dim=-1)
self.loss_ = fluid.layers.mean(x=cost)
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def weights_info(self):
return ('nextvlad_youtube8m',
'https://paddlemodels.bj.bcebos.com/video_classification/nextvlad_youtube8m.tar.gz')
def get_learning_rate_decay_list(base_learning_rate, decay, max_iter,
decay_examples, total_batch_size):
decay_step = decay_examples // total_batch_size
lr_bounds = []
lr_values = [base_learning_rate]
i = 1
while True:
if i * decay_step >= max_iter:
break
lr_bounds.append(i * decay_step)
lr_values.append(base_learning_rate * (decay**i))
i += 1
return lr_bounds, lr_values
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from . import clf_model
class NeXtVLAD(object):
"""
This is a paddlepaddle implementation of the NeXtVLAD model. For more
information, please refer to the paper,
https://static.googleusercontent.com/media/research.google.com/zh-CN//youtube8m/workshop2018/p_c03.pdf
"""
def __init__(self,
feature_size,
cluster_size,
is_training=True,
expansion=2,
groups=None,
inputname='video'):
self.feature_size = feature_size
self.cluster_size = cluster_size
self.is_training = is_training
self.expansion = expansion
self.groups = groups
self.name = inputname + '_'
def forward(self, input):
input = fluid.layers.fc(
input=input,
size=self.expansion * self.feature_size,
act=None,
name=self.name + 'fc_expansion',
param_attr=fluid.ParamAttr(
name=self.name + 'fc_expansion_w',
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=fluid.ParamAttr(
name=self.name + 'fc_expansion_b',
initializer=fluid.initializer.Constant(value=0.)))
# attention factor of per group
attention = fluid.layers.fc(
input=input,
size=self.groups,
act='sigmoid',
name=self.name + 'fc_group_attention',
param_attr=fluid.ParamAttr(
name=self.name + 'fc_group_attention_w',
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=fluid.ParamAttr(
name=self.name + 'fc_group_attention_b',
initializer=fluid.initializer.Constant(value=0.)))
# calculate activation factor of per group per cluster
feature_size = self.feature_size * self.expansion // self.groups
cluster_weights = fluid.layers.create_parameter(
shape=[
self.expansion * self.feature_size,
self.groups * self.cluster_size
],
dtype=input.dtype,
attr=fluid.ParamAttr(name=self.name + 'cluster_weights'),
default_initializer=fluid.initializer.MSRA(uniform=False))
activation = fluid.layers.matmul(input, cluster_weights)
activation = fluid.layers.batch_norm(
activation, is_test=(not self.is_training))
# reshape of activation
activation = fluid.layers.reshape(activation,
[-1, self.groups, self.cluster_size])
# softmax on per cluster
activation = fluid.layers.softmax(activation)
activation = fluid.layers.elementwise_mul(activation, attention, axis=0)
a_sum = fluid.layers.sequence_pool(activation, 'sum')
a_sum = fluid.layers.reduce_sum(a_sum, dim=1)
# create cluster_weights2
cluster_weights2 = fluid.layers.create_parameter(
shape=[self.cluster_size, feature_size],
dtype=input.dtype,
attr=fluid.ParamAttr(name=self.name + 'cluster_weights2'),
default_initializer=fluid.initializer.MSRA(uniform=False))
# expand a_sum dimension from [-1, self.cluster_size] to be [-1, self.cluster_size, feature_size]
a_sum = fluid.layers.reshape(a_sum, [-1, self.cluster_size, 1])
a_sum = fluid.layers.expand(a_sum, [1, 1, feature_size])
# element wise multiply a_sum and cluster_weights2
a = fluid.layers.elementwise_mul(
a_sum, cluster_weights2,
axis=1) # output shape [-1, self.cluster_size, feature_size]
# transpose activation from [-1, self.groups, self.cluster_size] to [-1, self.cluster_size, self.groups]
activation2 = fluid.layers.transpose(activation, perm=[0, 2, 1])
# transpose op will clear the lod infomation, so it should be reset
activation = fluid.layers.lod_reset(activation2, activation)
# reshape input from [-1, self.expansion * self.feature_size] to [-1, self.groups, feature_size]
reshaped_input = fluid.layers.reshape(input,
[-1, self.groups, feature_size])
# mat multiply activation and reshaped_input
vlad = fluid.layers.matmul(
activation,
reshaped_input) # output shape [-1, self.cluster_size, feature_size]
vlad = fluid.layers.sequence_pool(vlad, 'sum')
vlad = fluid.layers.elementwise_sub(vlad, a)
# l2_normalization
vlad = fluid.layers.transpose(vlad, [0, 2, 1])
vlad = fluid.layers.l2_normalize(vlad, axis=1)
# reshape and batch norm
vlad = fluid.layers.reshape(vlad,
[-1, self.cluster_size * feature_size])
vlad = fluid.layers.batch_norm(vlad, is_test=(not self.is_training))
return vlad
class NeXtVLADModel(object):
"""
Creates a NeXtVLAD based model.
Args:
model_input: A LoDTensor of [-1, N] for the input video frames.
vocab_size: The number of classes in the dataset.
"""
def __init__(self):
pass
def create_model(self,
video_input,
audio_input,
is_training=True,
class_dim=None,
cluster_size=None,
hidden_size=None,
groups=None,
expansion=None,
drop_rate=None,
gating_reduction=None,
l2_penalty=None,
**unused_params):
# calcluate vlad of video and audio
video_nextvlad = NeXtVLAD(
1024,
cluster_size,
is_training,
expansion=expansion,
groups=groups,
inputname='video')
audio_nextvlad = NeXtVLAD(
128,
cluster_size,
is_training,
expansion=expansion,
groups=groups,
inputname='audio')
vlad_video = video_nextvlad.forward(video_input)
vlad_audio = audio_nextvlad.forward(audio_input)
# concat video and audio
vlad = fluid.layers.concat([vlad_video, vlad_audio], axis=1)
# drop out
if drop_rate > 0.:
vlad = fluid.layers.dropout(
vlad, drop_rate, is_test=(not is_training))
# add fc
activation = fluid.layers.fc(
input=vlad,
size=hidden_size,
act=None,
name='hidden1_fc',
param_attr=fluid.ParamAttr(
name='hidden1_fc_weights',
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=False)
activation = fluid.layers.batch_norm(
activation, is_test=(not is_training))
# add fc, gate 1
gates = fluid.layers.fc(
input=activation,
size=hidden_size // gating_reduction,
act=None,
name='gating_fc1',
param_attr=fluid.ParamAttr(
name='gating_fc1_weights',
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=False)
gates = fluid.layers.batch_norm(
gates, is_test=(not is_training), act='relu')
# add fc, gate 2
gates = fluid.layers.fc(
input=gates,
size=hidden_size,
act='sigmoid',
name='gating_fc2',
param_attr=fluid.ParamAttr(
name='gating_fc2_weights',
initializer=fluid.initializer.MSRA(uniform=False)),
bias_attr=False)
activation = fluid.layers.elementwise_mul(activation, gates)
aggregate_model = clf_model.LogisticModel # set classification model
return aggregate_model().create_model(
model_input=activation,
vocab_size=class_dim,
l2_penalty=l2_penalty,
is_training=is_training,
**unused_params)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import paddle.fluid as fluid
from ..model import ModelBase
from .stnet_res_model import StNet_ResNet
__all__ = ["STNET"]
class STNET(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(STNET, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
self.num_classes = self.get_config_from_sec('model', 'num_classes')
self.seg_num = self.get_config_from_sec('model', 'seg_num')
self.seglen = self.get_config_from_sec('model', 'seglen')
self.image_mean = self.get_config_from_sec('model', 'image_mean')
self.image_std = self.get_config_from_sec('model', 'image_std')
self.num_layers = self.get_config_from_sec('model', 'num_layers')
self.num_epochs = self.get_config_from_sec('train', 'epoch')
self.total_videos = self.get_config_from_sec('train', 'total_videos')
self.base_learning_rate = self.get_config_from_sec('train',
'learning_rate')
self.learning_rate_decay = self.get_config_from_sec(
'train', 'learning_rate_decay')
self.l2_weight_decay = self.get_config_from_sec('train',
'l2_weight_decay')
self.momentum = self.get_config_from_sec('train', 'momentum')
self.target_size = self.get_config_from_sec(self.mode, 'target_size')
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
def build_input(self, use_pyreader=True):
image_shape = [3, self.target_size, self.target_size]
image_shape[0] = image_shape[0] * self.seglen
image_shape = [self.seg_num] + image_shape
self.use_pyreader = use_pyreader
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
py_reader = fluid.layers.py_reader(
capacity=100,
shapes=[[-1] + image_shape, [-1] + [1]],
dtypes=['float32', 'int64'],
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
image, label = fluid.layers.read_file(py_reader)
self.py_reader = py_reader
else:
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
if self.mode != 'infer':
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
else:
label = None
self.feature_input = [image]
self.label_input = label
def create_model_args(self):
cfg = {}
cfg['layers'] = self.num_layers
cfg['class_dim'] = self.num_classes
cfg['seg_num'] = self.seg_num
cfg['seglen'] = self.seglen
return cfg
def build_model(self):
cfg = self.create_model_args()
videomodel = StNet_ResNet(layers = cfg['layers'], seg_num = cfg['seg_num'], \
seglen = cfg['seglen'], is_training = (self.mode == 'train'))
out = videomodel.net(input=self.feature_input[0],
class_dim=cfg['class_dim'])
self.network_outputs = [out]
def optimizer(self):
epoch_points = [self.num_epochs / 3, self.num_epochs * 2 / 3]
total_videos = self.total_videos
step = int(total_videos / self.batch_size + 1)
bd = [e * step for e in epoch_points]
base_lr = self.base_learning_rate
lr_decay = self.learning_rate_decay
lr = [base_lr, base_lr * lr_decay, base_lr * lr_decay * lr_decay]
l2_weight_decay = self.l2_weight_decay
momentum = self.momentum
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=momentum,
regularization=fluid.regularizer.L2Decay(l2_weight_decay))
return optimizer
def loss(self):
cost = fluid.layers.cross_entropy(input=self.network_outputs[0], \
label=self.label_input, ignore_index=-1)
self.loss_ = fluid.layers.mean(x=cost)
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def pretrain_info(self):
return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz')
def load_pretrain_params(self, exe, pretrain, prog, place):
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars)
param_tensor = fluid.global_scope().find_var(
"conv1_weights").get_tensor()
param_numpy = np.array(param_tensor)
param_numpy = np.mean(param_numpy, axis=1, keepdims=True) / self.seglen
param_numpy = np.repeat(param_numpy, 3 * self.seglen, axis=1)
param_tensor.set(param_numpy.astype(np.float32), place)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import time
import sys
import paddle.fluid as fluid
import math
class StNet_ResNet():
def __init__(self, layers=50, seg_num=7, seglen=5, is_training=True):
self.layers = layers
self.seglen = seglen
self.seg_num = seg_num
self.is_training = is_training
def temporal_conv_bn(
self,
input, #(B*seg_num, c, h, w)
num_filters,
filter_size=(3, 1, 1),
padding=(1, 0, 0)):
#(B, seg_num, c, h, w)
in_reshape = fluid.layers.reshape(
x=input,
shape=[
-1, self.seg_num, input.shape[-3], input.shape[-2],
input.shape[-1]
])
in_transpose = fluid.layers.transpose(in_reshape, perm=[0, 2, 1, 3, 4])
conv = fluid.layers.conv3d(
input=in_transpose,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
groups=1,
padding=padding,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.MSRAInitializer()),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0.0)))
out = fluid.layers.batch_norm(
input=conv,
act=None,
is_test=(not self.is_training),
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=1.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0.0)))
out = out + in_transpose
out = fluid.layers.transpose(out, perm=[0, 2, 1, 3, 4])
out = fluid.layers.reshape(x=out, shape=input.shape)
return out
def xception(self, input): #(B, C, seg_num,1)
bn = fluid.layers.batch_norm(
input=input,
act=None,
name="xception_bn",
is_test=(not self.is_training),
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=1.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0.0)))
att_conv = fluid.layers.conv2d(
input=bn,
num_filters=2048,
filter_size=[3, 1],
stride=[1, 1],
padding=[1, 0],
groups=2048,
name="xception_att_conv",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.MSRAInitializer()),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0)))
att_2 = fluid.layers.conv2d(
input=att_conv,
num_filters=1024,
filter_size=[1, 1],
stride=[1, 1],
name="xception_att_2",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.MSRAInitializer()),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0)))
bndw = fluid.layers.batch_norm(
input=att_2,
act="relu",
name="xception_bndw",
is_test=(not self.is_training),
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=1.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0.0)))
att1 = fluid.layers.conv2d(
input=bndw,
num_filters=1024,
filter_size=[3, 1],
stride=[1, 1],
padding=[1, 0],
groups=1024,
name="xception_att1",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.MSRAInitializer()),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0)))
att1_2 = fluid.layers.conv2d(
input=att1,
num_filters=1024,
filter_size=[1, 1],
stride=[1, 1],
name="xception_att1_2",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.MSRAInitializer()),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0)))
dw = fluid.layers.conv2d(
input=bn,
num_filters=1024,
filter_size=[1, 1],
stride=[1, 1],
name="xception_dw",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.MSRAInitializer()),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0)))
add_to = dw + att1_2
bn2 = fluid.layers.batch_norm(
input=add_to,
act=None,
name='xception_bn2',
is_test=(not self.is_training),
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=1.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer(value=0.0)))
return fluid.layers.relu(bn2)
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"),
bias_attr=False,
#name = name+".conv2d.output.1"
)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
is_test=(not self.is_training),
#name=bn_name+'.output.1',
param_attr=fluid.param_attr.ParamAttr(name=bn_name + "_scale"),
bias_attr=fluid.param_attr.ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input, num_filters * 4, stride, name=name + "_branch1")
return fluid.layers.elementwise_add(
x=short,
y=conv2,
act='relu',
#name=".add.output.5"
)
def net(self, input, class_dim=101):
layers = self.layers
seg_num = self.seg_num
seglen = self.seglen
supported_layers = [50, 101, 152]
if layers not in supported_layers:
print("supported layers are", supported_layers, \
"but input layer is ", layers)
exit()
# reshape input
# [B, seg_num, seglen*c, H, W] --> [B*seg_num, seglen*c, H, W]
channels = input.shape[2]
short_size = input.shape[3]
input = fluid.layers.reshape(
x=input, shape=[-1, channels, short_size, short_size])
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name='conv1')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name)
if block == 1:
#insert the first temporal modeling block
conv = self.temporal_conv_bn(input=conv, num_filters=512)
if block == 2:
#insert the second temporal modeling block
conv = self.temporal_conv_bn(input=conv, num_filters=1024)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
feature = fluid.layers.reshape(
x=pool, shape=[-1, seg_num, pool.shape[1], 1])
feature = fluid.layers.transpose(feature, perm=[0, 2, 1, 3])
#append the temporal Xception block
xfeat = self.xception(feature) #(B, 1024, seg_num, 1)
out = fluid.layers.pool2d(
input=xfeat,
pool_size=(seg_num, 1),
pool_type='max',
global_pooling=True)
out = fluid.layers.reshape(x=out, shape=[-1, 1024])
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=out,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from ..model import ModelBase
from .tsn_res_model import TSN_ResNet
__all__ = ["TSN"]
class TSN(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(TSN, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
self.num_classes = self.get_config_from_sec('model', 'num_classes')
self.seg_num = self.get_config_from_sec('model', 'seg_num')
self.seglen = self.get_config_from_sec('model', 'seglen')
self.image_mean = self.get_config_from_sec('model', 'image_mean')
self.image_std = self.get_config_from_sec('model', 'image_std')
self.num_layers = self.get_config_from_sec('model', 'num_layers')
self.num_epochs = self.get_config_from_sec('train', 'epoch')
self.total_videos = self.get_config_from_sec('train', 'total_videos')
self.base_learning_rate = self.get_config_from_sec('train',
'learning_rate')
self.learning_rate_decay = self.get_config_from_sec(
'train', 'learning_rate_decay')
self.l2_weight_decay = self.get_config_from_sec('train',
'l2_weight_decay')
self.momentum = self.get_config_from_sec('train', 'momentum')
self.target_size = self.get_config_from_sec(self.mode, 'target_size')
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
def build_input(self, use_pyreader=True):
image_shape = [3, self.target_size, self.target_size]
image_shape[0] = image_shape[0] * self.seglen
image_shape = [self.seg_num] + image_shape
self.use_pyreader = use_pyreader
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
py_reader = fluid.layers.py_reader(
capacity=100,
shapes=[[-1] + image_shape, [-1] + [1]],
dtypes=['float32', 'int64'],
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
image, label = fluid.layers.read_file(py_reader)
self.py_reader = py_reader
else:
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
if self.mode != 'infer':
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
else:
label = None
self.feature_input = [image]
self.label_input = label
def create_model_args(self):
cfg = {}
cfg['layers'] = self.num_layers
cfg['class_dim'] = self.num_classes
cfg['seg_num'] = self.seg_num
return cfg
def build_model(self):
cfg = self.create_model_args()
videomodel = TSN_ResNet(
layers=cfg['layers'],
seg_num=cfg['seg_num'],
is_training=(self.mode == 'train'))
out = videomodel.net(input=self.feature_input[0],
class_dim=cfg['class_dim'])
self.network_outputs = [out]
def optimizer(self):
assert self.mode == 'train', "optimizer only can be get in train mode"
epoch_points = [self.num_epochs / 3, self.num_epochs * 2 / 3]
total_videos = self.total_videos
step = int(total_videos / self.batch_size + 1)
bd = [e * step for e in epoch_points]
base_lr = self.base_learning_rate
lr_decay = self.learning_rate_decay
lr = [base_lr, base_lr * lr_decay, base_lr * lr_decay * lr_decay]
l2_weight_decay = self.l2_weight_decay
momentum = self.momentum
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=momentum,
regularization=fluid.regularizer.L2Decay(l2_weight_decay))
return optimizer
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
cost = fluid.layers.cross_entropy(input=self.network_outputs[0], \
label=self.label_input, ignore_index=-1)
self.loss_ = fluid.layers.mean(x=cost)
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import time
import sys
import paddle.fluid as fluid
import math
class TSN_ResNet():
def __init__(self, layers=50, seg_num=7, is_training=True):
self.layers = layers
self.seg_num = seg_num
self.is_training = is_training
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
is_test=(not self.is_training),
param_attr=fluid.param_attr.ParamAttr(name=bn_name + "_scale"),
bias_attr=fluid.param_attr.ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input, num_filters * 4, stride, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def net(self, input, class_dim=101):
layers = self.layers
seg_num = self.seg_num
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
# reshape input
channels = input.shape[2]
short_size = input.shape[3]
input = fluid.layers.reshape(
x=input, shape=[-1, channels, short_size, short_size])
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name='conv1')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
feature = fluid.layers.reshape(
x=pool, shape=[-1, seg_num, pool.shape[1]])
out = fluid.layers.reduce_mean(feature, dim=1)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=out,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import wget
import tarfile
__all__ = ['decompress', 'download', 'AttrDict']
def decompress(path):
t = tarfile.open(path)
t.extractall(path='/'.join(path.split('/')[:-1]))
t.close()
os.remove(path)
def download(url, path):
weight_dir = '/'.join(path.split('/')[:-1])
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
path = path + ".tar.gz"
wget.download(url, path)
decompress(path)
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
python3 infer.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
--filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/AttentionCluster_epoch0 \
--save-dir="./save"
python infer.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt \
--filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/AttentionLSTM_epoch0 \
--save-dir="./save"
python infer.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/NEXTVLAD_epoch0 \
--save-dir="./save"
python infer.py --model-name="STNET" --config=./configs/stnet.txt --filelist=./data/kinetics/infer.list \
--log-interval=10 --weights=./checkpoints/STNET_epoch0 --save-dir=./save
python infer.py --model-name="TSN" --config=./configs/tsn.txt --filelist=./data/kinetics/infer.list \
--log-interval=10 --weights=./checkpoints/TSN_epoch0 --save-dir=./save
python3 test.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \
--log-interval=5 --weights=./checkpoints/AttentionCluster_epoch0
python test.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt \
--log-interval=5 --weights=./checkpoints/AttentionLSTM_epoch0
python test.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt \
--log-interval=10 --weights=./checkpoints/NEXTVLAD_epoch0
python test.py --model-name="STNET" --config=./configs/stnet.txt \
--log-interval=10 --weights=./checkpoints/STNET_epoch0
python test.py --model-name="TSN" --config=./configs/tsn.txt \
--log-interval=10 --weights=./checkpoints/TSN_epoch0
python3 train.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch-num=5 \
--valid-interval=1 --save-interval=1 --log-interval=10
python3 train.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt --epoch-num=10 \
--valid-interval=1 --save-interval=1 --log-interval=10
python train.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt --epoch-num=6 \
--valid-interval=1 --save-interval=1 --log-interval=10
python train.py --model-name="STNET" --config=./configs/stnet.txt --epoch-num=60 \
--valid-interval=1 --save-interval=1 --log-interval=10
python train.py --model-name="TSN" --config=./configs/tsn.txt --epoch-num=45 \
--valid-interval=1 --save-interval=1 --log-interval=10
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
__all__ = ['AttrDict']
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册