提交 e7ed651c 编写于 作者: H huangjun12 提交者: SunGaofeng

add BSN and BMN models to PaddleVideo (#3290)

* add BSN and BMN models to PaddleVideo

* modify codes to fix metrics evaluation and readme

* add copyright for gen_infer_list.py and modify the content of README

* add BSN and BMN in models/README.md
上级 bf13b12c
......@@ -14,10 +14,12 @@
| [TSN](./models/tsn/README.md) | 视频分类| ECCV'16提出的基于2D-CNN经典解决方案 |
| [Non-local](./models/nonlocal_model/README.md) | 视频分类| 视频非局部关联建模模型 |
| [C-TCN](./models/ctcn/README.md) | 视频动作定位| 2018年ActivityNet夺冠方案 |
| [BSN](./models/bsn/README.md) | 视频动作定位| 为视频动作定位问题提供高效的proposal生成方法 |
| [BMN](./models/bmn/README.md) | 视频动作定位| 2019年ActivityNet夺冠方案 |
### 主要特点
- 包含视频分类和动作定位方向的多个主流领先模型,其中Attention LSTM,Attention Cluster和NeXtVLAD是比较流行的特征序列模型,Non-local, TSN, TSM和StNet是End-to-End的视频分类模型。Attention LSTM模型速度快精度高,NeXtVLAD是2nd-Youtube-8M比赛中最好的单模型, TSN是基于2D-CNN的经典解决方案,TSM是基于时序移位的简单高效视频时空建模方法,Non-local模型提出了视频非局部关联建模方法。Attention Cluster和StNet是百度自研模型,分别发表于CVPR2018和AAAI2019,是Kinetics600比赛第一名中使用到的模型。C-TCN动作定位模型也是百度自研,2018年ActivityNet比赛的夺冠方案。
- 包含视频分类和动作定位方向的多个主流领先模型,其中Attention LSTM,Attention Cluster和NeXtVLAD是比较流行的特征序列模型,Non-local, TSN, TSM和StNet是End-to-End的视频分类模型。Attention LSTM模型速度快精度高,NeXtVLAD是2nd-Youtube-8M比赛中最好的单模型, TSN是基于2D-CNN的经典解决方案,TSM是基于时序移位的简单高效视频时空建模方法,Non-local模型提出了视频非局部关联建模方法。Attention Cluster和StNet是百度自研模型,分别发表于CVPR2018和AAAI2019,是Kinetics600比赛第一名中使用到的模型。C-TCN动作定位模型也是百度自研,2018年ActivityNet比赛的夺冠方案。BSN模型采用自底向上的方法生成proposal,为视频动作定位问题中proposal的生成提供高效的解决方案。BMN模型是百度自研模型,2019年ActivityNet夺冠方案。
- 提供了适合视频分类和动作定位任务的通用骨架代码,用户可一键式高效配置模型完成训练和评测。
......@@ -170,9 +172,11 @@ run.sh
- 基于ActivityNet的动作定位模型:
| 模型 | Batch Size | 环境配置 | cuDNN版本 | MAP | 下载链接 |
| 模型 | Batch Size | 环境配置 | cuDNN版本 | 精度 | 下载链接 |
| :-------: | :---: | :---------: | :----: | :----: | :----------: |
| C-TCN | 16 | 8卡P40 | 7.1 | 0.31| [model](https://paddlemodels.bj.bcebos.com/video_detection/CTCN_final.pdparams) |
| C-TCN | 16 | 8卡P40 | 7.1 | 0.31 (MAP) | [model](https://paddlemodels.bj.bcebos.com/video_detection/CTCN_final.pdparams) |
| BSN | 16 | 1卡K40 | 7.0 | 66.64% (AUC) | [model-tem](https://paddlemodels.bj.bcebos.com/video_detection/BsnTem_final.pdparams), [model-pem](https://paddlemodels.bj.bcebos.com/video_detection/BsnPem_final.pdparams) |
| BMN | 16 | 4卡K40 | 7.0 | 67.19% (AUC) | [model](https://paddlemodels.bj.bcebos.com/video_detection/BMN_final.pdparams) |
## 参考文献
......@@ -184,6 +188,9 @@ run.sh
- [Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859), Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool
- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han
- [Non-local Neural Networks](https://arxiv.org/abs/1711.07971v1), Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He
- [Bsn: Boundary sensitive network for temporal action proposal generation](http://arxiv.org/abs/1806.02964), Tianwei Lin, Xu Zhao, Haisheng Su, Chongjing Wang, Ming Yang.
- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
## 版本更新
......
MODEL:
name: "BMN"
tscale: 100
dscale: 100
feat_dim: 400
prop_boundary_ratio: 0.5
num_sample: 32
num_sample_perbin: 3
anno_file: "data/dataset/bmn/activitynet_1.3_annotations.json"
feat_path: '/paddle/PaddleProject/data/fix_feat_100'
TRAIN:
subset: "train"
epoch: 9
batch_size: 16
num_threads: 8
use_gpu: True
num_gpus: 4
learning_rate: 0.001
learning_rate_decay: 0.1
lr_decay_iter: 4200
l2_weight_decay: 1e-4
VALID:
subset: "validation"
batch_size: 16
num_threads: 8
use_gpu: True
num_gpus: 4
TEST:
subset: "validation"
batch_size: 1
num_threads: 1
snms_alpha: 0.001
snms_t1: 0.5
snms_t2: 0.9
output_path: "data/output/EVAL/BMN_results"
result_path: "data/evaluate_results"
INFER:
subset: "test"
batch_size: 1
num_threads: 1
snms_alpha: 0.4
snms_t1: 0.5
snms_t2: 0.9
filelist: 'data/dataset/bmn/infer.list'
output_path: "data/output/INFER/BMN_results"
result_path: "data/predict_results"
MODEL:
name: "BSNPEM"
tscale: 100
top_K: 500
feat_dim: 32
hidden_dim: 128
anno_file: "data/dataset/bmn/activitynet_1.3_annotations.json"
feat_path: "data/output/EVAL/PGM_feature/"
prop_path: "data/output/EVAL/PGM_proposals/"
TRAIN:
subset: "train"
epoch: 15
batch_size: 16
num_threads: 8
use_gpu: True
num_gpus: 1
learning_rate: 0.01
learning_rate_decay: 0.1
lr_decay_iter: 6000
l2_weight_decay: 1e-5
top_K: 500
VALID:
subset: "validation"
batch_size: 16
num_threads: 8
use_gpu: True
num_gpus: 1
top_K: 500
TEST:
subset: "validation"
batch_size: 1
num_threads: 1
snms_alpha: 0.9
snms_t1: 0.004
snms_t2: 0.01
top_K: 1000
num_gpus: 1
output_path_pem: "data/output/EVAL/PEM_results"
result_path_pem: "data/evaluate_results"
INFER:
subset: "test"
filelist: 'data/dataset/bmn/infer.list'
batch_size: 1
num_threads: 1
top_K: 1000
num_gpus: 1
feat_path: "data/output/INFER/PGM_feature/"
prop_path: "data/output/INFER/PGM_proposals/"
output_path_pem: "data/output/INFER/PEM_results"
result_path_pem: "data/predict_results"
MODEL:
name: "BSNTEM"
tscale: 100
feat_dim: 400
hidden_dim: 256
gt_boundary_ratio: 0.1
prop_boundary_ratio: 0.5
num_sample: 32
num_sample_perbin: 3
anno_file: "data/dataset/bmn/activitynet_1.3_annotations.json"
feat_path: '/paddle/PaddleProject/data/fix_feat_100'
pgm_top_K_train: 500
pgm_top_K: 1000
pgm_threshold: 0.5
bsp_boundary_ratio: 0.25
num_sample_start: 8
num_sample_end: 8
num_sample_action: 16
num_sample_perbin: 3
pgm_thread: 12
TRAIN:
subset: "train"
epoch: 9
batch_size: 16
num_threads: 8
use_gpu: True
num_gpus: 1
learning_rate: 0.001
learning_rate_decay: 0.1
lr_decay_iter: 4200
l2_weight_decay: 1e-4
VALID:
subset: "validation"
batch_size: 16
num_threads: 8
use_gpu: True
num_gpus: 1
TEST:
subset: "train_val"
batch_size: 1
num_threads: 1
score_thresh: 0.001
output_path_tem: "data/output/EVAL/TEM_results"
output_path_pgm_proposal: "data/output/EVAL/PGM_proposals"
output_path_pgm_feature: "data/output/EVAL/PGM_feature"
INFER:
subset: "test"
filelist: 'data/dataset/bmn/infer.list'
batch_size: 1
num_threads: 1
output_path_tem: "data/output/INFER/TEM_results"
output_path_pgm_proposal: "data/output/INFER/PGM_proposals"
output_path_pgm_feature: "data/output/INFER/PGM_feature"
# BMN模型数据使用说明
BMN模型使用ActivityNet 1.3数据集,使用方法有如下两种方式:
方式一:
首先参考[下载说明](https://github.com/activitynet/ActivityNet/tree/master/Crawler)下载原始数据集。在训练此模型时,需要先使用TSN对源文件抽取特征。可以[自行抽取](https://github.com/yjxiong/temporal-segment-networks)视频帧及光流信息,预训练好的TSN模型可从[此处](https://github.com/yjxiong/anet2016-cuhk)下载。
方式二:
我们也在[百度网盘](https://pan.baidu.com/s/19GI3_-uZbd_XynUO6g-8YQ)[谷歌云盘](https://drive.google.com/file/d/1ISemndlSDS2FtqQOKL0t3Cjj9yk2yznF/view?usp=sharing)提供了处理好的视频特征。若使用百度网盘下载,在解压前请使用如下命令:
cat zip_csv_mean_100.z* > csv_mean_100.zip
解压完成后,请相应修改configs/bmn.yaml文件中的特征路径feat\_path。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
'''
Generate samples from validation dataset for inference.
'''
import json
annos = json.load(open("activitynet_1.3_annotations.json"))
infer_data = {}
count = 0
for video_name in annos.keys():
if annos[video_name]["subset"] == 'validation':
infer_data[video_name] = annos[video_name]
count += 1
if count == 5:
break
with open('infer.list.json', 'w') as f:
json.dump(infer_data, f)
{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}}
\ No newline at end of file
## ActivityNet 指标计算
- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org)
- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至PaddleVideo目录下。
- 计算精度指标
```cd metrics/bmn_metrics```
```python eval_anet_prop.py```
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
import datetime
import logging
import json
import pandas as pd
from models.bmn.bmn_utils import boundary_choose, soft_nms, bmn_post_processing
import time
import os
logger = logging.getLogger(__name__)
class MetricsCalculator():
def __init__(self, cfg, name='BMN', mode='train'):
self.name = name
self.mode = mode # 'train', 'valid', 'test', 'infer'
self.tscale = cfg["MODEL"]["tscale"]
self.dscale = cfg["MODEL"]["tscale"]
self.subset = cfg[self.mode.upper()]["subset"]
self.anno_file = cfg["MODEL"]["anno_file"]
self.file_list = cfg["INFER"]["filelist"]
self.get_dataset_dict()
self.cols = ["xmin", "xmax", "score"]
self.snippet_xmins = [1.0 / self.tscale * i for i in range(self.tscale)]
self.snippet_xmaxs = [
1.0 / self.tscale * i for i in range(1, self.tscale + 1)
]
if self.mode == "test" or self.mode == "infer":
self.output_path = cfg[self.mode.upper()]["output_path"]
self.result_path = cfg[self.mode.upper()]["result_path"]
self.reset()
def get_dataset_dict(self):
if self.mode == "infer":
annos = json.load(open(self.file_list))
self.video_dict = {}
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
self.video_dict = {}
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
def reset(self):
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_loss = 0.0
self.aggr_tem_loss = 0.0
self.aggr_pem_reg_loss = 0.0
self.aggr_pem_cls_loss = 0.0
self.aggr_batch_size = 0
if self.mode == 'test' or self.mode == "infer":
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
def gen_props(self, pred_bm, pred_start, pred_end, fid):
video_name = self.video_list[fid]
pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :]
start_mask = boundary_choose(pred_start)
start_mask[0] = 1.
end_mask = boundary_choose(pred_end)
end_mask[-1] = 1.
score_vector_list = []
for idx in range(self.dscale):
for jdx in range(self.tscale):
start_index = jdx
end_index = start_index + idx
if end_index < self.tscale and start_mask[
start_index] == 1 and end_mask[end_index] == 1:
xmin = self.snippet_xmins[start_index]
xmax = self.snippet_xmaxs[end_index]
xmin_score = pred_start[start_index]
xmax_score = pred_end[end_index]
bm_score = pred_bm[idx, jdx]
conf_score = xmin_score * xmax_score * bm_score
score_vector_list.append([xmin, xmax, conf_score])
score_vector_list = np.stack(score_vector_list)
video_df = pd.DataFrame(score_vector_list, columns=self.cols)
video_df.to_csv(
os.path.join(self.output_path, "%s.csv" % video_name), index=False)
def accumulate(self, fetch_list):
cur_batch_size = 1 # iteration counter
total_loss = fetch_list[0]
tem_loss = fetch_list[1]
pem_reg_loss = fetch_list[2]
pem_cls_loss = fetch_list[3]
self.aggr_loss += np.mean(np.array(total_loss))
self.aggr_tem_loss += np.mean(np.array(tem_loss))
self.aggr_pem_reg_loss += np.mean(np.array(pem_reg_loss))
self.aggr_pem_cls_loss += np.mean(np.array(pem_cls_loss))
self.aggr_batch_size += cur_batch_size
if self.mode == 'test':
pred_bm = np.array(fetch_list[4])
pred_start = np.array(fetch_list[5])
pred_end = np.array(fetch_list[6])
fid = fetch_list[7][0][0]
self.gen_props(pred_bm, pred_start, pred_end, fid)
def accumulate_infer_results(self, fetch_list):
pred_bm = np.array(fetch_list[0])
pred_start = np.array(fetch_list[1][0])
pred_end = np.array(fetch_list[2][0])
fid = fetch_list[3][0]
self.gen_props(pred_bm, pred_start, pred_end, fid)
def finalize_metrics(self):
self.avg_loss = self.aggr_loss / self.aggr_batch_size
self.avg_tem_loss = self.aggr_tem_loss / self.aggr_batch_size
self.avg_pem_reg_loss = self.aggr_pem_reg_loss / self.aggr_batch_size
self.avg_pem_cls_loss = self.aggr_pem_cls_loss / self.aggr_batch_size
if self.mode == 'test':
bmn_post_processing(self.video_dict, self.subset, self.output_path,
self.result_path)
def finalize_infer_metrics(self):
bmn_post_processing(self.video_dict, self.subset, self.output_path,
self.result_path)
def get_computed_metrics(self):
json_stats = {}
json_stats['avg_loss'] = self.avg_loss
json_stats['avg_tem_loss'] = self.avg_tem_loss
json_stats['avg_pem_reg_loss'] = self.avg_pem_reg_loss
json_stats['avg_pem_cls_loss'] = self.avg_pem_cls_loss
return json_stats
'''
Calculate AR@N and AUC;
Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)
'''
import sys
sys.path.append('../../Evaluation')
from eval_proposal import ANETproposal
import numpy as np
import argparse
parser = argparse.ArgumentParser("Eval AR vs AN of proposal")
parser.add_argument(
'--eval_file',
type=str,
default='bmn_results_validation.json',
help='name of results file to eval')
def run_evaluation(ground_truth_filename,
proposal_filename,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation'):
anet_proposal = ANETproposal(
ground_truth_filename,
proposal_filename,
tiou_thresholds=tiou_thresholds,
max_avg_nr_proposals=max_avg_nr_proposals,
subset=subset,
verbose=True,
check_status=False)
anet_proposal.evaluate()
recall = anet_proposal.recall
average_recall = anet_proposal.avg_recall
average_nr_proposals = anet_proposal.proposals_per_video
return (average_nr_proposals, average_recall, recall)
def plot_metric(average_nr_proposals,
average_recall,
recall,
tiou_thresholds=np.linspace(0.5, 0.95, 10)):
fn_size = 14
plt.figure(num=None, figsize=(12, 8))
ax = plt.subplot(1, 1, 1)
colors = [
'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo'
]
area_under_curve = np.zeros_like(tiou_thresholds)
for i in range(recall.shape[0]):
area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
for idx, tiou in enumerate(tiou_thresholds[::2]):
ax.plot(
average_nr_proposals,
recall[2 * idx, :],
color=colors[idx + 1],
label="tiou=[" + str(tiou) + "], area=" + str(
int(area_under_curve[2 * idx] * 100) / 100.),
linewidth=4,
linestyle='--',
marker=None)
# Plots Average Recall vs Average number of proposals.
ax.plot(
average_nr_proposals,
average_recall,
color=colors[0],
label="tiou = 0.5:0.05:0.95," + " area=" + str(
int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.),
linewidth=4,
linestyle='-',
marker=None)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
[handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
plt.ylabel('Average Recall', fontsize=fn_size)
plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
plt.grid(b=True, which="both")
plt.ylim([0, 1.0])
plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
plt.show()
if __name__ == "__main__":
args = parser.parse_args()
eval_file = args.eval_file
eval_file_path = "../../data/evaluate_results/" + eval_file
uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
"../../Evaluation/data/activity_net_1_3_new.json",
eval_file_path,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation')
print("AR@1; AR@5; AR@10; AR@100")
print("%.02f %.02f %.02f %.02f" %
(100 * np.mean(uniform_recall_valid[:, 0]),
100 * np.mean(uniform_recall_valid[:, 4]),
100 * np.mean(uniform_recall_valid[:, 9]),
100 * np.mean(uniform_recall_valid[:, -1])))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
import datetime
import logging
import json
import pandas as pd
from models.bsn.bsn_utils import soft_nms, bsn_post_processing
import time
logger = logging.getLogger(__name__)
import os
class MetricsCalculator():
def __init__(self, cfg, name='BsnPem', mode='train'):
self.name = name
self.mode = mode # 'train', 'valid', 'test', 'infer'
self.subset = cfg[self.mode.upper()][
"subset"] # 'train', 'validation', 'test'
self.anno_file = cfg["MODEL"]["anno_file"]
self.file_list = cfg["INFER"]["filelist"]
self.get_dataset_dict()
if self.mode == "test" or self.mode == "infer":
self.output_path_pem = cfg[self.mode.upper()]["output_path_pem"]
self.result_path_pem = cfg[self.mode.upper()]["result_path_pem"]
self.reset()
def get_dataset_dict(self):
if self.mode == "infer":
annos = json.load(open(self.file_list))
self.video_dict = {}
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
self.video_dict = {}
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
def reset(self):
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_loss = 0.0
self.aggr_batch_size = 0
if self.mode == 'test' or self.mode == 'infer':
if not os.path.exists(self.output_path_pem):
os.makedirs(self.output_path_pem)
def save_results(self, pred_iou, props_info, fid):
video_name = self.video_list[fid]
df = pd.DataFrame()
df["xmin"] = props_info[0, 0, :]
df["xmax"] = props_info[0, 1, :]
df["xmin_score"] = props_info[0, 2, :]
df["xmax_score"] = props_info[0, 3, :]
df["iou_score"] = pred_iou.squeeze()
df.to_csv(
os.path.join(self.output_path_pem, video_name + ".csv"),
index=False)
def accumulate(self, fetch_list):
cur_batch_size = 1 # iteration counter
total_loss = fetch_list[0]
self.aggr_loss += np.mean(np.array(total_loss))
self.aggr_batch_size += cur_batch_size
if self.mode == 'test':
pred_iou = np.array(fetch_list[1])
props_info = np.array(fetch_list[2])
fid = fetch_list[3][0][0]
self.save_results(pred_iou, props_info, fid)
def accumulate_infer_results(self, fetch_list):
pred_iou = np.array(fetch_list[0])
props_info = np.array(fetch_list[1])
fid = fetch_list[2][0]
self.save_results(pred_iou, props_info, fid)
def finalize_metrics(self):
self.avg_loss = self.aggr_loss / self.aggr_batch_size
if self.mode == 'test':
bsn_post_processing(self.video_dict, self.subset,
self.output_path_pem, self.result_path_pem)
def finalize_infer_metrics(self):
bsn_post_processing(self.video_dict, self.subset, self.output_path_pem,
self.result_path_pem)
def get_computed_metrics(self):
json_stats = {}
json_stats['avg_loss'] = self.avg_loss
return json_stats
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
import datetime
import logging
import json
import pandas as pd
from models.bsn.bsn_utils import pgm_gen_proposal, pgm_gen_feature
import time
logger = logging.getLogger(__name__)
import os
class MetricsCalculator():
def __init__(self, cfg, name='BsnTem', mode='train'):
self.name = name
self.mode = mode # 'train', 'valid', 'test', 'infer'
self.tscale = cfg["MODEL"]["tscale"]
self.subset = cfg[self.mode.upper()][
"subset"] # 'train', 'validation', 'train_val'
self.anno_file = cfg["MODEL"]["anno_file"]
self.file_list = cfg["INFER"]["filelist"]
self.get_pgm_cfg(cfg)
self.get_dataset_dict()
self.cols = ["xmin", "xmax", "score"]
self.snippet_xmins = [1.0 / self.tscale * i for i in range(self.tscale)]
self.snippet_xmaxs = [
1.0 / self.tscale * i for i in range(1, self.tscale + 1)
]
if self.mode == "test" or self.mode == "infer":
print('1212')
self.output_path_tem = cfg[self.mode.upper()]["output_path_tem"]
self.output_path_pgm_feature = cfg[self.mode.upper()][
"output_path_pgm_feature"]
self.output_path_pgm_proposal = cfg[self.mode.upper()][
"output_path_pgm_proposal"]
self.reset()
def get_dataset_dict(self):
if self.mode == "infer":
annos = json.load(open(self.file_list))
self.video_dict = {}
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
self.video_dict = {}
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset == "train_val":
if "train" in video_subset or "validation" in video_subset:
self.video_dict[video_name] = annos[video_name]
else:
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
def get_pgm_cfg(self, cfg):
self.pgm_config = {}
if self.mode == "test" or self.mode == "infer":
self.pgm_config["tscale"] = self.tscale
self.pgm_config["pgm_threshold"] = cfg["MODEL"]["pgm_threshold"]
self.pgm_config["pgm_top_K_train"] = cfg["MODEL"]["pgm_top_K_train"]
self.pgm_config["pgm_top_K"] = cfg["MODEL"]["pgm_top_K"]
self.pgm_config["bsp_boundary_ratio"] = cfg["MODEL"][
"bsp_boundary_ratio"]
self.pgm_config["num_sample_start"] = cfg["MODEL"][
"num_sample_start"]
self.pgm_config["num_sample_end"] = cfg["MODEL"]["num_sample_end"]
self.pgm_config["num_sample_action"] = cfg["MODEL"][
"num_sample_action"]
self.pgm_config["num_sample_perbin"] = cfg["MODEL"][
"num_sample_perbin"]
self.pgm_config["pgm_thread"] = cfg["MODEL"]["pgm_thread"]
def reset(self):
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_loss = 0.0
self.aggr_start_loss = 0.0
self.aggr_end_loss = 0.0
self.aggr_action_loss = 0.0
self.aggr_batch_size = 0
if self.mode == 'test' or self.mode == 'infer':
if not os.path.exists(self.output_path_tem):
os.makedirs(self.output_path_tem)
if not os.path.exists(self.output_path_pgm_feature):
os.makedirs(self.output_path_pgm_feature)
if not os.path.exists(self.output_path_pgm_proposal):
os.makedirs(self.output_path_pgm_proposal)
def save_results(self, pred_tem, fid):
video_name = self.video_list[fid]
pred_start = pred_tem[0, 0, :]
pred_end = pred_tem[0, 1, :]
pred_action = pred_tem[0, 2, :]
output_tem = np.stack([pred_start, pred_end, pred_action], axis=1)
video_df = pd.DataFrame(output_tem, columns=["start", "end", "action"])
video_df.to_csv(
os.path.join(self.output_path_tem, video_name + ".csv"),
index=False)
def accumulate(self, fetch_list):
cur_batch_size = 1 # iteration counter
total_loss = fetch_list[0]
start_loss = fetch_list[1]
end_loss = fetch_list[2]
action_loss = fetch_list[3]
self.aggr_loss += np.mean(np.array(total_loss))
self.aggr_start_loss += np.mean(np.array(start_loss))
self.aggr_end_loss += np.mean(np.array(end_loss))
self.aggr_action_loss += np.mean(np.array(action_loss))
self.aggr_batch_size += cur_batch_size
if self.mode == 'test':
pred_tem = np.array(fetch_list[4])
fid = fetch_list[5][0][0]
self.save_results(pred_tem, fid)
def accumulate_infer_results(self, fetch_list):
pred_tem = np.array(fetch_list[0])
fid = fetch_list[1][0]
self.save_results(pred_tem, fid)
def finalize_metrics(self):
self.avg_loss = self.aggr_loss / self.aggr_batch_size
self.avg_start_loss = self.aggr_start_loss / self.aggr_batch_size
self.avg_end_loss = self.aggr_end_loss / self.aggr_batch_size
self.avg_action_loss = self.aggr_action_loss / self.aggr_batch_size
if self.mode == 'test':
print("start generate proposals of %s subset" % self.subset)
pgm_gen_proposal(self.video_dict, self.pgm_config,
self.output_path_tem,
self.output_path_pgm_proposal)
print("finish generate proposals of %s subset" % self.subset)
print("start generate proposals feature of %s subset" % self.subset)
pgm_gen_feature(self.video_dict, self.pgm_config,
self.output_path_tem, self.output_path_pgm_proposal,
self.output_path_pgm_feature)
print("finish generate proposals feature of %s subset" %
self.subset)
def finalize_infer_metrics(self):
print("start generate proposals of %s subset" % self.subset)
pgm_gen_proposal(self.video_dict, self.pgm_config, self.output_path_tem,
self.output_path_pgm_proposal)
print("finish generate proposals of %s subset" % self.subset)
print("start generate proposals feature of %s subset" % self.subset)
pgm_gen_feature(self.video_dict, self.pgm_config, self.output_path_tem,
self.output_path_pgm_proposal,
self.output_path_pgm_feature)
print("finish generate proposals feature of %s subset" % self.subset)
def get_computed_metrics(self):
json_stats = {}
json_stats['avg_loss'] = self.avg_loss
json_stats['avg_start_loss'] = self.avg_start_loss
json_stats['avg_end_loss'] = self.avg_end_loss
json_stats['avg_action_loss'] = self.avg_action_loss
return json_stats
......@@ -25,6 +25,9 @@ from metrics.youtube8m import eval_util as youtube8m_metrics
from metrics.kinetics import accuracy_metrics as kinetics_metrics
from metrics.multicrop_test import multicrop_test_metrics as multicrop_test_metrics
from metrics.detections import detection_metrics as detection_metrics
from metrics.bmn_metrics import bmn_proposal_metrics as bmn_proposal_metrics
from metrics.bsn_metrics import bsn_tem_metrics as bsn_tem_metrics
from metrics.bsn_metrics import bsn_pem_metrics as bsn_pem_metrics
logger = logging.getLogger(__name__)
......@@ -301,6 +304,123 @@ class DetectionMetrics(Metrics):
self.calculator.reset()
class BmnMetrics(Metrics):
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.calculator = bmn_proposal_metrics.MetricsCalculator(
cfg=cfg, name=self.name, mode=self.mode)
def calculate_and_log_out(self, fetch_list, info=''):
total_loss = np.array(fetch_list[0])
tem_loss = np.array(fetch_list[1])
pem_reg_loss = np.array(fetch_list[2])
pem_cls_loss = np.array(fetch_list[3])
logger.info(
info + '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format(
'%.04f' % np.mean(total_loss), '%.04f' % np.mean(tem_loss), \
'%.04f' % np.mean(pem_reg_loss), '%.04f' % np.mean(pem_cls_loss)))
def accumulate(self, fetch_list):
if self.mode == 'infer':
self.calculator.accumulate_infer_results(fetch_list)
else:
self.calculator.accumulate(fetch_list)
def finalize_and_log_out(self, info='', savedir='./'):
if self.mode == 'infer':
self.calculator.finalize_infer_metrics()
else:
self.calculator.finalize_metrics()
metrics_dict = self.calculator.get_computed_metrics()
loss = metrics_dict['avg_loss']
tem_loss = metrics_dict['avg_tem_loss']
pem_reg_loss = metrics_dict['avg_pem_reg_loss']
pem_cls_loss = metrics_dict['avg_pem_cls_loss']
logger.info(
info +
'\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.
format('%.04f' % loss, '%.04f' % tem_loss, '%.04f' %
pem_reg_loss, '%.04f' % pem_cls_loss))
def reset(self):
self.calculator.reset()
class BsnTemMetrics(Metrics):
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.calculator = bsn_tem_metrics.MetricsCalculator(
cfg=cfg, name=self.name, mode=self.mode)
def calculate_and_log_out(self, fetch_list, info=''):
total_loss = np.array(fetch_list[0])
start_loss = np.array(fetch_list[1])
end_loss = np.array(fetch_list[2])
action_loss = np.array(fetch_list[3])
logger.info(
info +
'\tLoss = {}, \tstart_loss = {}, \tend_loss = {}, \taction_loss = {}'.
format('%.04f' % np.mean(total_loss), '%.04f' % np.mean(start_loss),
'%.04f' % np.mean(end_loss), '%.04f' % np.mean(action_loss)))
def accumulate(self, fetch_list):
if self.mode == 'infer':
self.calculator.accumulate_infer_results(fetch_list)
else:
self.calculator.accumulate(fetch_list)
def finalize_and_log_out(self, info='', savedir='./'):
if self.mode == 'infer':
self.calculator.finalize_infer_metrics()
else:
self.calculator.finalize_metrics()
metrics_dict = self.calculator.get_computed_metrics()
loss = metrics_dict['avg_loss']
start_loss = metrics_dict['avg_start_loss']
end_loss = metrics_dict['avg_end_loss']
action_loss = metrics_dict['avg_action_loss']
logger.info(
info +
'\tLoss = {}, \tstart_loss = {}, \tend_loss = {}, \taction_loss = {}'.
format('%.04f' % loss, '%.04f' % start_loss, '%.04f' % end_loss,
'%.04f' % action_loss))
def reset(self):
self.calculator.reset()
class BsnPemMetrics(Metrics):
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.calculator = bsn_pem_metrics.MetricsCalculator(
cfg=cfg, name=self.name, mode=self.mode)
def calculate_and_log_out(self, fetch_list, info=''):
total_loss = np.array(fetch_list[0])
logger.info(info + '\tLoss = {}'.format('%.04f' % np.mean(total_loss)))
def accumulate(self, fetch_list):
if self.mode == 'infer':
self.calculator.accumulate_infer_results(fetch_list)
else:
self.calculator.accumulate(fetch_list)
def finalize_and_log_out(self, info='', savedir='./'):
if self.mode == 'infer':
self.calculator.finalize_infer_metrics()
else:
self.calculator.finalize_metrics()
metrics_dict = self.calculator.get_computed_metrics()
loss = metrics_dict['avg_loss']
logger.info(info + '\tLoss = {}'.format('%.04f' % loss))
def reset(self):
self.calculator.reset()
class MetricsZoo(object):
def __init__(self):
self.metrics_zoo = {}
......@@ -338,3 +458,6 @@ regist_metrics("TSM", Kinetics400Metrics)
regist_metrics("TSN", Kinetics400Metrics)
regist_metrics("STNET", Kinetics400Metrics)
regist_metrics("CTCN", DetectionMetrics)
regist_metrics("BMN", BmnMetrics)
regist_metrics("BSNTEM", BsnTemMetrics)
regist_metrics("BSNPEM", BsnPemMetrics)
......@@ -7,6 +7,9 @@ from .tsm import TSM
from .tsn import TSN
from .stnet import STNET
from .ctcn import CTCN
from .bmn import BMN
from .bsn import BsnTem
from .bsn import BsnPem
# regist models, sort by alphabet
regist_model("AttentionCluster", AttentionCluster)
......@@ -17,3 +20,6 @@ regist_model("TSM", TSM)
regist_model("TSN", TSN)
regist_model("STNET", STNET)
regist_model("CTCN", CTCN)
regist_model("BMN", BMN)
regist_model("BsnTem", BsnTem)
regist_model("BsnPem", BsnPem)
# BMN 视频动作定位模型
---
## 内容
- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。
<p align="center">
<img src="../../images/BMN.png" height=300 width=500 hspace='10'/> <br />
BMN Overview
</p>
## 数据准备
BMN的训练数据采用ActivityNet1.3提供的数据集,数据下载及准备请参考[数据说明](../../data/dataset/bmn/README.md)
## 模型训练
数据准备完毕后,可以通过如下两种方式启动训练:
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_fast_eager_deletion_mode=1
python train.py --model_name=BMN \
--config=./configs/bmn.yaml \
--log_interval=10 \
--valid_interval=1 \
--use_gpu=True \
--save_dir=./data/checkpoints
bash run.sh train BMN ./configs/bmn.yaml
- 代码运行需要先安装pandas
- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
- 若使用第二种方式,请在run.sh脚本文件中设置4卡训练:
export CUDA_VISIBLE_DEVICES=0,1,2,3
**训练策略:**
* 采用Adam优化器,初始learning\_rate=0.001
* 权重衰减系数为1e-4
* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
## 模型评估
可通过如下两种方式进行模型评估:
python eval.py --model_name=BMN \
--config=./configs/bmn.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--use_gpu=True
bash run.sh eval BMN ./configs/bmn.yaml
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/BMN_final.pdparams)进行评估。
- 上述程序会将运行结果保存在data/output/EVAL\BMN\_results文件夹下,测试结果保存在data/evaluate\_results/bmn\_results\_validation.json文件中。使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程请参考[指标计算](../../metrics/bmn_metrics/README.md)
- 使用CPU进行评估时,请将上面的命令行或者run.sh脚本中的`use_gpu`设置为False。
在ActivityNet1.3数据集下评估精度如下:
| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
| :---: | :---: | :---: | :---: | :---: |
| 33.06 | 49.13 | 56.27 | 75.32 | 67.19% |
## 模型推断
可通过如下两种方式启动模型推断:
python predict.py --model_name=BMN \
--config=./configs/bmn.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST \
--use_gpu=True
bash run.sh predict BMN ./configs/bmn.yaml
- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为data/dataset/bmn/infer.list。`--weights`参数为训练好的权重参数,如果不设置,程序会自动下载已训练好的权重。
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要用到的权重。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/BMN_final.pdparams)进行推断。
- 上述程序会将运行结果保存在data/output/INFER/BMN\_results文件夹下,测试结果保存在data/predict\_results/bmn\_results\_test.json文件中。
- 使用CPU进行推断时,请将命令行或者run.sh脚本中的`use_gpu`设置为False
## 参考论文
- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
from ..model import ModelBase
from . import bmn_net
import logging
logger = logging.getLogger(__name__)
__all__ = ["BMN"]
class BMN(ModelBase):
"""BMN model"""
def __init__(self, name, cfg, mode='train'):
super(BMN, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
self.tscale = self.get_config_from_sec('MODEL', 'tscale')
self.dscale = self.get_config_from_sec('MODEL', 'dscale')
self.feat_dim = self.get_config_from_sec('MODEL', 'feat_dim')
self.prop_boundary_ratio = self.get_config_from_sec(
'MODEL', 'prop_boundary_ratio')
self.num_sample = self.get_config_from_sec('MODEL', 'num_sample')
self.num_sample_perbin = self.get_config_from_sec('MODEL',
'num_sample_perbin')
self.num_epochs = self.get_config_from_sec('train', 'epoch')
self.base_learning_rate = self.get_config_from_sec('train',
'learning_rate')
self.learning_rate_decay = self.get_config_from_sec(
'train', 'learning_rate_decay')
self.l2_weight_decay = self.get_config_from_sec('train',
'l2_weight_decay')
self.lr_decay_iter = self.get_config_from_sec('train', 'lr_decay_iter')
def build_input(self, use_pyreader=True):
feat_shape = [self.feat_dim, self.tscale]
gt_iou_map_shape = [self.dscale, self.tscale]
gt_start_shape = [self.tscale]
gt_end_shape = [self.tscale]
fileid_shape = [1]
self.use_pyreader = use_pyreader
# set init data to None
py_reader = None
feat = None
gt_iou_map = None
gt_start = None
gt_end = None
fileid = None
feat = fluid.layers.data(name='feat', shape=feat_shape, dtype='float32')
feed_list = []
feed_list.append(feat)
if (self.mode == 'train') or (self.mode == 'valid'):
gt_start = fluid.layers.data(
name='gt_start', shape=gt_start_shape, dtype='float32')
gt_end = fluid.layers.data(
name='gt_end', shape=gt_end_shape, dtype='float32')
gt_iou_map = fluid.layers.data(
name='gt_iou_map', shape=gt_iou_map_shape, dtype='float32')
feed_list.append(gt_iou_map)
feed_list.append(gt_start)
feed_list.append(gt_end)
elif self.mode == 'test':
gt_start = fluid.layers.data(
name='gt_start', shape=gt_start_shape, dtype='float32')
gt_end = fluid.layers.data(
name='gt_end', shape=gt_end_shape, dtype='float32')
gt_iou_map = fluid.layers.data(
name='gt_iou_map', shape=gt_iou_map_shape, dtype='float32')
feed_list.append(gt_iou_map)
feed_list.append(gt_start)
feed_list.append(gt_end)
fileid = fluid.layers.data(
name='fileid', shape=fileid_shape, dtype='int64')
feed_list.append(fileid)
elif self.mode == 'infer':
# only image feature input when inference
pass
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
self.py_reader = fluid.io.PyReader(
feed_list=feed_list, capacity=8, iterable=True)
self.feat_input = [feat]
self.gt_iou_map = gt_iou_map
self.gt_start = gt_start
self.gt_end = gt_end
self.fileid = fileid
def create_model_args(self):
cfg = {}
cfg['tscale'] = self.tscale
cfg['dscale'] = self.dscale
cfg['prop_boundary_ratio'] = self.prop_boundary_ratio
cfg['num_sample'] = self.num_sample
cfg['num_sample_perbin'] = self.num_sample_perbin
return cfg
def build_model(self):
cfg = self.create_model_args()
self.videomodel = bmn_net.BMN_NET(mode=self.mode, cfg=cfg)
pred_bm, pred_start, pred_end = self.videomodel.net(
input=self.feat_input[0])
self.network_outputs = [pred_bm, pred_start, pred_end]
self.bm_mask = self.videomodel.bm_mask
def optimizer(self):
bd = [self.lr_decay_iter]
base_lr = self.base_learning_rate
lr_decay = self.learning_rate_decay
lr = [base_lr, base_lr * lr_decay]
l2_weight_decay = self.l2_weight_decay
optimizer = fluid.optimizer.Adam(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=l2_weight_decay))
return optimizer
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
self.loss_ = self.videomodel.bmn_loss_func(
self.network_outputs[0], self.network_outputs[1],
self.network_outputs[2], self.gt_iou_map, self.gt_start,
self.gt_end, self.bm_mask)
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
if (self.mode == 'train') or (self.mode == 'valid'):
return self.feat_input + [
self.gt_iou_map, self.gt_start, self.gt_end
]
elif self.mode == 'test':
return self.feat_input + [
self.gt_iou_map, self.gt_start, self.gt_end, self.fileid
]
elif self.mode == 'infer':
return self.feat_input
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
def fetches(self):
if (self.mode == 'train') or (self.mode == 'valid'):
losses = self.loss()
fetch_list = [item for item in losses]
elif self.mode == 'test':
losses = self.loss()
preds = self.outputs()
fetch_list = [item for item in losses] + \
[item for item in preds] + \
[self.fileid]
elif self.mode == 'infer':
preds = self.outputs()
fetch_list = [item for item in preds]
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
return fetch_list
def pretrain_info(self):
return (None, None)
def weights_info(self):
return (
'BMN_final.pdparams',
'https://paddlemodels.bj.bcebos.com/video_detection/BMN_final.pdparams'
)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
import math
DATATYPE = 'float32'
class BMN_NET(object):
def __init__(self, mode, cfg):
self.tscale = cfg["tscale"]
self.dscale = cfg["dscale"]
self.prop_boundary_ratio = cfg["prop_boundary_ratio"]
self.num_sample = cfg["num_sample"]
self.num_sample_perbin = cfg["num_sample_perbin"]
self.is_training = (mode == 'train')
self.hidden_dim_1d = 256
self.hidden_dim_2d = 128
self.hidden_dim_3d = 512
self._get_interp1d_mask()
self._get_mask()
def conv1d(self,
input,
num_k=256,
input_size=256,
size_k=3,
padding=1,
groups=1,
act='relu',
name="conv1d"):
fan_in = input_size * size_k * 1
k = 1. / math.sqrt(fan_in)
param_attr = fluid.initializer.Uniform(low=-k, high=k)
bias_attr = fluid.initializer.Uniform(low=-k, high=k)
input = fluid.layers.unsqueeze(input=input, axes=[2])
conv = fluid.layers.conv2d(
input=input,
num_filters=num_k,
filter_size=(1, size_k),
stride=1,
padding=(0, padding),
groups=groups,
act=act,
name=name,
param_attr=param_attr,
bias_attr=bias_attr)
conv = fluid.layers.squeeze(input=conv, axes=[2])
return conv
def conv2d(self,
input,
num_k=256,
size_k=3,
padding=1,
act='relu',
name='conv2d'):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_k,
filter_size=size_k,
stride=1,
padding=padding,
act=act,
name=name)
return conv
def conv3d(self, input, num_k=512, name="PEM_3d"):
conv = fluid.layers.conv3d(
input=input,
num_filters=num_k,
filter_size=(self.num_sample, 1, 1),
stride=(self.num_sample, 1, 1),
padding=0,
act='relu',
name=name)
return conv
def net(self, input):
# Base Module of BMN
x_1d = self.conv1d(
input,
input_size=400,
num_k=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu",
name="Base_1")
x_1d = self.conv1d(
x_1d,
num_k=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu",
name="Base_2")
# Temporal Evaluation Module of BMN
x_1d_s = self.conv1d(
x_1d,
num_k=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu",
name="TEM_s1")
x_1d_s = self.conv1d(
x_1d_s, num_k=1, size_k=1, padding=0, act="sigmoid", name="TEM_s2")
x_1d_e = self.conv1d(
x_1d,
num_k=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu",
name="TEM_e1")
x_1d_e = self.conv1d(
x_1d_e, num_k=1, size_k=1, padding=0, act="sigmoid", name="TEM_e2")
x_1d_s = fluid.layers.squeeze(input=x_1d_s, axes=[1])
x_1d_e = fluid.layers.squeeze(input=x_1d_e, axes=[1])
# Proposal Evaluation Module of BMN
x_1d = self.conv1d(
x_1d,
num_k=self.hidden_dim_2d,
size_k=3,
padding=1,
act="relu",
name="PEM_1d")
x_3d = self._boundary_matching_layer(x_1d)
x_3d = self.conv3d(x_3d, self.hidden_dim_3d, name="PEM_3d1")
x_2d = fluid.layers.squeeze(input=x_3d, axes=[2])
x_2d = self.conv2d(
x_2d,
self.hidden_dim_2d,
size_k=1,
padding=0,
act='relu',
name="PEM_2d1")
x_2d = self.conv2d(
x_2d,
self.hidden_dim_2d,
size_k=3,
padding=1,
act='relu',
name="PEM_2d2")
x_2d = self.conv2d(
x_2d,
self.hidden_dim_2d,
size_k=3,
padding=1,
act='relu',
name="PEM_2d3")
x_2d = self.conv2d(
x_2d, 2, size_k=1, padding=0, act='sigmoid', name="PEM_2d4")
return x_2d, x_1d_s, x_1d_e
def _get_mask(self):
bm_mask = []
for idx in range(self.dscale):
mask_vector = [1 for i in range(self.tscale - idx)
] + [0 for i in range(idx)]
bm_mask.append(mask_vector)
bm_mask = np.array(bm_mask, dtype=np.float32)
self.bm_mask = fluid.layers.create_global_var(
shape=[self.dscale, self.tscale],
value=0,
dtype=DATATYPE,
persistable=True)
fluid.layers.assign(bm_mask, self.bm_mask)
self.bm_mask.stop_gradient = True
def _boundary_matching_layer(self, x):
out = fluid.layers.matmul(x, self.sample_mask)
out = fluid.layers.reshape(
x=out, shape=[0, 0, -1, self.dscale, self.tscale])
return out
def _get_interp1d_bin_mask(self, seg_xmin, seg_xmax, tscale, num_sample,
num_sample_perbin):
# generate sample mask for a boundary-matching pair
plen = float(seg_xmax - seg_xmin)
plen_sample = plen / (num_sample * num_sample_perbin - 1.0)
total_samples = [
seg_xmin + plen_sample * ii
for ii in range(num_sample * num_sample_perbin)
]
p_mask = []
for idx in range(num_sample):
bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) *
num_sample_perbin]
bin_vector = np.zeros([tscale])
for sample in bin_samples:
sample_upper = math.ceil(sample)
sample_decimal, sample_down = math.modf(sample)
if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0:
bin_vector[int(sample_down)] += 1 - sample_decimal
if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0:
bin_vector[int(sample_upper)] += sample_decimal
bin_vector = 1.0 / num_sample_perbin * bin_vector
p_mask.append(bin_vector)
p_mask = np.stack(p_mask, axis=1)
return p_mask
def _get_interp1d_mask(self):
# generate sample mask for each point in Boundary-Matching Map
mask_mat = []
for start_index in range(self.tscale):
mask_mat_vector = []
for duration_index in range(self.dscale):
if start_index + duration_index < self.tscale:
p_xmin = start_index
p_xmax = start_index + duration_index
center_len = float(p_xmax - p_xmin) + 1
sample_xmin = p_xmin - center_len * self.prop_boundary_ratio
sample_xmax = p_xmax + center_len * self.prop_boundary_ratio
p_mask = self._get_interp1d_bin_mask(
sample_xmin, sample_xmax, self.tscale, self.num_sample,
self.num_sample_perbin)
else:
p_mask = np.zeros([self.tscale, self.num_sample])
mask_mat_vector.append(p_mask)
mask_mat_vector = np.stack(mask_mat_vector, axis=2)
mask_mat.append(mask_mat_vector)
mask_mat = np.stack(mask_mat, axis=3)
mask_mat = mask_mat.astype(np.float32)
self.sample_mask = fluid.layers.create_parameter(
shape=[self.tscale, self.num_sample, self.dscale, self.tscale],
dtype=DATATYPE,
attr=fluid.ParamAttr(
name="sample_mask", trainable=False),
default_initializer=fluid.initializer.NumpyArrayInitializer(
mask_mat))
self.sample_mask = fluid.layers.reshape(
x=self.sample_mask, shape=[self.tscale, -1], inplace=True)
self.sample_mask.stop_gradient = True
def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end):
def bi_loss(pred_score, gt_label):
pred_score = fluid.layers.reshape(
x=pred_score, shape=[-1], inplace=True)
gt_label = fluid.layers.reshape(
x=gt_label, shape=[-1], inplace=False)
gt_label.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE)
num_entries = fluid.layers.cast(
fluid.layers.shape(pmask), dtype=DATATYPE)
num_positive = fluid.layers.cast(
fluid.layers.reduce_sum(pmask), dtype=DATATYPE)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
loss = -1 * (loss_pos + loss_neg)
return loss
loss_start = bi_loss(pred_start, gt_start)
loss_end = bi_loss(pred_end, gt_end)
loss = loss_start + loss_end
return loss
def pem_reg_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE)
u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3)
u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.)
u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
u_lmask = fluid.layers.elementwise_mul(u_lmask, mask)
num_h = fluid.layers.cast(
fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
num_m = fluid.layers.cast(
fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
num_l = fluid.layers.cast(
fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
r_m = num_h / num_m
u_smmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
r_l = num_h / num_l
u_slmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
weights = u_hmask + u_smmask + u_slmask
weights.stop_gradient = True
loss = fluid.layers.square_error_cost(pred_score, gt_iou_map)
loss = fluid.layers.elementwise_mul(loss, weights)
loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
weights)
return loss
def pem_cls_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
gt_iou_map.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE)
nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE)
nmask = fluid.layers.elementwise_mul(nmask, mask)
num_positive = fluid.layers.reduce_sum(pmask)
num_entries = num_positive + fluid.layers.reduce_sum(nmask)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), nmask)
loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg)
loss = -1 * (loss_pos + loss_neg) / num_entries
return loss
def bmn_loss_func(self, pred_bm, pred_start, pred_end, gt_iou_map, gt_start,
gt_end, bm_mask):
pred_bm_reg = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[0], ends=[1]),
axes=[1])
pred_bm_cls = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[1], ends=[2]),
axes=[1])
pem_reg_loss = self.pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask)
pem_cls_loss = self.pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask)
tem_loss = self.tem_loss_func(pred_start, pred_end, gt_start, gt_end)
loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
loss.persistable = True
return loss, tem_loss, pem_reg_loss, pem_cls_loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
from paddle.fluid.initializer import Uniform
import pandas as pd
import multiprocessing as mp
import json
import os
def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute jaccard score between a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
union_len = len_anchors - inter_len + box_max - box_min
#print inter_len,union_len
jaccard = np.divide(inter_len, union_len)
return jaccard
def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute intersection between score a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
scores = np.divide(inter_len, len_anchors)
return scores
def boundary_choose(score_list):
max_score = max(score_list)
mask_high = (score_list > max_score * 0.5)
score_list = list(score_list)
score_middle = np.array([0.0] + score_list + [0.0])
score_front = np.array([0.0, 0.0] + score_list)
score_back = np.array(score_list + [0.0, 0.0])
mask_peak = ((score_middle > score_front) & (score_middle > score_back))
mask_peak = mask_peak[1:-1]
mask = (mask_high | mask_peak).astype('float32')
return mask
def soft_nms(df, alpha, t1, t2):
'''
df: proposals generated by network;
alpha: alpha value of Gaussian decaying function;
t1, t2: threshold for soft nms.
'''
df = df.sort_values(by="score", ascending=False)
tstart = list(df.xmin.values[:])
tend = list(df.xmax.values[:])
tscore = list(df.score.values[:])
rstart = []
rend = []
rscore = []
while len(tscore) > 1 and len(rscore) < 101:
max_index = tscore.index(max(tscore))
tmp_iou_list = iou_with_anchors(
np.array(tstart),
np.array(tend), tstart[max_index], tend[max_index])
for idx in range(0, len(tscore)):
if idx != max_index:
tmp_iou = tmp_iou_list[idx]
tmp_width = tend[max_index] - tstart[max_index]
if tmp_iou > t1 + (t2 - t1) * tmp_width:
tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
alpha)
rstart.append(tstart[max_index])
rend.append(tend[max_index])
rscore.append(tscore[max_index])
tstart.pop(max_index)
tend.pop(max_index)
tscore.pop(max_index)
newDf = pd.DataFrame()
newDf['score'] = rscore
newDf['xmin'] = rstart
newDf['xmax'] = rend
return newDf
def video_process(video_list,
video_dict,
output_path,
snms_alpha=0.4,
snms_t1=0.55,
snms_t2=0.9):
for video_name in video_list:
df = pd.read_csv(os.path.join(output_path, video_name + ".csv"))
if len(df) > 1:
df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
video_duration = video_dict[video_name]["duration_second"]
proposal_list = []
for idx in range(min(100, len(df))):
tmp_prop={"score":df.score.values[idx],\
"segment":[max(0,df.xmin.values[idx])*video_duration,\
min(1,df.xmax.values[idx])*video_duration]}
proposal_list.append(tmp_prop)
result_dict[video_name[2:]] = proposal_list
def bmn_post_processing(video_dict, subset, output_path, result_path):
video_list = video_dict.keys()
video_list = list(video_dict.keys())
global result_dict
result_dict = mp.Manager().dict()
pp_num = 12
num_videos = len(video_list)
num_videos_per_thread = int(num_videos / pp_num)
processes = []
for tid in range(pp_num - 1):
tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
num_videos_per_thread]
p = mp.Process(
target=video_process,
args=(
tmp_video_list,
video_dict,
output_path, ))
p.start()
processes.append(p)
tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
p = mp.Process(
target=video_process, args=(tmp_video_list, video_dict, output_path))
p.start()
processes.append(p)
for p in processes:
p.join()
result_dict = dict(result_dict)
output_dict = {
"version": "VERSION 1.3",
"results": result_dict,
"external_data": {}
}
outfile = open(
os.path.join(result_path, "bmn_results_%s.json" % subset), "w")
json.dump(output_dict, outfile)
outfile.close()
# BSN 视频动作定位模型
---
## 内容
- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
BSN模型为视频动作定位问题中proposal的生成提供高效的解决方案。此模型采用自底向上的方法生成proposal,首先对每一个时序位置分别生成动作开始、动作结束及动作类别的概率,然后组合其中概率高的位置生成候选proposal,最后利用Boundary-Sensitive Proposal特征判别该proposal是否包含的动作。网络由TEM、PGM和PEM三个模块组成,分别用于时序概率预测、BSP特征生成及proposal置信度评估。
<p align="center">
<img src="../../images/BSN.png" height=300 width=500 hspace='10'/> <br />
BSN Overview
</p>
## 数据准备
BSN的训练数据采用ActivityNet1.3提供的数据集,数据下载及准备请参考[数据说明](../../data/dataset/bmn/README.md)
## 模型训练
TEM模块以snippet-level的特征序列作为输入,预测每一个时序位置属于动作开始、动作结束及动作行为的概率。
数据准备完毕后,可以通过如下两种方式启动训练:
export CUDA_VISIBLE_DEVICES=0
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_fast_eager_deletion_mode=1
python train.py --model_name=BsnTem \
--config=./configs/bsn_tem.yaml \
--log_interval=10 \
--valid_interval=1 \
--use_gpu=True \
--save_dir=./data/checkpoints \
--fix_random_seed=False
bash run.sh train BsnTem ./configs/bsn_tem.yaml
- 代码运行需要先安装pandas
- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
- 若使用第二种方式,请在run.sh脚本文件中设置单卡训练:
export CUDA_VISIBLE_DEVICES=0
**训练策略:**
* 采用Adam优化器,初始learning\_rate=0.001
* 权重衰减系数为1e-4
* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
PEM模块以PGM模块输出的BSP特征作为输入,输出proposal包含动作类别的置信度。
数据准备完毕后,可以通过如下两种方式启动训练:
export CUDA_VISIBLE_DEVICES=0
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_fast_eager_deletion_mode=1
python train.py --model_name=BsnPem \
--config=./configs/bsn_pem.yaml \
--log_interval=10 \
--valid_interval=1 \
--use_gpu=True \
--save_dir=./data/checkpoints \
--fix_random_seed=False
bash run.sh train BsnPem ./configs/bsn_pem.yaml
- 请先运行[TEM模块评估代码](#模型评估),该代码会自动调用PGM模块生成PEM模块运行所需的BSP特征,特征的默认存储路径为data/output/EVAL/PGM\_feature。
- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
- 若使用第二种方式,请在run.sh脚本文件中设置单卡训练:
export CUDA_VISIBLE_DEVICES=0
**训练策略:**
* 采用Adam优化器,初始learning\_rate=0.01
* 权重衰减系数为1e-5
* 学习率在迭代次数达到6000的时候做一次衰减,衰减系数为0.1
## 模型评估
TEM模块可通过如下两种方式进行模型评估:
python eval.py --model_name=BsnTem \
--config=./configs/bsn_tem.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--use_gpu=True
bash run.sh eval BsnTem ./configs/bsn_tem.yaml
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/BsnTem_final.pdparams)进行评估
- 上述程序会将运行结果保存在data/output/EVAL/TEM\_results文件夹下,同时调用PGM模块生成proposal和BSP特征,分别保存在data/output/EVAL/PGM\_proposals和data/output/EVAL/PGM\_feature路径下。
- 使用CPU进行评估时,请将上面的命令行或者run.sh脚本中的`use_gpu`设置为False
TEM评估模块完成后,PEM模块可通过如下两种方式进行模型评估:
python eval.py --model_name=BsnPem \
--config=./configs/bsn_pem.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--use_gpu=True
bash run.sh eval BsnPem ./configs/bsn_pem.yaml
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/BsnPem_final.pdparams)进行评估。
- 上述程序会将运行结果保存在data/output/EVAL/PEM\_results文件夹下,测试结果保存在data/evaluate\_results/bsn\_results\_validation.json文件中。使用ActivityNet官方提供的测试脚本,即可计算AR@AN、AUC。具体计算过程请参考[指标计算](../../metrics/bmn_metrics/README.md)
- 使用CPU进行评估时,请将上面的命令行或者run.sh脚本中的`use_gpu`设置为False
在ActivityNet1.3数据集下评估精度如下:
| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
| :---: | :---: | :---: | :---: | :---: |
| 32.54 | 47.97 | 55.17 | 75.01 | 66.64% |
## 模型推断
TEM模块可通过如下两种方式启动模型推断:
python predict.py --model_name=BsnTem \
--config=./configs/bsn_tem.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST \
--use_gpu=True
bash run.sh predict BsnTem ./configs/bsn_tem.yaml
- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为data/dataset/bmn/infer.list。`--weights`参数为训练好的权重参数,如果不设置,程序会自动下载已训练好的权重。这两个参数如果不设置,请不要写在命令行,将会自动使用默认值。
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要用到的权重。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/BsnTem_final.pdparams)进行推断。
- 上述程序会将运行结果保存在data/output/INFER的子目录TEM\_results、PGM\_proposals、PGM\_feature中。
- 使用CPU进行推断时,请将命令行或者run.sh脚本中的`use_gpu`设置为False
PEM模块可通过如下两种方式启动模型推断:
python predict.py --model_name=BsnPem \
--config=./configs/bsn_pem.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST \
--use_gpu=True
bash run.sh predict BsnPem ./configs/bsn_pem.yaml
- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为data/dataset/bmn/infer.list。`--weights`参数为训练好的权重参数,如果不设置,程序会自动下载已训练好的权重。这两个参数如果不设置,请不要写在命令行,将会自动使用默认值。
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要用到的权重。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/BsnPem_final.pdparams)进行推断。
- 上述程序会将运行结果保存在data/output/INFER/PEM\_results文件夹下,测试结果保存在data/predict\_results/bsn\_results\_test.json文件中。
- 使用CPU进行推断时,请将命令行或者run.sh脚本中的`use_gpu`设置为False
## 参考论文
- [Bsn: Boundary sensitive network for temporal action proposal generation](http://arxiv.org/abs/1806.02964), Tianwei Lin, Xu Zhao, Haisheng Su, Chongjing Wang, Ming Yang.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
from ..model import ModelBase
from . import bsn_net
import logging
logger = logging.getLogger(__name__)
__all__ = ["BsnTem", "BsnPem"]
class BsnTem(ModelBase):
"""BsnTem model"""
def __init__(self, name, cfg, mode='train'):
super(BsnTem, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
self.tscale = self.get_config_from_sec('MODEL', 'tscale')
self.feat_dim = self.get_config_from_sec('MODEL', 'feat_dim')
self.hidden_dim = self.get_config_from_sec('MODEL', 'hidden_dim')
self.num_epochs = self.get_config_from_sec('train', 'epoch')
self.base_learning_rate = self.get_config_from_sec('train',
'learning_rate')
self.learning_rate_decay = self.get_config_from_sec(
'train', 'learning_rate_decay')
self.l2_weight_decay = self.get_config_from_sec('train',
'l2_weight_decay')
self.lr_decay_iter = self.get_config_from_sec('train', 'lr_decay_iter')
def build_input(self, use_pyreader=True):
feat_shape = [self.feat_dim, self.tscale]
gt_start_shape = [self.tscale]
gt_end_shape = [self.tscale]
gt_action_shape = [self.tscale]
fileid_shape = [1]
self.use_pyreader = use_pyreader
# set init data to None
py_reader = None
feat = None
gt_start = None
gt_end = None
gt_action = None
fileid = None
feat = fluid.layers.data(name='feat', shape=feat_shape, dtype='float32')
feed_list = []
feed_list.append(feat)
if (self.mode == 'train') or (self.mode == 'valid'):
gt_start = fluid.layers.data(
name='gt_start', shape=gt_start_shape, dtype='float32')
gt_end = fluid.layers.data(
name='gt_end', shape=gt_end_shape, dtype='float32')
gt_action = fluid.layers.data(
name='gt_action', shape=gt_action_shape, dtype='float32')
feed_list.append(gt_start)
feed_list.append(gt_end)
feed_list.append(gt_action)
elif self.mode == 'test':
gt_start = fluid.layers.data(
name='gt_start', shape=gt_start_shape, dtype='float32')
gt_end = fluid.layers.data(
name='gt_end', shape=gt_end_shape, dtype='float32')
gt_action = fluid.layers.data(
name='gt_action', shape=gt_action_shape, dtype='float32')
feed_list.append(gt_start)
feed_list.append(gt_end)
feed_list.append(gt_action)
fileid = fluid.layers.data(
name='fileid', shape=fileid_shape, dtype='int64')
feed_list.append(fileid)
elif self.mode == 'infer':
# only image feature input when inference
pass
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
self.py_reader = fluid.io.PyReader(
feed_list=feed_list, capacity=8, iterable=True)
self.feat_input = [feat]
self.gt_start = gt_start
self.gt_end = gt_end
self.gt_action = gt_action
self.fileid = fileid
def create_model_args(self):
cfg = {}
cfg['tscale'] = self.tscale
cfg['feat_dim'] = self.feat_dim
cfg['hidden_dim'] = self.hidden_dim
return cfg
def build_model(self):
cfg = self.create_model_args()
self.videomodel = bsn_net.BsnTemNet(cfg=cfg)
preds = self.videomodel.net(input=self.feat_input[0])
self.network_outputs = [preds]
def optimizer(self):
bd = [self.lr_decay_iter]
base_lr = self.base_learning_rate
lr_decay = self.learning_rate_decay
lr = [base_lr, base_lr * lr_decay]
l2_weight_decay = self.l2_weight_decay
optimizer = fluid.optimizer.Adam(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=l2_weight_decay))
return optimizer
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
self.loss_ = self.videomodel.loss_func(
self.network_outputs[0], self.gt_start, self.gt_end, self.gt_action)
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
if (self.mode == 'train') or (self.mode == 'valid'):
return self.feat_input + [
self.gt_start, self.gt_end, self.gt_action
]
elif self.mode == 'test':
return self.feat_input + [
self.gt_start, self.gt_end, self.gt_action, self.fileid
]
elif self.mode == 'infer':
return self.feat_input
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
def fetches(self):
if (self.mode == 'train') or (self.mode == 'valid'):
losses = self.loss()
fetch_list = [item for item in losses]
elif self.mode == 'test':
losses = self.loss()
preds = self.outputs()
fetch_list = [item for item in losses] + \
[item for item in preds] + \
[self.fileid]
elif self.mode == 'infer':
preds = self.outputs()
fetch_list = [item for item in preds]
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
return fetch_list
def pretrain_info(self):
return (None, None)
def weights_info(self):
return (
'BsnTem_final.pdparams',
'https://paddlemodels.bj.bcebos.com/video_detection/BsnTem_final.pdparams'
)
class BsnPem(ModelBase):
"""BsnPem model"""
def __init__(self, name, cfg, mode='train'):
super(BsnPem, self).__init__(name, cfg, mode=mode)
self.mode = mode
self.get_config()
def get_config(self):
self.feat_dim = self.get_config_from_sec('MODEL', 'feat_dim')
self.hidden_dim = self.get_config_from_sec('MODEL', 'hidden_dim')
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size')
self.num_gpus = self.get_config_from_sec(self.mode, 'num_gpus')
self.top_K = self.get_config_from_sec(self.mode, 'top_K')
self.num_epochs = self.get_config_from_sec('train', 'epoch')
self.base_learning_rate = self.get_config_from_sec('train',
'learning_rate')
self.learning_rate_decay = self.get_config_from_sec(
'train', 'learning_rate_decay')
self.l2_weight_decay = self.get_config_from_sec('train',
'l2_weight_decay')
self.lr_decay_iter = self.get_config_from_sec('train', 'lr_decay_iter')
def build_input(self, use_pyreader=True):
feat_shape = [self.top_K, self.feat_dim]
gt_iou_shape = [self.top_K, 1]
props_info_shape = [self.top_K, 4]
fileid_shape = [1]
self.use_pyreader = use_pyreader
# set init data to None
py_reader = None
feat = None
gt_iou = None
props_info = None
fileid = None
feat = fluid.layers.data(name='feat', shape=feat_shape, dtype='float32')
feed_list = []
feed_list.append(feat)
if (self.mode == 'train') or (self.mode == 'valid'):
gt_iou = fluid.layers.data(
name='gt_iou', shape=gt_iou_shape, dtype='float32')
feed_list.append(gt_iou)
elif self.mode == 'test':
gt_iou = fluid.layers.data(
name='gt_iou', shape=gt_iou_shape, dtype='float32')
props_info = fluid.layers.data(
name='props_info', shape=props_info_shape, dtype='float32')
feed_list.append(gt_iou)
feed_list.append(props_info)
fileid = fluid.layers.data(
name='fileid', shape=fileid_shape, dtype='int64')
feed_list.append(fileid)
elif self.mode == 'infer':
props_info = fluid.layers.data(
name='props_info', shape=props_info_shape, dtype='float32')
feed_list.append(props_info)
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
self.py_reader = fluid.io.PyReader(
feed_list=feed_list, capacity=4, iterable=True)
self.feat_input = [feat]
self.gt_iou = gt_iou
self.props_info = props_info
self.fileid = fileid
def create_model_args(self):
cfg = {}
cfg['feat_dim'] = self.feat_dim
cfg['hidden_dim'] = self.hidden_dim
cfg['batch_size'] = self.batch_size
cfg['top_K'] = self.top_K
cfg["num_gpus"] = self.num_gpus
return cfg
def build_model(self):
cfg = self.create_model_args()
self.videomodel = bsn_net.BsnPemNet(cfg=cfg)
preds = self.videomodel.net(input=self.feat_input[0])
self.network_outputs = [preds]
def optimizer(self):
bd = [self.lr_decay_iter]
base_lr = self.base_learning_rate
lr_decay = self.learning_rate_decay
lr = [base_lr, base_lr * lr_decay]
l2_weight_decay = self.l2_weight_decay
optimizer = fluid.optimizer.Adam(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=l2_weight_decay))
return optimizer
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
self.loss_ = self.videomodel.loss_func(self.network_outputs[0],
self.gt_iou)
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
if (self.mode == 'train') or (self.mode == 'valid'):
return self.feat_input + [self.gt_iou]
elif self.mode == 'test':
return self.feat_input + [self.gt_iou, self.props_info, self.fileid]
elif self.mode == 'infer':
return self.feat_input + [self.props_info]
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
def fetches(self):
if (self.mode == 'train') or (self.mode == 'valid'):
losses = self.loss()
fetch_list = [item for item in losses]
elif self.mode == 'test':
losses = self.loss()
preds = self.outputs()
fetch_list = [item for item in losses] + \
[item for item in preds] + \
[self.props_info, self.fileid]
elif self.mode == 'infer':
preds = self.outputs()
fetch_list = [item for item in preds] + \
[self.props_info]
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
return fetch_list
def pretrain_info(self):
return (None, None)
def weights_info(self):
return (
'BsnPem_final.pdparams',
'https://paddlemodels.bj.bcebos.com/video_detection/BsnPem_final.pdparams'
)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
#from .ctcn_utils import get_ctcn_conv_initializer as get_init
import math
DATATYPE = 'float32'
class BsnTemNet(object):
def __init__(self, cfg):
self.tscale = cfg["tscale"]
self.feat_dim = cfg["feat_dim"]
self.hidden_dim = cfg["hidden_dim"]
def conv1d(self,
input,
num_k=256,
input_size=256,
size_k=3,
padding=1,
act='relu',
name="conv1d"):
fan_in = input_size * size_k * 1
k = 1. / math.sqrt(fan_in)
param_attr = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-k, high=k))
bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-k, high=k))
input = fluid.layers.unsqueeze(input=input, axes=[2])
conv = fluid.layers.conv2d(
input=input,
num_filters=num_k,
filter_size=(1, size_k),
stride=1,
padding=(0, padding),
act=act,
name=name,
param_attr=param_attr,
bias_attr=bias_attr)
conv = fluid.layers.squeeze(input=conv, axes=[2])
return conv
def net(self, input):
x_1d = self.conv1d(
input,
input_size=self.feat_dim,
num_k=self.hidden_dim,
size_k=3,
padding=1,
act="relu",
name="Base_1")
x_1d = self.conv1d(
x_1d,
input_size=self.hidden_dim,
num_k=self.hidden_dim,
size_k=3,
padding=1,
act="relu",
name="Base_2")
x_1d = self.conv1d(
x_1d,
input_size=self.hidden_dim,
num_k=3,
size_k=1,
padding=0,
act="sigmoid",
name="Pred")
return x_1d
def loss_func(self, preds, gt_start, gt_end, gt_action):
pred_start = fluid.layers.squeeze(
fluid.layers.slice(
preds, axes=[1], starts=[0], ends=[1]), axes=[1])
pred_end = fluid.layers.squeeze(
fluid.layers.slice(
preds, axes=[1], starts=[1], ends=[2]), axes=[1])
pred_action = fluid.layers.squeeze(
fluid.layers.slice(
preds, axes=[1], starts=[2], ends=[3]), axes=[1])
def bi_loss(pred_score, gt_label):
pred_score = fluid.layers.reshape(
x=pred_score, shape=[-1], inplace=True)
gt_label = fluid.layers.reshape(
x=gt_label, shape=[-1], inplace=False)
gt_label.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype='float32')
num_entries = fluid.layers.cast(
fluid.layers.shape(pmask), dtype='float32')
num_positive = fluid.layers.cast(
fluid.layers.reduce_sum(pmask), dtype='float32')
ratio = num_entries / num_positive
coef_0 = 0.5 * num_entries / (num_entries - num_positive + 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
loss = -1 * (loss_pos + loss_neg)
return loss
loss_start = bi_loss(pred_start, gt_start)
loss_end = bi_loss(pred_end, gt_end)
loss_action = bi_loss(pred_action, gt_action)
loss = loss_start + loss_end + loss_action
return loss, loss_start, loss_end, loss_action
class BsnPemNet(object):
def __init__(self, cfg):
self.feat_dim = cfg["feat_dim"]
self.hidden_dim = cfg["hidden_dim"]
self.batch_size = cfg["batch_size"]
self.top_K = cfg["top_K"]
self.num_gpus = cfg["num_gpus"]
self.mini_batch = self.batch_size // self.num_gpus
def net(self, input):
input = fluid.layers.reshape(input, shape=[-1, self.feat_dim])
x = fluid.layers.fc(input=input, size=self.hidden_dim)
x = fluid.layers.relu(0.1 * x)
x = fluid.layers.fc(input=input, size=1)
x = fluid.layers.sigmoid(0.1 * x)
return x
def loss_func(self, pred_score, gt_iou):
gt_iou = fluid.layers.reshape(gt_iou, shape=[-1, 1])
u_hmask = fluid.layers.cast(x=gt_iou > 0.6, dtype=DATATYPE)
u_mmask = fluid.layers.logical_and(gt_iou <= 0.6, gt_iou > 0.2)
u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
u_lmask = fluid.layers.logical_and(gt_iou <= 0.2, gt_iou >= 0.)
u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
num_h = fluid.layers.cast(
fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
num_m = fluid.layers.cast(
fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
num_l = fluid.layers.cast(
fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
r_m = num_h / num_m
u_smmask = fluid.layers.uniform_random(
shape=[self.mini_batch * self.top_K, 1],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
r_l = 2 * num_h / num_l
u_slmask = fluid.layers.uniform_random(
shape=[self.mini_batch * self.top_K, 1],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
weights = u_hmask + u_smmask + u_slmask
weights.stop_gradient = True
loss = fluid.layers.square_error_cost(pred_score, gt_iou)
loss = fluid.layers.elementwise_mul(loss, weights)
loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
weights)
return [loss]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
from paddle.fluid.initializer import Uniform
from scipy.interpolate import interp1d
import pandas as pd
import multiprocessing as mp
import json
import pandas
import numpy
import random
import os
def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute jaccard score between a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
union_len = len_anchors - inter_len + box_max - box_min
#print inter_len,union_len
jaccard = np.divide(inter_len, union_len)
return jaccard
def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute intersection between score a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
scores = np.divide(inter_len, len_anchors)
return scores
def boundary_choose(score_list, peak_thres):
max_score = max(score_list)
mask_high = (score_list > max_score * peak_thres)
score_list = list(score_list)
score_middle = np.array([0.0] + score_list + [0.0])
score_front = np.array([0.0, 0.0] + score_list)
score_back = np.array(score_list + [0.0, 0.0])
mask_peak = ((score_middle > score_front) & (score_middle > score_back))
mask_peak = mask_peak[1:-1]
mask = (mask_high | mask_peak).astype('float32')
return mask
def soft_nms(df, alpha, t1, t2):
'''
df: proposals generated by network;
alpha: alpha value of Gaussian decaying function;
t1, t2: threshold for soft nms.
'''
df = df.sort_values(by="score", ascending=False)
tstart = list(df.xmin.values[:])
tend = list(df.xmax.values[:])
tscore = list(df.score.values[:])
rstart = []
rend = []
rscore = []
while len(tscore) > 1 and len(rscore) < 101:
max_index = tscore.index(max(tscore))
tmp_iou_list = iou_with_anchors(
np.array(tstart),
np.array(tend), tstart[max_index], tend[max_index])
for idx in range(0, len(tscore)):
if idx != max_index:
tmp_iou = tmp_iou_list[idx]
tmp_width = tend[max_index] - tstart[max_index]
if tmp_iou > t1 + (t2 - t1) * tmp_width:
tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
alpha)
rstart.append(tstart[max_index])
rend.append(tend[max_index])
rscore.append(tscore[max_index])
tstart.pop(max_index)
tend.pop(max_index)
tscore.pop(max_index)
newDf = pd.DataFrame()
newDf['score'] = rscore
newDf['xmin'] = rstart
newDf['xmax'] = rend
return newDf
def video_process(video_list,
video_dict,
output_path_pem,
snms_alpha=0.75,
snms_t1=0.65,
snms_t2=0.9):
for video_name in video_list:
df = pd.read_csv(os.path.join(output_path_pem, video_name + ".csv"))
df["score"] = df.xmin_score.values[:] * df.xmax_score.values[:] * df.iou_score.values[:]
if len(df) > 1:
df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
video_duration = video_dict[video_name]["duration_second"]
proposal_list = []
for idx in range(min(100, len(df))):
tmp_prop={"score":df.score.values[idx],\
"segment":[max(0,df.xmin.values[idx])*video_duration,\
min(1,df.xmax.values[idx])*video_duration]}
proposal_list.append(tmp_prop)
result_dict[video_name[2:]] = proposal_list
def bsn_post_processing(video_dict, subset, output_path_pem, result_path_pem):
video_list = video_dict.keys()
video_list = list(video_dict.keys())
global result_dict
result_dict = mp.Manager().dict()
pp_num = 12
num_videos = len(video_list)
num_videos_per_thread = int(num_videos / pp_num)
processes = []
for tid in range(pp_num - 1):
tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
num_videos_per_thread]
p = mp.Process(
target=video_process,
args=(
tmp_video_list,
video_dict,
output_path_pem, ))
p.start()
processes.append(p)
tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
p = mp.Process(
target=video_process,
args=(
tmp_video_list,
video_dict,
output_path_pem, ))
p.start()
processes.append(p)
for p in processes:
p.join()
result_dict = dict(result_dict)
output_dict = {
"version": "VERSION 1.3",
"results": result_dict,
"external_data": {}
}
outfile = open(
os.path.join(result_path_pem, "bsn_results_%s.json" % subset), "w")
json.dump(output_dict, outfile)
outfile.close()
def generate_props(pgm_config, video_list, video_dict, output_path_tem,
output_path_pgm_proposal):
tscale = pgm_config["tscale"]
peak_thres = pgm_config["pgm_threshold"]
anchor_xmin = [1.0 / tscale * i for i in range(tscale)]
anchor_xmax = [1.0 / tscale * i for i in range(1, tscale + 1)]
for video_name in video_list:
video_info = video_dict[video_name]
if video_info["subset"] == "training":
top_K = pgm_config["pgm_top_K_train"]
else:
top_K = pgm_config["pgm_top_K"]
tdf = pandas.read_csv(
os.path.join(output_path_tem, video_name + ".csv"))
start_scores = tdf.start.values[:]
end_scores = tdf.end.values[:]
start_mask = boundary_choose(start_scores, peak_thres)
start_mask[0] = 1.
end_mask = boundary_choose(end_scores, peak_thres)
end_mask[-1] = 1.
score_vector_list = []
for idx in range(tscale):
for jdx in range(tscale):
start_index = jdx
end_index = start_index + idx
if end_index < tscale and start_mask[
start_index] == 1 and end_mask[end_index] == 1:
xmin = anchor_xmin[start_index]
xmax = anchor_xmax[end_index]
xmin_score = start_scores[start_index]
xmax_score = end_scores[end_index]
score_vector_list.append(
[xmin, xmax, xmin_score, xmax_score])
num_data = len(score_vector_list)
if num_data < top_K:
for idx in range(top_K - num_data):
start_index = random.randint(0, tscale - 1)
end_index = random.randint(start_index, tscale - 1)
xmin = anchor_xmin[start_index]
xmax = anchor_xmax[end_index]
xmin_score = start_scores[start_index]
xmax_score = end_scores[end_index]
score_vector_list.append([xmin, xmax, xmin_score, xmax_score])
score_vector_list = np.stack(score_vector_list)
col_name = ["xmin", "xmax", "xmin_score", "xmax_score"]
new_df = pandas.DataFrame(score_vector_list, columns=col_name)
new_df["score"] = new_df.xmin_score * new_df.xmax_score
new_df = new_df.sort_values(by="score", ascending=False)
new_df = new_df[:top_K]
video_second = video_info['duration_second']
try:
gt_xmins = []
gt_xmaxs = []
for idx in range(len(video_info["annotations"])):
gt_xmins.append(video_info["annotations"][idx]["segment"][0] /
video_second)
gt_xmaxs.append(video_info["annotations"][idx]["segment"][1] /
video_second)
new_iou_list = []
for j in range(len(gt_xmins)):
tmp_new_iou = iou_with_anchors(new_df.xmin.values[:],
new_df.xmax.values[:],
gt_xmins[j], gt_xmaxs[j])
new_iou_list.append(tmp_new_iou)
new_iou_list = numpy.stack(new_iou_list)
new_iou_list = numpy.max(new_iou_list, axis=0)
new_ioa_list = []
for j in range(len(gt_xmins)):
tmp_new_ioa = ioa_with_anchors(new_df.xmin.values[:],
new_df.xmax.values[:],
gt_xmins[j], gt_xmaxs[j])
new_ioa_list.append(tmp_new_ioa)
new_ioa_list = numpy.stack(new_ioa_list)
new_ioa_list = numpy.max(new_ioa_list, axis=0)
new_df["match_iou"] = new_iou_list
new_df["match_ioa"] = new_ioa_list
except:
pass
new_df.to_csv(
os.path.join(output_path_pgm_proposal, video_name + ".csv"),
index=False)
def generate_feats(pgm_config, video_list, video_dict, output_path_tem,
output_path_pgm_proposal, output_path_pgm_feature):
num_sample_start = pgm_config["num_sample_start"]
num_sample_end = pgm_config["num_sample_end"]
num_sample_action = pgm_config["num_sample_action"]
num_sample_perbin = pgm_config["num_sample_perbin"]
tscale = pgm_config["tscale"]
seg_xmins = [1.0 / tscale * i for i in range(tscale)]
seg_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)]
for video_name in video_list:
adf = pandas.read_csv(
os.path.join(output_path_tem, video_name + ".csv"))
score_action = adf.action.values[:]
video_scale = len(adf)
video_gap = seg_xmaxs[0] - seg_xmins[0]
video_extend = int(video_scale / 4 + 10)
pdf = pandas.read_csv(
os.path.join(output_path_pgm_proposal, video_name + ".csv"))
tmp_zeros = numpy.zeros([video_extend])
score_action = numpy.concatenate((tmp_zeros, score_action, tmp_zeros))
tmp_cell = video_gap
tmp_x = [-tmp_cell / 2 - (video_extend - 1 - ii) * tmp_cell for ii in range(video_extend)] + \
[tmp_cell / 2 + ii * tmp_cell for ii in range(video_scale)] + \
[tmp_cell / 2 + seg_xmaxs[-1] + ii * tmp_cell for ii in range(video_extend)]
f_action = interp1d(tmp_x, score_action, axis=0)
feature_bsp = []
for idx in range(len(pdf)):
xmin = pdf.xmin.values[idx]
xmax = pdf.xmax.values[idx]
xlen = xmax - xmin
xmin_0 = xmin - xlen * pgm_config["bsp_boundary_ratio"]
xmin_1 = xmin + xlen * pgm_config["bsp_boundary_ratio"]
xmax_0 = xmax - xlen * pgm_config["bsp_boundary_ratio"]
xmax_1 = xmax + xlen * pgm_config["bsp_boundary_ratio"]
# start
plen_start = (xmin_1 - xmin_0) / (num_sample_start - 1)
plen_sample = plen_start / num_sample_perbin
tmp_x_new = [
xmin_0 - plen_start / 2 + plen_sample * ii
for ii in range(num_sample_start * num_sample_perbin + 1)
]
tmp_y_new_start_action = f_action(tmp_x_new)
tmp_y_new_start = [
numpy.mean(tmp_y_new_start_action[ii * num_sample_perbin:(ii + 1) * num_sample_perbin + 1]) \
for ii in range(num_sample_start)]
# end
plen_end = (xmax_1 - xmax_0) / (num_sample_end - 1)
plen_sample = plen_end / num_sample_perbin
tmp_x_new = [
xmax_0 - plen_end / 2 + plen_sample * ii
for ii in range(num_sample_end * num_sample_perbin + 1)
]
tmp_y_new_end_action = f_action(tmp_x_new)
tmp_y_new_end = [
numpy.mean(tmp_y_new_end_action[ii * num_sample_perbin:(ii + 1) * num_sample_perbin + 1]) \
for ii in range(num_sample_end)]
# action
plen_action = (xmax - xmin) / (num_sample_action - 1)
plen_sample = plen_action / num_sample_perbin
tmp_x_new = [
xmin - plen_action / 2 + plen_sample * ii
for ii in range(num_sample_action * num_sample_perbin + 1)
]
tmp_y_new_action = f_action(tmp_x_new)
tmp_y_new_action = [
numpy.mean(tmp_y_new_action[ii * num_sample_perbin:(ii + 1) * num_sample_perbin + 1]) \
for ii in range(num_sample_action)]
tmp_feature = numpy.concatenate(
[tmp_y_new_action, tmp_y_new_start, tmp_y_new_end])
feature_bsp.append(tmp_feature)
feature_bsp = numpy.array(feature_bsp)
numpy.save(
os.path.join(output_path_pgm_feature, video_name), feature_bsp)
def pgm_gen_proposal(video_dict, pgm_config, output_path_tem,
output_path_pgm_proposal):
video_list = list(video_dict.keys())
video_list.sort()
num_videos = len(video_list)
num_videos_per_thread = int(num_videos / pgm_config["pgm_thread"])
processes = []
for tid in range(pgm_config["pgm_thread"] - 1):
tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
num_videos_per_thread]
p = mp.Process(
target=generate_props,
args=(
pgm_config,
tmp_video_list,
video_dict,
output_path_tem,
output_path_pgm_proposal, ))
p.start()
processes.append(p)
tmp_video_list = video_list[(pgm_config["pgm_thread"] - 1) *
num_videos_per_thread:]
p = mp.Process(
target=generate_props,
args=(
pgm_config,
tmp_video_list,
video_dict,
output_path_tem,
output_path_pgm_proposal, ))
p.start()
processes.append(p)
for p in processes:
p.join()
def pgm_gen_feature(video_dict, pgm_config, output_path_tem,
output_path_pgm_proposal, output_path_pgm_feature):
video_list = list(video_dict.keys())
video_list.sort()
num_videos = len(video_list)
num_videos_per_thread = int(num_videos / pgm_config["pgm_thread"])
processes = []
for tid in range(pgm_config["pgm_thread"] - 1):
tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
num_videos_per_thread]
p = mp.Process(
target=generate_feats,
args=(
pgm_config,
tmp_video_list,
video_dict,
output_path_tem,
output_path_pgm_proposal,
output_path_pgm_feature, ))
p.start()
processes.append(p)
tmp_video_list = video_list[(pgm_config["pgm_thread"] - 1) *
num_videos_per_thread:]
p = mp.Process(
target=generate_feats,
args=(
pgm_config,
tmp_video_list,
video_dict,
output_path_tem,
output_path_pgm_proposal,
output_path_pgm_feature, ))
p.start()
processes.append(p)
for p in processes:
p.join()
......@@ -3,6 +3,9 @@ from .feature_reader import FeatureReader
from .kinetics_reader import KineticsReader
from .nonlocal_reader import NonlocalReader
from .ctcn_reader import CTCNReader
from .bmn_reader import BMNReader
from .bsn_reader import BSNVideoReader
from .bsn_reader import BSNProposalReader
# regist reader, sort by alphabet
regist_reader("ATTENTIONCLUSTER", FeatureReader)
......@@ -13,3 +16,6 @@ regist_reader("TSM", KineticsReader)
regist_reader("TSN", KineticsReader)
regist_reader("STNET", KineticsReader)
regist_reader("CTCN", CTCNReader)
regist_reader("BMN", BMNReader)
regist_reader("BSNTEM", BSNVideoReader)
regist_reader("BSNPEM", BSNProposalReader)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import random
import numpy as np
import multiprocessing
import json
import logging
logger = logging.getLogger(__name__)
from .reader_utils import DataReader
from models.bmn.bmn_utils import iou_with_anchors, ioa_with_anchors
class BMNReader(DataReader):
"""
Data reader for BMN model, which was stored as features extracted by prior networks
dataset cfg: anno_file, annotation file path,
feat_path, feature path,
tscale, temporal length of BM map,
dscale, duration scale of BM map,
anchor_xmin, anchor_xmax, the range of each point in the feature sequence,
batch_size, batch size of input data,
num_threads, number of threads of data processing
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.tscale = cfg.MODEL.tscale # 100
self.dscale = cfg.MODEL.dscale # 100
self.anno_file = cfg.MODEL.anno_file
self.file_list = cfg.INFER.filelist
self.subset = cfg[mode.upper()]['subset']
self.tgap = 1. / self.tscale
self.feat_path = cfg.MODEL.feat_path
self.get_dataset_dict()
self.get_match_map()
self.batch_size = cfg[mode.upper()]['batch_size']
self.num_threads = cfg[mode.upper()]['num_threads']
if (mode == 'test') or (mode == 'infer'):
self.num_threads = 1 # set num_threads as 1 for test and infer
def get_dataset_dict(self):
self.video_dict = {}
if self.mode == "infer":
annos = json.load(open(self.file_list))
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
print("%s subset video numbers: %d" %
(self.subset, len(self.video_list)))
def get_match_map(self):
match_map = []
for idx in range(self.tscale):
tmp_match_window = []
xmin = self.tgap * idx
for jdx in range(1, self.tscale + 1):
xmax = xmin + self.tgap * jdx
tmp_match_window.append([xmin, xmax])
match_map.append(tmp_match_window)
match_map = np.array(match_map)
match_map = np.transpose(match_map, [1, 0, 2])
match_map = np.reshape(match_map, [-1, 2])
self.match_map = match_map
self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
def get_video_label(self, video_name):
video_info = self.video_dict[video_name]
video_second = video_info['duration_second']
video_labels = video_info['annotations']
gt_bbox = []
gt_iou_map = []
for gt in video_labels:
tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
gt_bbox.append([tmp_start, tmp_end])
tmp_gt_iou_map = iou_with_anchors(
self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end)
tmp_gt_iou_map = np.reshape(tmp_gt_iou_map,
[self.dscale, self.tscale])
gt_iou_map.append(tmp_gt_iou_map)
gt_iou_map = np.array(gt_iou_map)
gt_iou_map = np.max(gt_iou_map, axis=0)
gt_bbox = np.array(gt_bbox)
gt_xmins = gt_bbox[:, 0]
gt_xmaxs = gt_bbox[:, 1]
gt_lens = gt_xmaxs - gt_xmins
gt_len_small = 3 * self.tgap
# gt_len_small=np.maximum(temporal_gap,boundary_ratio*gt_lens)
gt_start_bboxs = np.stack(
(gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
gt_end_bboxs = np.stack(
(gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
match_score_start = []
for jdx in range(len(self.anchor_xmin)):
match_score_start.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
match_score_end = []
for jdx in range(len(self.anchor_xmin)):
match_score_end.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
gt_start = np.array(match_score_start)
gt_end = np.array(match_score_end)
return gt_iou_map, gt_start, gt_end
def load_file(self, video_name):
file_name = video_name + ".npy"
file_path = os.path.join(self.feat_path, file_name)
video_feat = np.load(file_path)
video_feat = video_feat.T
video_feat = video_feat.astype("float32")
return video_feat
def create_reader(self):
"""reader creator for ctcn model"""
if self.mode == 'infer':
return self.make_infer_reader()
if self.num_threads == 1:
return self.make_reader()
else:
return self.make_multiprocess_reader()
def make_infer_reader(self):
"""reader for inference"""
def reader():
batch_out = []
for video_name in self.video_list:
video_idx = self.video_list.index(video_name)
video_feat = self.load_file(video_name)
batch_out.append((video_feat, video_idx))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_reader(self):
"""single process reader"""
def reader():
video_list = self.video_list
if self.mode == 'train':
random.shuffle(video_list)
batch_out = []
for video_name in video_list:
video_idx = video_list.index(video_name)
video_feat = self.load_file(video_name)
gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((video_feat, gt_iou_map, gt_start, gt_end))
elif self.mode == 'test':
batch_out.append(
(video_feat, gt_iou_map, gt_start, gt_end, video_idx))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_multiprocess_reader(self):
"""multiprocess reader"""
def read_into_queue(video_list, queue):
batch_out = []
for video_name in video_list:
video_idx = video_list.index(video_name)
video_feat = self.load_file(video_name)
gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((video_feat, gt_iou_map, gt_start, gt_end))
elif self.mode == 'test':
batch_out.append(
(video_feat, gt_iou_map, gt_start, gt_end, video_idx))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
video_list = self.video_list
if self.mode == 'train':
random.shuffle(video_list)
n = self.num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(video_list) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = video_list[i * file_num:(i + 1) * file_num]
else:
tmp_list = video_list[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import random
import numpy as np
import pandas as pd
import multiprocessing
import json
import logging
logger = logging.getLogger(__name__)
from .reader_utils import DataReader
from models.bsn.bsn_utils import iou_with_anchors, ioa_with_anchors
class BSNVideoReader(DataReader):
"""
Data reader for BsnTem model, which was stored as features extracted by prior networks
dataset cfg: anno_file, annotation file path,
feat_path, feature path,
file_list, file list for infer,
tscale, temporal length of input,
anchor_xmin, anchor_xmax, the range of each point in the feature sequence,
batch_size, batch size of input data,
num_threads, number of threads of data processing
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.tscale = cfg.MODEL.tscale # 100
self.gt_boundary_ratio = cfg.MODEL.gt_boundary_ratio
self.anno_file = cfg.MODEL.anno_file
self.file_list = cfg.INFER.filelist
self.subset = cfg[mode.upper()]['subset']
self.tgap = 1. / self.tscale
self.feat_path = cfg.MODEL.feat_path
self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
self.get_dataset_dict()
self.batch_size = cfg[mode.upper()]['batch_size']
self.num_threads = cfg[mode.upper()]['num_threads']
if (mode == 'test') or (mode == 'infer'):
self.num_threads = 1 # set num_threads as 1 for test and infer
def get_dataset_dict(self):
self.video_dict = {}
if self.mode == "infer":
annos = json.load(open(self.file_list))
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset == "train_val":
if "train" in video_subset or "validation" in video_subset:
self.video_dict[video_name] = annos[video_name]
else:
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
print("%s subset video numbers: %d" %
(self.subset, len(self.video_list)))
def get_video_label(self, video_name):
video_info = self.video_dict[video_name]
video_second = video_info['duration_second']
video_labels = video_info['annotations']
gt_bbox = []
for gt in video_labels:
tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
gt_bbox.append([tmp_start, tmp_end])
gt_bbox = np.array(gt_bbox)
gt_xmins = gt_bbox[:, 0]
gt_xmaxs = gt_bbox[:, 1]
gt_lens = gt_xmaxs - gt_xmins
gt_len_small = np.maximum(self.tgap, self.gt_boundary_ratio * gt_lens)
gt_start_bboxs = np.stack(
(gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
gt_end_bboxs = np.stack(
(gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
match_score_action = []
for jdx in range(len(self.anchor_xmin)):
match_score_action.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_xmins, gt_xmaxs)))
match_score_start = []
for jdx in range(len(self.anchor_xmin)):
match_score_start.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
match_score_end = []
for jdx in range(len(self.anchor_xmin)):
match_score_end.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
gt_start = np.array(match_score_start)
gt_end = np.array(match_score_end)
gt_action = np.array(match_score_action)
return gt_start, gt_end, gt_action
def load_file(self, video_name):
video_feat = np.load(self.feat_path + "/" + video_name + ".npy")
video_feat = video_feat.T
video_feat = video_feat.astype("float32")
return video_feat
def create_reader(self):
"""reader creator for ctcn model"""
if self.mode == 'infer':
return self.make_infer_reader()
if self.num_threads == 1:
return self.make_reader()
else:
return self.make_multiprocess_reader()
def make_infer_reader(self):
"""reader for inference"""
def reader():
batch_out = []
for video_name in self.video_list:
video_idx = self.video_list.index(video_name)
video_feat = self.load_file(video_name)
batch_out.append((video_feat, video_idx))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_reader(self):
"""single process reader"""
def reader():
video_list = self.video_list
if self.mode == 'train':
random.shuffle(video_list)
batch_out = []
for video_name in video_list:
video_idx = video_list.index(video_name)
video_feat = self.load_file(video_name)
gt_start, gt_end, gt_action = self.get_video_label(video_name)
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((video_feat, gt_start, gt_end, gt_action))
elif self.mode == 'test':
batch_out.append(
(video_feat, gt_start, gt_end, gt_action, video_idx))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_multiprocess_reader(self):
"""multiprocess reader"""
def read_into_queue(video_list, queue):
batch_out = []
for video_name in video_list:
video_idx = video_list.index(video_name)
video_feat = self.load_file(video_name)
gt_start, gt_end, gt_action = self.get_video_label(video_name)
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((video_feat, gt_start, gt_end, gt_action))
elif self.mode == 'test':
batch_out.append(
(video_feat, gt_start, gt_end, gt_action, video_idx))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
video_list = self.video_list
if self.mode == 'train':
random.shuffle(video_list)
n = self.num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(video_list) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = video_list[i * file_num:(i + 1) * file_num]
else:
tmp_list = video_list[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
class BSNProposalReader(DataReader):
"""
Data reader for BsnPem model, which was stored as features extracted by prior networks
dataset cfg: anno_file, annotation file path,
file_list, file list for infer,
top_K, number of proposals during training/test,
feat_path, feature path generated by PGM,
prop_path, proposal path generated by PGM,
batch_size, batch size of input data,
num_threads, number of threads of data processing.
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.top_K = cfg[mode.upper()]['top_K']
self.anno_file = cfg.MODEL.anno_file
self.file_list = cfg.INFER.filelist
self.subset = cfg[mode.upper()]['subset']
if mode == 'infer':
self.feat_path = cfg[mode.upper()]['feat_path']
self.prop_path = cfg[mode.upper()]['prop_path']
else:
self.feat_path = cfg.MODEL.feat_path
self.prop_path = cfg.MODEL.prop_path
self.get_dataset_dict()
self.batch_size = cfg[mode.upper()]['batch_size']
self.num_threads = cfg[mode.upper()]['num_threads']
if (mode == 'test') or (mode == 'infer'):
self.num_threads = 1 # set num_threads as 1 for test and infer
def get_dataset_dict(self):
self.video_dict = {}
if self.mode == "infer":
annos = json.load(open(self.file_list))
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
print("%s subset video numbers: %d" %
(self.subset, len(self.video_list)))
def get_props(self, video_name):
pdf = pd.read_csv(self.prop_path + video_name + ".csv")
pdf = pdf[:self.top_K]
props_start = pdf.xmin.values[:]
props_end = pdf.xmax.values[:]
props_start_score = pdf.xmin_score.values[:]
props_end_score = pdf.xmax_score.values[:]
props_info = np.stack(
[props_start, props_end, props_start_score, props_end_score])
if self.mode == "infer":
return props_info
else:
props_iou = pdf.match_iou.values[:]
return props_iou, props_info
def load_file(self, video_name):
video_feat = np.load(self.feat_path + video_name + ".npy")
video_feat = video_feat[:self.top_K, :]
video_feat = video_feat.astype("float32")
return video_feat
def create_reader(self):
"""reader creator for ctcn model"""
if self.mode == 'infer':
return self.make_infer_reader()
if self.num_threads == 1:
return self.make_reader()
else:
return self.make_multiprocess_reader()
def make_infer_reader(self):
"""reader for inference"""
def reader():
batch_out = []
for video_name in self.video_list:
video_idx = self.video_list.index(video_name)
props_feat = self.load_file(video_name)
props_info = self.get_props(video_name)
batch_out.append((props_feat, props_info, video_idx))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_reader(self):
"""single process reader"""
def reader():
video_list = self.video_list
if self.mode == 'train':
random.shuffle(video_list)
batch_out = []
for video_name in video_list:
video_idx = video_list.index(video_name)
props_feat = self.load_file(video_name)
props_iou, props_info = self.get_props(video_name)
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((props_feat, props_iou))
elif self.mode == 'test':
batch_out.append(
(props_feat, props_iou, props_info, video_idx))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_multiprocess_reader(self):
"""multiprocess reader"""
def read_into_queue(video_list, queue):
batch_out = []
for video_name in video_list:
video_idx = video_list.index(video_name)
props_feat = self.load_file(video_name)
props_iou, props_info = self.get_props(video_name)
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((props_feat, props_iou))
elif self.mode == 'test':
batch_out.append(
(props_feat, props_iou, props_info, video_idx))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
video_list = self.video_list
if self.mode == 'train':
random.shuffle(video_list)
n = self.num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(video_list) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = video_list[i * file_num:(i + 1) * file_num]
else:
tmp_list = video_list[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
......@@ -24,6 +24,7 @@ weights="" #set the path of weights to enable eval and predicut, just ignore thi
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#export CUDA_VISIBLE_DEVICES=0,1,2,3
#export CUDA_VISIBLE_DEVICES=0
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
......
......@@ -124,6 +124,8 @@ PaddlePaddle 提供了丰富的计算单元,使得用户可以采用模块化
| [Attention Cluster](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleVideo) | CVPR'18提出的视频多模态特征注意力聚簇融合方法 | Youtube-8M | GAP = 84% |
| [NeXtVlad](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleVideo) | 2nd-Youtube-8M比赛第3名的模型 | Youtube-8M | GAP = 87% |
| [C-TCN](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleVideo) | 2018年ActivityNet夺冠方案 | ActivityNet1.3 | MAP=31% |
| [BSN](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleVideo) | 为视频动作定位问题提供高效的proposal生成方法 | ActivityNet1.3 | AUC=66.64% |
| [BMN](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleVideo) | 2019年ActivityNet夺冠方案 | ActivityNet1.3 | AUC=67.19% |
## PaddleNLP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册