未验证 提交 25fbd3be 编写于 作者: J Jiaqi Tang 提交者: GitHub

Add SSN dataset. (#37)

上级 1ae01225
from .accuracy import (average_recall_at_avg_proposals, confusion_matrix,
from .accuracy import (average_precision_at_temporal_iou,
average_recall_at_avg_proposals, confusion_matrix,
get_weighted_score, mean_average_precision,
mean_class_accuracy, pairwise_temporal_iou,
mean_class_accuracy, pairwise_temporal_iou, softmax,
top_k_accuracy)
from .eval_hooks import DistEvalHook, EvalHook
__all__ = [
'DistEvalHook', 'EvalHook', 'top_k_accuracy', 'mean_class_accuracy',
'confusion_matrix', 'mean_average_precision', 'get_weighted_score',
'average_recall_at_avg_proposals', 'pairwise_temporal_iou'
'average_recall_at_avg_proposals', 'pairwise_temporal_iou', 'softmax',
'average_precision_at_temporal_iou'
]
......@@ -171,19 +171,23 @@ def pairwise_temporal_iou(candidate_segments, target_segments):
"""Compute intersection over union between segments.
Args:
candidate_segments (np.ndarray): 2-dim array in format
[m x 2:=[init, end]].
candidate_segments (np.ndarray): 1-dim/2-dim array in format
[init, end]/[m x 2:=[init, end]].
target_segments (np.ndarray): 2-dim array in format
[n x 2:=[init, end]].
Returns:
temporal_iou (np.ndarray): 2-dim array [n x m] with IoU ratio.
t_iou (np.ndarray): 1-dim array [n] /
2-dim array [n x m] with IoU ratio.
"""
if target_segments.ndim != 2 or candidate_segments.ndim != 2:
if target_segments.ndim != 2 or candidate_segments.ndim not in [1, 2]:
raise ValueError('Dimension of arguments is incorrect')
if candidate_segments.ndim == 1:
candidate_segments = candidate_segments[np.newaxis, :]
n, m = target_segments.shape[0], candidate_segments.shape[0]
temporal_iou = np.empty((n, m))
t_iou = np.empty((n, m), dtype=np.float32)
for i in range(m):
candidate_segment = candidate_segments[i, :]
tt1 = np.maximum(candidate_segment[0], target_segments[:, 0])
......@@ -196,10 +200,9 @@ def pairwise_temporal_iou(candidate_segments, target_segments):
segments_intersection)
# Compute overlap as the ratio of the intersection
# over union of two segments.
temporal_iou[:, i] = (
segments_intersection.astype(float) / segments_union)
t_iou[:, i] = (segments_intersection.astype(float) / segments_union)
return temporal_iou
return t_iou
def average_recall_at_avg_proposals(ground_truth,
......@@ -275,9 +278,9 @@ def average_recall_at_avg_proposals(ground_truth,
num_retrieved_proposals, :]
# Compute temporal_iou scores.
temporal_iou = pairwise_temporal_iou(this_video_proposals,
this_video_ground_truth)
score_list.append(temporal_iou)
t_iou = pairwise_temporal_iou(this_video_proposals,
this_video_ground_truth)
score_list.append(t_iou)
# Given that the length of the videos is really varied, we
# compute the number of proposals in terms of a ratio of the total
......@@ -351,3 +354,114 @@ def get_weighted_score(score_list, coeff_list):
coeff = np.array(coeff_list) # (num_coeff, )
weighted_scores = list(np.dot(scores.T, coeff).T)
return weighted_scores
def softmax(x, dim=1):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x, axis=dim, keepdims=True))
return e_x / e_x.sum(axis=dim, keepdims=True)
def interpolated_precision_recall(precision, recall):
"""Interpolated AP - VOCdevkit from VOC 2011.
Args:
precision (np.ndarray): The precision of different thresholds.
recall (np.ndarray): The recall of different thresholds.
Returns:
float: Average precision score.
"""
mprecision = np.hstack([[0], precision, [0]])
mrecall = np.hstack([[0], recall, [1]])
for i in range(len(mprecision) - 1)[::-1]:
mprecision[i] = max(mprecision[i], mprecision[i + 1])
idx = np.where(mrecall[1::] != mrecall[0:-1])[0] + 1
ap = np.sum((mrecall[idx] - mrecall[idx - 1]) * mprecision[idx])
return ap
def average_precision_at_temporal_iou(ground_truth,
prediction,
temporal_iou_thresholds=(np.linspace(
0.5, 0.95, 10))):
"""Compute average precision (in detection task) between ground truth and
predicted data frames. If multiple predictions match the same predicted
segment, only the one with highest score is matched as true positive. This
code is greatly inspired by Pascal VOC devkit.
Args:
ground_truth (dict): Dict containing the ground truth instances.
Key: 'video_id'
Value (np.ndarry): 1D array of 't-start' and 't-end'.
proposals (np.ndarray): 2D array containing the information of proposal
instances, including 'video_id', 'class_id', 't-start', 't-end' and
'score'.
temporal_iou_thresholds (np.ndarray): 1D array with temporal_iou
thresholds. Default: np.linspace(0.5, 0.95, 10).
Returns:
np.ndarray: 1D array of average precision score.
"""
ap = np.zeros(len(temporal_iou_thresholds), dtype=np.float32)
if len(prediction) < 1:
return ap
num_gts = 0.
lock_gt = dict()
for key in ground_truth:
lock_gt[key] = np.ones(
(len(temporal_iou_thresholds), len(ground_truth[key]))) * -1
num_gts += len(ground_truth[key])
# Sort predictions by decreasing score order.
prediction = np.array(prediction)
scores = prediction[:, 4].astype(float)
sort_idx = np.argsort(scores)[::-1]
prediction = prediction[sort_idx]
# Initialize true positive and false positive vectors.
tp = np.zeros((len(temporal_iou_thresholds), len(prediction)),
dtype=np.int32)
fp = np.zeros((len(temporal_iou_thresholds), len(prediction)),
dtype=np.int32)
# Assigning true positive to truly grount truth instances.
for idx, this_pred in enumerate(prediction):
# Check if there is at least one ground truth in the video.
if (this_pred[0] in ground_truth):
this_gt = np.array(ground_truth[this_pred[0]], dtype=float)
else:
fp[:, idx] = 1
continue
t_iou = pairwise_temporal_iou(this_pred[2:4].astype(float), this_gt)
# We would like to retrieve the predictions with highest t_iou score.
t_iou_sorted_idx = t_iou.argsort()[::-1]
for t_idx, t_iou_threshold in enumerate(temporal_iou_thresholds):
for jdx in t_iou_sorted_idx:
if t_iou[jdx] < t_iou_threshold:
fp[t_idx, idx] = 1
break
if lock_gt[this_pred[0]][t_idx, jdx] >= 0:
continue
# Assign as true positive after the filters above.
tp[t_idx, idx] = 1
lock_gt[this_pred[0]][t_idx, jdx] = idx
break
if fp[t_idx, idx] == 0 and tp[t_idx, idx] == 0:
fp[t_idx, idx] = 1
tp_cumsum = np.cumsum(tp, axis=1).astype(np.float32)
fp_cumsum = np.cumsum(fp, axis=1).astype(np.float32)
recall_cumsum = tp_cumsum / num_gts
precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum)
for t_idx in range(len(temporal_iou_thresholds)):
ap[t_idx] = interpolated_precision_recall(precision_cumsum[t_idx, :],
recall_cumsum[t_idx, :])
return ap
......@@ -3,9 +3,10 @@ from .base import BaseDataset
from .builder import build_dataloader, build_dataset
from .dataset_wrappers import RepeatDataset
from .rawframe_dataset import RawframeDataset
from .ssn_dataset import SSNDataset
from .video_dataset import VideoDataset
__all__ = [
'VideoDataset', 'build_dataloader', 'build_dataset', 'RepeatDataset',
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset'
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset'
]
......@@ -443,7 +443,7 @@ class SampleProposalFrames(SampleFrames):
valid_ending = min(num_frames - ori_clip_len + 1,
end_frame - 1 + int(duration * self.aug_ratio[1]))
valid_starting_length = (start_frame - valid_starting - ori_clip_len)
valid_starting_length = start_frame - valid_starting - ori_clip_len
valid_ending_length = (valid_ending - end_frame + 1) - ori_clip_len
if self.mode == 'train':
......@@ -533,9 +533,6 @@ class SampleProposalFrames(SampleFrames):
"""
total_frames = results['total_frames']
assert 'out_props' not in results, (
"'out_props' is out of date, please use 'out_proposals'")
out_proposals = results.get('out_proposals', None)
clip_offsets = self._sample_clips(total_frames, out_proposals)
frame_inds = clip_offsets[:, None] + np.arange(
......
此差异已折叠。
from .bsn_utils import generate_bsp_feature, generate_candidate_proposals
from .proposal_utils import soft_nms, temporal_iop, temporal_iou
from .ssn_utils import (eval_ap, load_localize_proposal_file,
perform_regression, temporal_nms)
__all__ = [
'generate_candidate_proposals', 'generate_bsp_feature', 'temporal_iop',
'temporal_iou', 'soft_nms'
'temporal_iou', 'soft_nms', 'load_localize_proposal_file',
'perform_regression', 'temporal_nms', 'eval_ap'
]
from itertools import groupby
import numpy as np
from ..core import average_precision_at_temporal_iou
from . import temporal_iou
def load_localize_proposal_file(filename):
"""Load the proposal file and split it into many parts which contain one
video's information separately.
Args:
filename(str): Path to the proposal file.
Returns:
list: List of all videos' information.
"""
lines = list(open(filename))
# Split the proposal file into many parts which contain one video's
# information separately.
groups = groupby(lines, lambda x: x.startswith('#'))
video_infos = [[x.strip() for x in list(g)] for k, g in groups if not k]
def parse_group(video_info):
"""Parse the video's information.
Template information of a video in a standard file:
# index
video_id
num_frames
fps
num_gts
label, start_frame, end_frame
label, start_frame, end_frame
...
num_proposals
label, best_iou, overlap_self, start_frame, end_frame
label, best_iou, overlap_self, start_frame, end_frame
...
Example of a standard annotation file:
.. code-block:: txt
# 0
video_validation_0000202
5666
1
3
8 130 185
8 832 1136
8 1303 1381
5
8 0.0620 0.0620 790 5671
8 0.1656 0.1656 790 2619
8 0.0833 0.0833 3945 5671
8 0.0960 0.0960 4173 5671
8 0.0614 0.0614 3327 5671
Args:
video_info (list): Information of the video.
Returns:
tuple[str, int, list, list]:
video_id (str): Name of the video.
num_frames (int): Number of frames in the video.
gt_boxes (list): List of the information of gt boxes.
proposal_boxes (list): List of the information of
proposal boxes.
"""
offset = 0
video_id = video_info[offset]
offset += 1
num_frames = int(float(video_info[1]) * float(video_info[2]))
num_gts = int(video_info[3])
offset = 4
gt_boxes = [x.split() for x in video_info[offset:offset + num_gts]]
offset += num_gts
num_proposals = int(video_info[offset])
offset += 1
proposal_boxes = [
x.split() for x in video_info[offset:offset + num_proposals]
]
return video_id, num_frames, gt_boxes, proposal_boxes
return [parse_group(video_info) for video_info in video_infos]
def perform_regression(detections):
"""Perform regression on detection results.
Args:
detections (list): Detection results before regression.
Returns:
list: Detection results after regression.
"""
starts = detections[:, 0]
ends = detections[:, 1]
centers = (starts + ends) / 2
durations = ends - starts
new_centers = centers + durations * detections[:, 3]
new_durations = durations * np.exp(detections[:, 4])
new_detections = np.concatenate(
(np.clip(new_centers - new_durations / 2, 0,
1)[:, None], np.clip(new_centers + new_durations / 2, 0,
1)[:, None], detections[:, 2:]),
axis=1)
return new_detections
def temporal_nms(detections, threshold):
"""Parse the video's information.
Args:
detections (list): Detection results before NMS.
threshold (float): Threshold of NMS.
Returns:
list: Detection results after NMS.
"""
starts = detections[:, 0]
ends = detections[:, 1]
scores = detections[:, 2]
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
ious = temporal_iou(starts[order[1:]], ends[order[1:]], starts[i],
ends[i])
idxs = np.where(ious <= threshold)[0]
order = order[idxs + 1]
return detections[keep, :]
def eval_ap(detections, gt_by_cls, iou_range):
"""Evaluate average precisions.
Args:
detections (dict): Results of detections.
gt_by_cls (dict): Information of groudtruth.
iou_range (list): Ranges of iou.
Returns:
list: Average precision values of classes at ious.
"""
ap_values = np.zeros((len(detections), len(iou_range)))
for iou_idx, min_overlap in enumerate(iou_range):
for class_idx in range(len(detections)):
ap = average_precision_at_temporal_iou(gt_by_cls[class_idx],
detections[class_idx],
[min_overlap])
ap_values[class_idx, iou_idx] = ap
return ap_values
# 0
test_imgs
5
1
2
3 0.2000 0.4000
3 0.6000 1.0000
10
3 1.0000 1.0000 0.2000 0.4000
3 0.5000 0.5000 0.2000 0.6000
3 0.3333 0.3333 0.2000 0.8000
3 0.5000 0.5000 0.2000 1.0000
3 0.0000 0.0000 0.4000 0.6000
3 0.3333 0.5000 0.4000 0.8000
3 0.6666 0.6666 0.4000 1.0000
3 0.5000 1.0000 0.6000 0.8000
3 1.0000 1.0000 0.6000 1.0000
3 0.5000 1.0000 0.8000 1.0000
# 0
test_imgs
5
1
2
3 1 2
3 3 5
10
3 1.0000 1.0000 1 2
3 0.5000 0.5000 1 3
3 0.3333 0.3333 1 4
3 0.5000 0.5000 1 5
3 0.0000 0.0000 2 3
3 0.3333 0.5000 2 4
3 0.6666 0.6666 2 5
3 0.5000 1.0000 3 4
3 1.0000 1.0000 3 5
3 0.5000 1.0000 4 5
......@@ -6,10 +6,11 @@ import mmcv
import numpy as np
import pytest
import torch
from mmcv import ConfigDict
from numpy.testing import assert_array_equal
from mmaction.datasets import (ActivityNetDataset, RawframeDataset,
RepeatDataset, VideoDataset)
RepeatDataset, SSNDataset, VideoDataset)
class TestDataset(object):
......@@ -30,6 +31,10 @@ class TestDataset(object):
cls.video_ann_file = osp.join(cls.data_prefix, 'video_test_list.txt')
cls.action_ann_file = osp.join(cls.data_prefix,
'action_test_anno.json')
cls.proposal_ann_file = osp.join(cls.data_prefix,
'proposal_test_list.txt')
cls.proposal_norm_ann_file = osp.join(cls.data_prefix,
'proposal_normalized_list.txt')
cls.frame_pipeline = [
dict(
......@@ -49,6 +54,62 @@ class TestDataset(object):
dict(type='OpenCVDecode')
]
cls.action_pipeline = []
cls.proposal_pipeline = [
dict(
type='SampleProposalFrames',
clip_len=1,
body_segments=5,
aug_segments=(2, 2),
aug_ratio=0.5),
dict(type='FrameSelector', io_backend='disk')
]
cls.proposal_test_pipeline = [
dict(
type='SampleProposalFrames',
clip_len=1,
body_segments=5,
aug_segments=(2, 2),
aug_ratio=0.5,
mode='test'),
dict(type='FrameSelector', io_backend='disk')
]
cls.proposal_train_cfg = ConfigDict(
dict(
ssn=dict(
assigner=dict(
positive_iou_threshold=0.7,
background_iou_threshold=0.01,
incomplete_iou_threshold=0.5,
background_coverage_threshold=0.02,
incomplete_overlap_threshold=0.01),
sampler=dict(
num_per_video=8,
positive_ratio=1,
background_ratio=1,
incomplete_ratio=6,
add_gt_as_proposals=True),
loss_weight=dict(
comp_loss_weight=0.1, reg_loss_weight=0.1),
debug=False)))
cls.proposal_test_cfg = ConfigDict(
dict(
ssn=dict(
sampler=dict(test_interval=6, batch_size=16),
evaluater=dict(
top_k=2000,
nms=0.2,
softmax_before_filter=True,
cls_top_k=2))))
cls.proposal_test_cfg_topall = ConfigDict(
dict(
ssn=dict(
sampler=dict(test_interval=6, batch_size=16),
evaluater=dict(
top_k=-1,
nms=0.2,
softmax_before_filter=True,
cls_top_k=2))))
def test_rawframe_dataset(self):
rawframe_dataset = RawframeDataset(self.frame_ann_file,
......@@ -222,6 +283,53 @@ class TestDataset(object):
result = action_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
def test_proposal_pipeline(self):
target_keys = [
'frame_dir', 'video_id', 'total_frames', 'gts', 'proposals',
'filename_tmpl', 'modality', 'out_proposals', 'reg_targets',
'proposal_scale_factor', 'proposal_labels', 'proposal_type',
'start_index'
]
# SSN Dataset not in test mode
proposal_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix)
result = proposal_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
# SSN Dataset with random sampling proposals
proposal_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix,
video_centric=False)
result = proposal_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
target_keys = [
'frame_dir', 'video_id', 'total_frames', 'gts', 'proposals',
'filename_tmpl', 'modality', 'relative_proposal_list',
'scale_factor_list', 'proposal_tick_list', 'reg_norm_consts',
'start_index'
]
# SSN Dataset in test mode
proposal_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_test_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix,
test_mode=True)
result = proposal_dataset[0]
assert self.check_keys_contain(result.keys(), target_keys)
def test_rawframe_evaluate(self):
rawframe_dataset = RawframeDataset(self.frame_ann_file,
self.frame_pipeline,
......@@ -446,3 +554,116 @@ class TestDataset(object):
load_obj,
np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]],
dtype=np.float32))
def test_ssn_dataset(self):
# test ssn dataset
ssn_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix)
ssn_infos = ssn_dataset.video_infos
assert ssn_infos[0]['video_id'] == 'test_imgs'
assert ssn_infos[0]['total_frames'] == 5
# test ssn dataset with verbose
ssn_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix,
verbose=True)
ssn_infos = ssn_dataset.video_infos
assert ssn_infos[0]['video_id'] == 'test_imgs'
assert ssn_infos[0]['total_frames'] == 5
# test ssn datatset with normalized proposal file
with pytest.raises(Exception):
ssn_dataset = SSNDataset(
self.proposal_norm_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix)
ssn_infos = ssn_dataset.video_infos
# test ssn dataset with reg_normalize_constants
ssn_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix,
reg_normalize_constants=[[[-0.0603, 0.0325], [0.0752, 0.1596]]])
ssn_infos = ssn_dataset.video_infos
assert ssn_infos[0]['video_id'] == 'test_imgs'
assert ssn_infos[0]['total_frames'] == 5
# test error case
with pytest.raises(TypeError):
ssn_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix,
aug_ratio=('error', 'error'))
ssn_infos = ssn_dataset.video_infos
def test_ssn_evaluate(self):
ssn_dataset = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg,
data_prefix=self.data_prefix)
ssn_dataset_topall = SSNDataset(
self.proposal_ann_file,
self.proposal_pipeline,
self.proposal_train_cfg,
self.proposal_test_cfg_topall,
data_prefix=self.data_prefix)
with pytest.raises(TypeError):
# results must be a list
ssn_dataset.evaluate('0.5')
with pytest.raises(AssertionError):
# The length of results must be equal to the dataset len
ssn_dataset.evaluate([0] * 5)
with pytest.raises(KeyError):
# unsupported metric
ssn_dataset.evaluate([0] * len(ssn_dataset), metrics='iou')
# evaluate mAP metric
results_relative_proposal_list = np.random.randn(16, 2)
results_activity_scores = np.random.randn(16, 21)
results_completeness_scores = np.random.randn(16, 20)
results_bbox_preds = np.random.randn(16, 20, 2)
results = [[
results_relative_proposal_list, results_activity_scores,
results_completeness_scores, results_bbox_preds
]]
eval_result = ssn_dataset.evaluate(results, metrics=['mAP'])
assert set(eval_result) == set([
'mAP@0.10', 'mAP@0.20', 'mAP@0.30', 'mAP@0.40', 'mAP@0.50',
'mAP@0.50', 'mAP@0.60', 'mAP@0.70', 'mAP@0.80', 'mAP@0.90'
])
# evaluate mAP metric without filtering topk
results_relative_proposal_list = np.random.randn(16, 2)
results_activity_scores = np.random.randn(16, 21)
results_completeness_scores = np.random.randn(16, 20)
results_bbox_preds = np.random.randn(16, 20, 2)
results = [[
results_relative_proposal_list, results_activity_scores,
results_completeness_scores, results_bbox_preds
]]
eval_result = ssn_dataset_topall.evaluate(results, metrics=['mAP'])
assert set(eval_result) == set([
'mAP@0.10', 'mAP@0.20', 'mAP@0.30', 'mAP@0.40', 'mAP@0.50',
'mAP@0.50', 'mAP@0.60', 'mAP@0.70', 'mAP@0.80', 'mAP@0.90'
])
import argparse
import os.path as osp
from tools.data.parse_file_list import parse_directory
from mmaction.localization import load_localize_proposal_file
def process_norm_proposal_file(norm_proposal_file, frame_dict):
"""Process the normalized proposal file and denormalize it.
Args:
norm_proposal_file (str): Name of normalized proposal file.
frame_dict (dict): Information of frame folders.
"""
proposal_file = norm_proposal_file.replace('normalized_', '')
norm_proposals = load_localize_proposal_file(norm_proposal_file)
processed_proposal_list = []
for idx, norm_proposal in enumerate(norm_proposals):
video_id = norm_proposal[0]
frame_info = frame_dict[video_id]
num_frames = frame_info[1]
frame_path = osp.basename(frame_info[0])
gt = [[
int(x[0]),
int(float(x[1]) * num_frames),
int(float(x[2]) * num_frames)
] for x in norm_proposal[2]]
proposal = [[
int(x[0]),
float(x[1]),
float(x[2]),
int(float(x[3]) * num_frames),
int(float(x[4]) * num_frames)
] for x in norm_proposal[3]]
gt_dump = '\n'.join(['{} {} {}'.format(*x) for x in gt])
gt_dump += '\n' if len(gt) else ''
proposal_dump = '\n'.join(
['{} {:.04f} {:.04f} {} {}'.format(*x) for x in proposal])
proposal_dump += '\n' if len(proposal) else ''
processed_proposal_list.append(
f'# {idx}\n{frame_path}\n{num_frames}\n1'
f'\n{len(gt)}\n{gt_dump}{len(proposal)}\n{proposal_dump}')
with open(proposal_file, 'w') as f:
f.writelines(processed_proposal_list)
def parse_args():
parser = argparse.ArgumentParser(description='Denormalize proposal file')
parser.add_argument(
'dataset',
type=str,
choices=['thumos14'],
help='dataset to be denormalize proposal file')
parser.add_argument(
'--norm-proposal-file',
type=str,
help='normalized proposal file to be denormalize')
parser.add_argument(
'--data-prefix',
type=str,
help='path to a directory where rawframes are held')
args = parser.parse_args()
return args
def main():
args = parse_args()
print(f'Converting from {args.norm_proposal_file}.')
frame_dict = parse_directory(args.data_prefix)
process_norm_proposal_file(args.norm_proposal_file, frame_dict)
if __name__ == '__main__':
main()
#!/usr/bin/env bash
cd ../../../
PYTHONPATH=. python tools/data/denormalize_proposal_file.py thumos14 --norm-proposal-file data/thumos14/proposals/thumos14_tag_val_normalized_proposal_list.txt --data-prefix data/thumos14/rawframes/validation/
echo "Proposal file denormalized for val set"
PYTHONPATH=. python tools/data/denormalize_proposal_file.py thumos14 --norm-proposal-file data/thumos14/proposals/thumos14_tag_test_normalized_proposal_list.txt --data-prefix data/thumos14/rawframes/test/
echo "Proposal file denormalized for test set"
cd tools/data/thumos14/
......@@ -58,8 +58,11 @@ cd $MMACTION2/tools/data/thumos14/
bash extract_frames.sh tvl1
```
.
## Step 4. Fetch File List
This part is **optional** if you do not use SSN model.
You can run the follow script to fetch pre-computed tag proposals.
```shell
......@@ -67,7 +70,19 @@ cd $MMACTION2/tools/data/thumos14/
bash fetch_tag_proposals.sh
```
## Step 5. Check Directory Structure
## Step 5. Denormalize Proposal File
This part is **optional** if you do not use SSN model.
You can run the follow script to denormalize pre-computed tag proposals according to
actual number of local rawframes.
```shell
cd $MMACTION2/tools/data/thumos14/
bash denormalize_proposal_file.sh
```
## Step 6. Check Directory Structure
After the whole data process for THUMOS'14 preparation,
you will get the rawframes (RGB + Flow), videos and annotation files for THUMOS'14.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册