提交 4044f30c 编写于 作者: L lixuanyi 提交者: lizz

save

Signed-off-by: NJoannaLXY <lixuanyi199801@gmail.com>
上级 9e341413
# model settings
model = dict(
type='Recognizer3D',
backbone=dict(
type='ResNet3d',
pretrained2d=True,
pretrained='torchvision://resnet50',
depth=50,
conv_cfg=dict(type='Conv3d'),
norm_eval=False,
inflate=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 1, 0)),
zero_init_residual=False),
cls_head=dict(
type='I3DHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'VideoDataset'
data_root = 's3://lizz.ssd/datasets/kinetics400_256/'
data_root_val = 's3://lizz.ssd/datasets/kinetics400_256/'
ann_file_train = 'data/kinetics400/k400_train.txt'
ann_file_val = 'data/kinetics400/k400_val.txt'
ann_file_test = 'data/kinetics400/k400_val.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')
train_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.8),
random_crop=False,
max_wh_scale_gap=0),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=10,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[40, 80])
total_epochs = 100
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/i3d_video_32x2x1_r50_3d_kinetics400_100e/'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='Recognizer3D',
backbone=dict(
type='ResNet2Plus1d',
depth=34,
pretrained=None,
pretrained2d=False,
norm_eval=False,
conv_cfg=dict(type='Conv2plus1d'),
norm_cfg=dict(type='SyncBN', requires_grad=True, eps=1e-3),
act_cfg=dict(type='ReLU'),
conv1_kernel=(3, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(1, 1, 1, 1),
spatial_strides=(1, 2, 2, 2),
temporal_strides=(1, 2, 2, 2),
zero_init_residual=False),
cls_head=dict(
type='I3DHead',
num_classes=400,
in_channels=512,
spatial_type='avg',
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'VideoDataset'
data_root = 's3://lizz.ssd/datasets/kinetics400_256/'
data_root_val = 's3://lizz.ssd/datasets/kinetics400_256/'
ann_file_train = 'data/kinetics400/k400_train.txt'
ann_file_val = 'data/kinetics400/k400_val.txt'
ann_file_test = 'data/kinetics400/k400_val.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')
train_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(type='SampleFrames', clip_len=8, frame_interval=8, num_clips=1),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(
type='SampleFrames',
clip_len=8,
frame_interval=8,
num_clips=1,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(
type='SampleFrames',
clip_len=8,
frame_interval=8,
num_clips=10,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=24,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline,
test_mode=True),
test=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=test_pipeline,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.6, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(
policy='cosine',
warmup='linear',
warmup_ratio=0.1,
warmup_byepoch=True,
warmup_iters=40)
total_epochs = 180
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/r2plus1d_video_8x8x1_r34_3d_kinetics400_180e/'
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = False
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTIN',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
shift_div=4),
cls_head=dict(
type='TINHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'VideoDataset'
data_root = 's3://lizz.ssd/datasets/kinetics400_256/'
data_root_val = 's3://lizz.ssd/datasets/kinetics400_256/'
ann_file_train = 'data/kinetics400/k400_train.txt'
ann_file_val = 'data/kinetics400/k400_val.txt'
ann_file_test = 'data/kinetics400/k400_val.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')
train_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=4),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=4),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=4),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='MultiGroupCrop', crop_size=256, groups=1),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=6,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD',
constructor='TSMOptimizerConstructor',
paramwise_cfg=dict(fc_lr5=True),
lr=0.005,
momentum=0.9,
weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[10, 20, 30])
total_epochs = 35
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tin_video_1x1x8_r50_2d_kinetics400_35e/'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTSM',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
shift_div=8),
cls_head=dict(
type='TSMHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'VideoDataset'
data_root = 's3://lizz.ssd/datasets/kinetics400_256/'
data_root_val = 's3://lizz.ssd/datasets/kinetics400_256/'
ann_file_train = 'data/kinetics400/k400_train.txt'
ann_file_val = 'data/kinetics400/k400_val.txt'
ann_file_test = 'data/kinetics400/k400_val.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')
train_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=4),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=4),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD',
constructor='TSMOptimizerConstructor',
paramwise_cfg=dict(fc_lr5=True),
lr=0.02,
momentum=0.9,
weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[20, 40])
total_epochs = 50
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsm_video_1x1x8_r50_2d_kinetics400_100e/'
load_from = None
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNet',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'VideoDataset'
data_root = 's3://lizz.ssd/datasets/kinetics400_256/'
data_root_val = 's3://lizz.ssd/datasets/kinetics400_256/'
ann_file_train = 'data/kinetics400/k400_train.txt'
ann_file_val = 'data/kinetics400/k400_val.txt'
ann_file_test = 'data/kinetics400/k400_val.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
mc_cfg = dict(
server_list_cfg='/mnt/lustre/share/memcached_client/server_list.conf',
client_cfg='/mnt/lustre/share/memcached_client/client.conf',
sys_path='/mnt/lustre/share/pymc/py3')
train_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
dict(type='DecordDecode'),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit', io_backend='petrel', num_threads=1),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='TenCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=32,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[40, 80])
total_epochs = 100
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsn_video_1x1x3_r50_2d_kinetics400_100e/'
load_from = None
resume_from = None
workflow = [('train', 1)]
......@@ -4,9 +4,10 @@ from .augmentations import (CenterCrop, Flip, Fuse, MultiGroupCrop,
from .compose import Compose
from .formating import (Collect, FormatShape, ImageToTensor, ToDataContainer,
ToTensor, Transpose)
from .loading import (DecordDecode, DenseSampleFrames, FrameSelector,
GenerateLocalizationLabels, LoadLocalizationFeature,
LoadProposals, OpenCVDecode, PyAVDecode, SampleFrames)
from .loading import (DecordDecode, DecordInit, DenseSampleFrames,
FrameSelector, GenerateLocalizationLabels,
LoadLocalizationFeature, LoadProposals, OpenCVDecode,
OpenCVInit, PyAVDecode, PyAVInit, SampleFrames)
__all__ = [
'SampleFrames', 'PyAVDecode', 'DecordDecode', 'DenseSampleFrames',
......@@ -14,5 +15,6 @@ __all__ = [
'RandomResizedCrop', 'RandomCrop', 'Resize', 'Flip', 'Fuse', 'Normalize',
'ThreeCrop', 'CenterCrop', 'TenCrop', 'ImageToTensor', 'Transpose',
'Collect', 'FormatShape', 'Compose', 'ToTensor', 'ToDataContainer',
'GenerateLocalizationLabels', 'LoadLocalizationFeature', 'LoadProposals'
'GenerateLocalizationLabels', 'LoadLocalizationFeature', 'LoadProposals',
'DecordInit', 'OpenCVInit', 'PyAVInit'
]
import io
import os
import os.path as osp
import shutil
import mmcv
import numpy as np
from mmcv.fileio import FileClient
from ...utils import get_random_string, get_shm_dir, get_thread_id
from ..registry import PIPELINES
......@@ -131,7 +134,6 @@ class SampleFrames(object):
frame_inds += perframe_offsets
frame_inds = np.mod(frame_inds, total_frames)
results['frame_inds'] = frame_inds.astype(np.int)
results['clip_len'] = self.clip_len
results['frame_interval'] = self.frame_interval
......@@ -223,24 +225,21 @@ class DenseSampleFrames(SampleFrames):
@PIPELINES.register_module
class PyAVDecode(object):
"""Using pyav to decode the video.
class PyAVInit(object):
"""Using pyav to initialize the video.
PyAV: https://github.com/mikeboers/PyAV
Required keys are "filename" and "frame_inds",
added or modified keys are "imgs", "img_shape" and "original_shape".
Required keys are "filename",
added or modified keys are "video_reader", and "total_frames".
Args:
multi_thread (bool): If set to True, it will apply multi
thread processing. Default: False.
io_backend (str): io backend where frames are store.
Default: 'disk'.
kwargs (dict): Args for file client.
"""
def __init__(self, multi_thread=False, io_backend='disk', **kwargs):
self.multi_thread = multi_thread
def __init__(self, io_backend='disk', **kwargs):
self.io_backend = io_backend
self.kwargs = kwargs
self.file_client = None
......@@ -258,6 +257,31 @@ class PyAVDecode(object):
file_obj = io.BytesIO(self.file_client.get(results['filename']))
container = av.open(file_obj)
results['video_reader'] = container
results['total_frames'] = container.streams.video[0].frames
return results
@PIPELINES.register_module
class PyAVDecode(object):
"""Using pyav to decode the video.
PyAV: https://github.com/mikeboers/PyAV
Required keys are "video_reader" and "frame_inds",
added or modified keys are "imgs", "img_shape" and "original_shape".
Args:
multi_thread (bool): If set to True, it will apply multi
thread processing. Default: False.
"""
def __init__(self, multi_thread=False):
self.multi_thread = multi_thread
def __call__(self, results):
container = results['video_reader']
imgs = list()
if self.multi_thread:
......@@ -274,9 +298,13 @@ class PyAVDecode(object):
imgs.append(frame.to_rgb().to_ndarray())
i += 1
results['video_reader'] = None
del container
# the available frame in pyav may be less than its length,
# which may raise error
results['imgs'] = [imgs[i % len(imgs)] for i in results['frame_inds']]
results['original_shape'] = imgs[0].shape[:2]
results['img_shape'] = imgs[0].shape[:2]
......@@ -289,15 +317,23 @@ class PyAVDecode(object):
@PIPELINES.register_module
class DecordDecode(object):
"""Using decord to decode the video.
class DecordInit(object):
"""Using decord to initialize the video_reader.
Decord: https://github.com/dmlc/decord
Required keys are "filename" and "frame_inds",
added or modified keys are "imgs", "img_shape" and "original_shape".
Required keys are "filename",
added or modified keys are "new_path", "video_reader" and "total_frames".
"""
def __init__(self, io_backend='disk', num_threads=1, **kwargs):
self.io_backend = io_backend
self.num_threads = num_threads
self.kwargs = kwargs
self.file_client = None
self.tmp_folder = osp.join(get_shm_dir(), get_random_string())
os.mkdir(self.tmp_folder)
def __call__(self, results):
try:
import decord
......@@ -305,9 +341,44 @@ class DecordDecode(object):
raise ImportError(
'Please run "pip install decord" to install Decord first.')
container = decord.VideoReader(results['filename'])
imgs = list()
if self.io_backend == 'disk':
new_path = results['filename']
else:
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
thread_id = get_thread_id()
# save the file of same thread at the same place
new_path = osp.join(self.tmp_folder, f'tmp_{thread_id}.mp4')
with open(new_path, 'wb') as f:
f.write(self.file_client.get(results['filename']))
container = decord.VideoReader(new_path, num_threads=self.num_threads)
results['new_path'] = new_path
results['video_reader'] = container
results['total_frames'] = len(container)
return results
def __del__(self):
shutil.rmtree(self.tmp_folder)
@PIPELINES.register_module
class DecordDecode(object):
"""Using decord to decode the video.
Decord: https://github.com/dmlc/decord
Required keys are "video_reader", "filename" and "frame_inds",
added or modified keys are "imgs" and "original_shape".
"""
def __init__(self, **kwargs):
pass
def __call__(self, results):
container = results['video_reader']
imgs = list()
if results['frame_inds'].ndim != 1:
results['frame_inds'] = np.squeeze(results['frame_inds'])
......@@ -315,6 +386,9 @@ class DecordDecode(object):
cur_frame = container[frame_idx].asnumpy()
imgs.append(cur_frame)
results['video_reader'] = None
del container
results['imgs'] = imgs
results['original_shape'] = imgs[0].shape[:2]
results['img_shape'] = imgs[0].shape[:2]
......@@ -322,16 +396,58 @@ class DecordDecode(object):
return results
@PIPELINES.register_module
class OpenCVInit(object):
"""Using OpenCV to initalize the video_reader.
Required keys are "filename",
added or modified keys are "new_path", "video_reader" and "total_frames".
"""
def __init__(self, io_backend='disk', **kwargs):
self.io_backend = io_backend
self.kwargs = kwargs
self.file_client = None
self.tmp_folder = osp.join(get_shm_dir(), get_random_string())
os.mkdir(self.tmp_folder)
def __call__(self, results):
if self.io_backend == 'disk':
new_path = results['filename']
else:
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
thread_id = get_thread_id()
# save the file of same thread at the same place
new_path = osp.join(self.tmp_folder, f'tmp_{thread_id}.mp4')
with open(new_path, 'wb') as f:
f.write(self.file_client.get(results['filename']))
container = mmcv.VideoReader(new_path)
results['new_path'] = new_path
results['video_reader'] = container
results['total_frames'] = len(container)
return results
def __del__(self):
shutil.rmtree(self.tmp_folder)
@PIPELINES.register_module
class OpenCVDecode(object):
"""Using OpenCV to decode the video.
Required keys are "filename" and "frame_inds",
Required keys are "video_reader", "filename" and "frame_inds",
added or modified keys are "imgs", "img_shape" and "original_shape".
"""
def __init__(self):
pass
def __call__(self, results):
container = mmcv.VideoReader(results['filename'])
container = results['video_reader']
imgs = list()
if results['frame_inds'].ndim != 1:
......@@ -345,6 +461,9 @@ class OpenCVDecode(object):
cur_frame = container[frame_ind]
imgs.append(cur_frame)
results['video_reader'] = None
del container
imgs = np.array(imgs)
# The default channel order of OpenCV is BGR, thus we change it to RGB
imgs = imgs[:, :, :, ::-1]
......
from .collect_env import collect_env
from .logger import get_root_logger
from .misc import get_random_string, get_shm_dir, get_thread_id
__all__ = ['get_root_logger', 'collect_env']
__all__ = [
'get_root_logger', 'collect_env', 'get_random_string', 'get_thread_id',
'get_shm_dir'
]
import ctypes
import random
import string
def get_random_string(length=15):
"""Get random string with letters and digits.
Args:
length (int): Length of random string. Default: 15.
"""
return ''.join(
random.choice(string.ascii_letters + string.digits)
for _ in range(length))
def get_thread_id():
"""Get current thread id.
"""
# use ctype to find thread id
thread_id = ctypes.CDLL('libc.so.6').syscall(186)
return thread_id
def get_shm_dir():
"""Get shm dir for temporary usage.
"""
return '/dev/shm'
......@@ -35,6 +35,7 @@ class TestDataset(object):
dict(type='FrameSelector', io_backend='disk')
]
cls.video_pipeline = [
dict(type='OpenCVInit'),
dict(
type='SampleFrames',
clip_len=32,
......
......@@ -6,12 +6,13 @@ import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from mmaction.datasets.pipelines import (DecordDecode, DenseSampleFrames,
FrameSelector,
from mmaction.datasets.pipelines import (DecordDecode, DecordInit,
DenseSampleFrames, FrameSelector,
GenerateLocalizationLabels,
LoadLocalizationFeature,
LoadProposals, OpenCVDecode,
PyAVDecode, SampleFrames)
OpenCVInit, PyAVDecode, PyAVInit,
SampleFrames)
class TestLoading(object):
......@@ -275,6 +276,14 @@ class TestLoading(object):
dense_sample_frames_results = dense_sample_frames(frame_result)
assert len(dense_sample_frames_results['frame_inds']) == 120
def test_pyav_init(self):
target_keys = ['video_reader', 'total_frames']
video_result = copy.deepcopy(self.video_results)
pyav_init = PyAVInit()
pyav_init_result = pyav_init(video_result)
assert self.check_keys_contain(pyav_init_result.keys(), target_keys)
assert pyav_init_result['total_frames'] == 300
def test_pyav_decode(self):
target_keys = ['frame_inds', 'imgs', 'original_shape']
......@@ -282,6 +291,10 @@ class TestLoading(object):
video_result = copy.deepcopy(self.video_results)
video_result['frame_inds'] = np.arange(0, self.total_frames,
2)[:, np.newaxis]
pyav_init = PyAVInit()
pyav_init_result = pyav_init(video_result)
video_result['video_reader'] = pyav_init_result['video_reader']
pyav_decode = PyAVDecode()
pyav_decode_result = pyav_decode(video_result)
assert self.check_keys_contain(pyav_decode_result.keys(), target_keys)
......@@ -292,6 +305,10 @@ class TestLoading(object):
# test PyAV with 1 dim input
video_result = copy.deepcopy(self.video_results)
video_result['frame_inds'] = np.arange(1, self.total_frames, 5)
pyav_init = PyAVInit()
pyav_init_result = pyav_init(video_result)
video_result['video_reader'] = pyav_init_result['video_reader']
pyav_decode = PyAVDecode()
pyav_decode_result = pyav_decode(video_result)
assert self.check_keys_contain(pyav_decode_result.keys(), target_keys)
......@@ -302,6 +319,10 @@ class TestLoading(object):
# PyAV with multi thread
video_result = copy.deepcopy(self.video_results)
video_result['frame_inds'] = np.arange(1, self.total_frames, 5)
pyav_init = PyAVInit()
pyav_init_result = pyav_init(video_result)
video_result['video_reader'] = pyav_init_result['video_reader']
pyav_decode = PyAVDecode(multi_thread=True)
pyav_decode_result = pyav_decode(video_result)
assert self.check_keys_contain(pyav_decode_result.keys(), target_keys)
......@@ -312,6 +333,15 @@ class TestLoading(object):
assert repr(pyav_decode) == pyav_decode.__class__.__name__ + \
f'(multi_thread={True})'
def test_decord_init(self):
target_keys = ['new_path', 'video_reader', 'total_frames']
video_result = copy.deepcopy(self.video_results)
decord_init = DecordInit()
decord_init_result = decord_init(video_result)
assert self.check_keys_contain(decord_init_result.keys(), target_keys)
assert decord_init_result['total_frames'] == len(
decord_init_result['video_reader'])
def test_decord_decode(self):
target_keys = ['frame_inds', 'imgs', 'original_shape']
......@@ -319,6 +349,10 @@ class TestLoading(object):
video_result = copy.deepcopy(self.video_results)
video_result['frame_inds'] = np.arange(1, self.total_frames,
3)[:, np.newaxis]
decord_init = DecordInit()
decord_init_result = decord_init(video_result)
video_result['video_reader'] = decord_init_result['video_reader']
decord_decode = DecordDecode()
decord_decode_result = decord_decode(video_result)
assert self.check_keys_contain(decord_decode_result.keys(),
......@@ -330,6 +364,10 @@ class TestLoading(object):
# test Decord with 1 dim input
video_result = copy.deepcopy(self.video_results)
video_result['frame_inds'] = np.arange(1, self.total_frames, 3)
decord_init = DecordInit()
decord_init_result = decord_init(video_result)
video_result['video_reader'] = decord_init_result['video_reader']
decord_decode = DecordDecode()
decord_decode_result = decord_decode(video_result)
assert self.check_keys_contain(decord_decode_result.keys(),
......@@ -338,6 +376,15 @@ class TestLoading(object):
assert np.shape(decord_decode_result['imgs']) == (len(
video_result['frame_inds']), 256, 340, 3)
def test_opencv_init(self):
target_keys = ['new_path', 'video_reader', 'total_frames']
video_result = copy.deepcopy(self.video_results)
opencv_init = OpenCVInit()
opencv_init_result = opencv_init(video_result)
assert self.check_keys_contain(opencv_init_result.keys(), target_keys)
assert opencv_init_result['total_frames'] == len(
opencv_init_result['video_reader'])
def test_opencv_decode(self):
target_keys = ['frame_inds', 'imgs', 'original_shape']
......@@ -345,6 +392,10 @@ class TestLoading(object):
video_result = copy.deepcopy(self.video_results)
video_result['frame_inds'] = np.arange(0, self.total_frames,
2)[:, np.newaxis]
opencv_init = OpenCVInit()
opencv_init_result = opencv_init(video_result)
video_result['video_reader'] = opencv_init_result['video_reader']
opencv_decode = OpenCVDecode()
opencv_decode_result = opencv_decode(video_result)
assert self.check_keys_contain(opencv_decode_result.keys(),
......@@ -356,6 +407,10 @@ class TestLoading(object):
# test OpenCV with 1 dim input
video_result = copy.deepcopy(self.video_results)
video_result['frame_inds'] = np.arange(1, self.total_frames, 3)
opencv_init = OpenCVInit()
opencv_init_result = opencv_init(video_result)
video_result['video_reader'] = opencv_init_result['video_reader']
opencv_decode = OpenCVDecode()
opencv_decode_result = opencv_decode(video_result)
assert self.check_keys_contain(opencv_decode_result.keys(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册