diff --git a/bmn/BMN.png b/bmn/BMN.png new file mode 100644 index 0000000000000000000000000000000000000000..46b9743cf5b5e79ada93ebb6b61da96967f350a4 Binary files /dev/null and b/bmn/BMN.png differ diff --git a/bmn/README.md b/bmn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dbf593512c07da9645418b92b2a6819c9be13fa5 --- /dev/null +++ b/bmn/README.md @@ -0,0 +1,120 @@ +# BMN 视频动作定位模型高层API实现 + +--- +## 内容 + +- [模型简介](#模型简介) +- [代码结构](#代码结构) +- [数据准备](#数据准备) +- [模型训练](#模型训练) +- [模型评估](#模型评估) +- [模型推断](#模型推断) +- [参考论文](#参考论文) + + +## 模型简介 + +BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。 + +

+
+BMN Overview +

+ + +## 代码结构 +``` +├── bmn.yaml # 网络配置文件,快速配置参数 +├── run.sh # 快速运行脚本,可直接开始多卡训练 +├── train.py # 训练代码,训练网络 +├── eval.py # 评估代码,评估网络性能 +├── predict.py # 预测代码,针对任意输入预测结果 +├── bmn_model.py # 网络结构与损失函数定义 +├── bmn_metric.py # 精度评估方法定义 +├── reader.py # 数据reader,构造Dataset和Dataloader +├── bmn_utils.py # 模型细节相关代码 +├── config_utils.py # 配置细节相关代码 +├── eval_anet_prop.py # 计算精度评估指标 +└── infer.list # 推断文件列表 +``` + + +## 数据准备 + +BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。对应的标签文件请下载[label](https://paddlemodels.bj.bcebos.com/video_detection/activitynet_1.3_annotations.json)并修改bmn.yaml中的标签文件路径anno\_file。 + + +## 模型训练 + +数据准备完成后,可通过如下两种方式启动训练: + +默认使用4卡训练,启动方式如下: + + bash run.sh + +若使用单卡训练,启动方式如下: + + export CUDA_VISIBLE_DEVICES=0 + python train.py + +- 代码运行需要先安装pandas + +- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型 + +- 单卡训练时,请将配置文件中的batch_size调整为16 + +**训练策略:** + +* 采用Adam优化器,初始learning\_rate=0.001 +* 权重衰减系数为1e-4 +* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1 + + +## 模型评估 + +训练完成后,可通过如下方式进行模型评估: + + python eval.py --weights=$PATH_TO_WEIGHTS + +- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/final.pdparams。 + +- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。 + +- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。 + + +使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下: + +- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org) + +- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:由于第三方评估代码不支持python3,此处建议使用python2进行评估;若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。) + +- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn/Evaluation/data目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。 + +- 计算精度指标 + + ```python eval_anet_prop.py``` + + +在ActivityNet1.3数据集下评估精度如下: + +| AR@1 | AR@5 | AR@10 | AR@100 | AUC | +| :---: | :---: | :---: | :---: | :---: | +| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% | + + +## 模型推断 + +可通过如下方式启动模型推断: + + python predict.py --weights=$PATH_TO_WEIGHTS \ + --filelist=$FILELIST + +- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/final.pdparams。 + +- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。 + + +## 参考论文 + +- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen. diff --git a/bmn/bmn.yaml b/bmn/bmn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..da50ea4f7c654d40fbf2498863cb7e87664fe55a --- /dev/null +++ b/bmn/bmn.yaml @@ -0,0 +1,50 @@ +MODEL: + name: "BMN" + tscale: 100 + dscale: 100 + feat_dim: 400 + prop_boundary_ratio: 0.5 + num_sample: 32 + num_sample_perbin: 3 + anno_file: "./activitynet_1.3_annotations.json" + feat_path: './fix_feat_100' + +TRAIN: + subset: "train" + epoch: 9 + batch_size: 4 + num_workers: 4 + use_shuffle: True + device: "gpu" + num_gpus: 4 + learning_rate: 0.001 + learning_rate_decay: 0.1 + lr_decay_iter: 4200 + l2_weight_decay: 1e-4 + +VALID: + subset: "validation" + +TEST: + subset: "validation" + batch_size: 1 + num_workers: 1 + use_buffer: False + snms_alpha: 0.001 + snms_t1: 0.5 + snms_t2: 0.9 + output_path: "output/EVAL/BMN_results" + result_path: "evaluate_results" + +INFER: + subset: "test" + batch_size: 1 + num_workers: 1 + use_buffer: False + snms_alpha: 0.4 + snms_t1: 0.5 + snms_t2: 0.9 + filelist: './infer.list' + output_path: "output/INFER/BMN_results" + result_path: "predict_results" + diff --git a/bmn/bmn_metric.py b/bmn/bmn_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a19f87c6b42b8737ddeb52c3a330f59dcc932004 --- /dev/null +++ b/bmn/bmn_metric.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import numpy as np +import pandas as pd +import os +import sys +import json + +sys.path.append('../') + +from metrics import Metric +from bmn_utils import boundary_choose, bmn_post_processing + + +class BmnMetric(Metric): + """ + only support update with batch_size=1 + """ + + def __init__(self, cfg, mode): + super(BmnMetric, self).__init__() + self.cfg = cfg + self.mode = mode + #get video_dict and video_list + if self.mode == 'test': + self.get_test_dataset_dict() + elif self.mode == 'infer': + self.get_infer_dataset_dict() + + def add_metric_op(self, preds, label): + pred_bm, pred_start, pred_en = preds + video_index = label[-1] + return [pred_bm, pred_start, pred_en, video_index] #return list + + def update(self, pred_bm, pred_start, pred_end, fid): + # generate proposals + pred_start = pred_start[0] + pred_end = pred_end[0] + fid = fid[0] + + if self.mode == 'infer': + output_path = self.cfg.INFER.output_path + else: + output_path = self.cfg.TEST.output_path + tscale = self.cfg.MODEL.tscale + dscale = self.cfg.MODEL.dscale + snippet_xmins = [1.0 / tscale * i for i in range(tscale)] + snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)] + cols = ["xmin", "xmax", "score"] + + video_name = self.video_list[fid] + pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :] + start_mask = boundary_choose(pred_start) + start_mask[0] = 1. + end_mask = boundary_choose(pred_end) + end_mask[-1] = 1. + score_vector_list = [] + for idx in range(dscale): + for jdx in range(tscale): + start_index = jdx + end_index = start_index + idx + if end_index < tscale and start_mask[ + start_index] == 1 and end_mask[end_index] == 1: + xmin = snippet_xmins[start_index] + xmax = snippet_xmaxs[end_index] + xmin_score = pred_start[start_index] + xmax_score = pred_end[end_index] + bm_score = pred_bm[idx, jdx] + conf_score = xmin_score * xmax_score * bm_score + score_vector_list.append([xmin, xmax, conf_score]) + + score_vector_list = np.stack(score_vector_list) + video_df = pd.DataFrame(score_vector_list, columns=cols) + video_df.to_csv( + os.path.join(output_path, "%s.csv" % video_name), index=False) + + return 0 # result has saved in output path + + def accumulate(self): + return 'post_processing is required...' # required method + + def reset(self): + print("Post_processing....This may take a while") + if self.mode == 'test': + bmn_post_processing(self.video_dict, self.cfg.TEST.subset, + self.cfg.TEST.output_path, + self.cfg.TEST.result_path) + elif self.mode == 'infer': + bmn_post_processing(self.video_dict, self.cfg.INFER.subset, + self.cfg.INFER.output_path, + self.cfg.INFER.result_path) + + def name(self): + return 'bmn_metric' + + def get_test_dataset_dict(self): + anno_file = self.cfg.MODEL.anno_file + annos = json.load(open(anno_file)) + subset = self.cfg.TEST.subset + self.video_dict = {} + for video_name in annos.keys(): + video_subset = annos[video_name]["subset"] + if subset in video_subset: + self.video_dict[video_name] = annos[video_name] + self.video_list = list(self.video_dict.keys()) + self.video_list.sort() + + def get_infer_dataset_dict(self): + file_list = self.cfg.INFER.filelist + annos = json.load(open(file_list)) + self.video_dict = {} + for video_name in annos.keys(): + self.video_dict[video_name] = annos[video_name] + self.video_list = list(self.video_dict.keys()) + self.video_list.sort() diff --git a/bmn/bmn_model.py b/bmn/bmn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dfde7bcc5cdaac8aa5ea3c069f580308b49ec01f --- /dev/null +++ b/bmn/bmn_model.py @@ -0,0 +1,364 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +import numpy as np +import math + +from bmn_utils import get_interp1d_mask +from model import Model, Loss + +DATATYPE = 'float32' + + +# Net +class Conv1D(fluid.dygraph.Layer): + def __init__(self, + prefix, + num_channels=256, + num_filters=256, + size_k=3, + padding=1, + groups=1, + act="relu"): + super(Conv1D, self).__init__() + fan_in = num_channels * size_k * 1 + k = 1. / math.sqrt(fan_in) + param_attr = ParamAttr( + name=prefix + "_w", + initializer=fluid.initializer.Uniform( + low=-k, high=k)) + bias_attr = ParamAttr( + name=prefix + "_b", + initializer=fluid.initializer.Uniform( + low=-k, high=k)) + + self._conv2d = fluid.dygraph.Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=(1, size_k), + stride=1, + padding=(0, padding), + groups=groups, + act=act, + param_attr=param_attr, + bias_attr=bias_attr) + + def forward(self, x): + x = fluid.layers.unsqueeze(input=x, axes=[2]) + x = self._conv2d(x) + x = fluid.layers.squeeze(input=x, axes=[2]) + return x + + +class BMN(Model): + def __init__(self, cfg, is_dygraph=True): + super(BMN, self).__init__() + + #init config + self.tscale = cfg.MODEL.tscale + self.dscale = cfg.MODEL.dscale + self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio + self.num_sample = cfg.MODEL.num_sample + self.num_sample_perbin = cfg.MODEL.num_sample_perbin + self.is_dygraph = is_dygraph + + self.hidden_dim_1d = 256 + self.hidden_dim_2d = 128 + self.hidden_dim_3d = 512 + + # Base Module + self.b_conv1 = Conv1D( + prefix="Base_1", + num_channels=400, + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + self.b_conv2 = Conv1D( + prefix="Base_2", + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + + # Temporal Evaluation Module + self.ts_conv1 = Conv1D( + prefix="TEM_s1", + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + self.ts_conv2 = Conv1D( + prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid") + self.te_conv1 = Conv1D( + prefix="TEM_e1", + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + self.te_conv2 = Conv1D( + prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid") + + #Proposal Evaluation Module + self.p_conv1 = Conv1D( + prefix="PEM_1d", + num_filters=self.hidden_dim_2d, + size_k=3, + padding=1, + act="relu") + + # init to speed up + sample_mask_array = get_interp1d_mask( + self.tscale, self.dscale, self.prop_boundary_ratio, + self.num_sample, self.num_sample_perbin) + if self.is_dygraph: + self.sample_mask = fluid.dygraph.base.to_variable( + sample_mask_array) + else: # static + self.sample_mask = fluid.layers.create_parameter( + shape=[ + self.tscale, self.num_sample * self.dscale * self.tscale + ], + dtype=DATATYPE, + attr=fluid.ParamAttr( + name="sample_mask", trainable=False), + default_initializer=fluid.initializer.NumpyArrayInitializer( + sample_mask_array)) + + self.sample_mask.stop_gradient = True + + self.p_conv3d1 = fluid.dygraph.Conv3D( + num_channels=128, + num_filters=self.hidden_dim_3d, + filter_size=(self.num_sample, 1, 1), + stride=(self.num_sample, 1, 1), + padding=0, + act="relu", + param_attr=ParamAttr(name="PEM_3d1_w"), + bias_attr=ParamAttr(name="PEM_3d1_b")) + + self.p_conv2d1 = fluid.dygraph.Conv2D( + num_channels=512, + num_filters=self.hidden_dim_2d, + filter_size=1, + stride=1, + padding=0, + act="relu", + param_attr=ParamAttr(name="PEM_2d1_w"), + bias_attr=ParamAttr(name="PEM_2d1_b")) + self.p_conv2d2 = fluid.dygraph.Conv2D( + num_channels=128, + num_filters=self.hidden_dim_2d, + filter_size=3, + stride=1, + padding=1, + act="relu", + param_attr=ParamAttr(name="PEM_2d2_w"), + bias_attr=ParamAttr(name="PEM_2d2_b")) + self.p_conv2d3 = fluid.dygraph.Conv2D( + num_channels=128, + num_filters=self.hidden_dim_2d, + filter_size=3, + stride=1, + padding=1, + act="relu", + param_attr=ParamAttr(name="PEM_2d3_w"), + bias_attr=ParamAttr(name="PEM_2d3_b")) + self.p_conv2d4 = fluid.dygraph.Conv2D( + num_channels=128, + num_filters=2, + filter_size=1, + stride=1, + padding=0, + act="sigmoid", + param_attr=ParamAttr(name="PEM_2d4_w"), + bias_attr=ParamAttr(name="PEM_2d4_b")) + + def forward(self, x): + #Base Module + x = self.b_conv1(x) + x = self.b_conv2(x) + + #TEM + xs = self.ts_conv1(x) + xs = self.ts_conv2(xs) + xs = fluid.layers.squeeze(xs, axes=[1]) + xe = self.te_conv1(x) + xe = self.te_conv2(xe) + xe = fluid.layers.squeeze(xe, axes=[1]) + + #PEM + xp = self.p_conv1(x) + #BM layer + xp = fluid.layers.matmul(xp, self.sample_mask) + xp = fluid.layers.reshape( + xp, shape=[0, 0, -1, self.dscale, self.tscale]) + + xp = self.p_conv3d1(xp) + xp = fluid.layers.squeeze(xp, axes=[2]) + xp = self.p_conv2d1(xp) + xp = self.p_conv2d2(xp) + xp = self.p_conv2d3(xp) + xp = self.p_conv2d4(xp) + return xp, xs, xe + + +class BmnLoss(Loss): + def __init__(self, cfg): + super(BmnLoss, self).__init__() + self.cfg = cfg + + def _get_mask(self): + dscale = self.cfg.MODEL.dscale + tscale = self.cfg.MODEL.tscale + bm_mask = [] + for idx in range(dscale): + mask_vector = [1 for i in range(tscale - idx) + ] + [0 for i in range(idx)] + bm_mask.append(mask_vector) + bm_mask = np.array(bm_mask, dtype=np.float32) + self_bm_mask = fluid.layers.create_global_var( + shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True) + fluid.layers.assign(bm_mask, self_bm_mask) + self_bm_mask.stop_gradient = True + return self_bm_mask + + def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end): + def bi_loss(pred_score, gt_label): + pred_score = fluid.layers.reshape( + x=pred_score, shape=[-1], inplace=False) + gt_label = fluid.layers.reshape( + x=gt_label, shape=[-1], inplace=False) + gt_label.stop_gradient = True + pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE) + num_entries = fluid.layers.cast( + fluid.layers.shape(pmask), dtype=DATATYPE) + num_positive = fluid.layers.cast( + fluid.layers.reduce_sum(pmask), dtype=DATATYPE) + ratio = num_entries / num_positive + coef_0 = 0.5 * ratio / (ratio - 1) + coef_1 = 0.5 * ratio + epsilon = 0.000001 + temp = fluid.layers.log(pred_score + epsilon) + loss_pos = fluid.layers.elementwise_mul( + fluid.layers.log(pred_score + epsilon), pmask) + loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos) + loss_neg = fluid.layers.elementwise_mul( + fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask)) + loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg) + loss = -1 * (loss_pos + loss_neg) + return loss + + loss_start = bi_loss(pred_start, gt_start) + loss_end = bi_loss(pred_end, gt_end) + loss = loss_start + loss_end + return loss + + def pem_reg_loss_func(self, pred_score, gt_iou_map, mask): + + gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask) + + u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE) + u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3) + u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE) + u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.) + u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE) + u_lmask = fluid.layers.elementwise_mul(u_lmask, mask) + + num_h = fluid.layers.cast( + fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE) + num_m = fluid.layers.cast( + fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE) + num_l = fluid.layers.cast( + fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE) + + r_m = num_h / num_m + u_smmask = fluid.layers.uniform_random( + shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]], + dtype=DATATYPE, + min=0.0, + max=1.0) + u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask) + u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE) + + r_l = num_h / num_l + u_slmask = fluid.layers.uniform_random( + shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]], + dtype=DATATYPE, + min=0.0, + max=1.0) + u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask) + u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE) + + weights = u_hmask + u_smmask + u_slmask + weights.stop_gradient = True + loss = fluid.layers.square_error_cost(pred_score, gt_iou_map) + loss = fluid.layers.elementwise_mul(loss, weights) + loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum( + weights) + + return loss + + def pem_cls_loss_func(self, pred_score, gt_iou_map, mask): + gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask) + gt_iou_map.stop_gradient = True + pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE) + nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE) + nmask = fluid.layers.elementwise_mul(nmask, mask) + + num_positive = fluid.layers.reduce_sum(pmask) + num_entries = num_positive + fluid.layers.reduce_sum(nmask) + ratio = num_entries / num_positive + coef_0 = 0.5 * ratio / (ratio - 1) + coef_1 = 0.5 * ratio + epsilon = 0.000001 + loss_pos = fluid.layers.elementwise_mul( + fluid.layers.log(pred_score + epsilon), pmask) + loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos) + loss_neg = fluid.layers.elementwise_mul( + fluid.layers.log(1.0 - pred_score + epsilon), nmask) + loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg) + loss = -1 * (loss_pos + loss_neg) / num_entries + return loss + + def forward(self, outputs, labels): + pred_bm, pred_start, pred_end = outputs + if len(labels) == 3: + gt_iou_map, gt_start, gt_end = labels + elif len(labels) == 4: # video_index used in eval mode + gt_iou_map, gt_start, gt_end, video_index = labels + pred_bm_reg = fluid.layers.squeeze( + fluid.layers.slice( + pred_bm, axes=[1], starts=[0], ends=[1]), + axes=[1]) + pred_bm_cls = fluid.layers.squeeze( + fluid.layers.slice( + pred_bm, axes=[1], starts=[1], ends=[2]), + axes=[1]) + + bm_mask = self._get_mask() + + pem_reg_loss = self.pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask) + pem_cls_loss = self.pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask) + + tem_loss = self.tem_loss_func(pred_start, pred_end, gt_start, gt_end) + + loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss + return loss diff --git a/bmn/bmn_utils.py b/bmn/bmn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06812e636fdaf6ccc419ca58151402ab50082112 --- /dev/null +++ b/bmn/bmn_utils.py @@ -0,0 +1,217 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import numpy as np +import pandas as pd +import multiprocessing as mp +import json +import os +import math + + +def iou_with_anchors(anchors_min, anchors_max, box_min, box_max): + """Compute jaccard score between a box and the anchors. + """ + len_anchors = anchors_max - anchors_min + int_xmin = np.maximum(anchors_min, box_min) + int_xmax = np.minimum(anchors_max, box_max) + inter_len = np.maximum(int_xmax - int_xmin, 0.) + union_len = len_anchors - inter_len + box_max - box_min + jaccard = np.divide(inter_len, union_len) + return jaccard + + +def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max): + """Compute intersection between score a box and the anchors. + """ + len_anchors = anchors_max - anchors_min + int_xmin = np.maximum(anchors_min, box_min) + int_xmax = np.minimum(anchors_max, box_max) + inter_len = np.maximum(int_xmax - int_xmin, 0.) + scores = np.divide(inter_len, len_anchors) + return scores + + +def boundary_choose(score_list): + max_score = max(score_list) + mask_high = (score_list > max_score * 0.5) + score_list = list(score_list) + score_middle = np.array([0.0] + score_list + [0.0]) + score_front = np.array([0.0, 0.0] + score_list) + score_back = np.array(score_list + [0.0, 0.0]) + mask_peak = ((score_middle > score_front) & (score_middle > score_back)) + mask_peak = mask_peak[1:-1] + mask = (mask_high | mask_peak).astype('float32') + return mask + + +def soft_nms(df, alpha, t1, t2): + ''' + df: proposals generated by network; + alpha: alpha value of Gaussian decaying function; + t1, t2: threshold for soft nms. + ''' + df = df.sort_values(by="score", ascending=False) + tstart = list(df.xmin.values[:]) + tend = list(df.xmax.values[:]) + tscore = list(df.score.values[:]) + + rstart = [] + rend = [] + rscore = [] + + while len(tscore) > 1 and len(rscore) < 101: + max_index = tscore.index(max(tscore)) + tmp_iou_list = iou_with_anchors( + np.array(tstart), + np.array(tend), tstart[max_index], tend[max_index]) + for idx in range(0, len(tscore)): + if idx != max_index: + tmp_iou = tmp_iou_list[idx] + tmp_width = tend[max_index] - tstart[max_index] + if tmp_iou > t1 + (t2 - t1) * tmp_width: + tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) / + alpha) + + rstart.append(tstart[max_index]) + rend.append(tend[max_index]) + rscore.append(tscore[max_index]) + tstart.pop(max_index) + tend.pop(max_index) + tscore.pop(max_index) + + newDf = pd.DataFrame() + newDf['score'] = rscore + newDf['xmin'] = rstart + newDf['xmax'] = rend + return newDf + + +def video_process(video_list, + video_dict, + output_path, + result_dict, + snms_alpha=0.4, + snms_t1=0.55, + snms_t2=0.9): + + for video_name in video_list: + print("Processing video........" + video_name) + df = pd.read_csv(os.path.join(output_path, video_name + ".csv")) + if len(df) > 1: + df = soft_nms(df, snms_alpha, snms_t1, snms_t2) + + video_duration = video_dict[video_name]["duration_second"] + proposal_list = [] + for idx in range(min(100, len(df))): + tmp_prop={"score":df.score.values[idx], \ + "segment":[max(0,df.xmin.values[idx])*video_duration, \ + min(1,df.xmax.values[idx])*video_duration]} + proposal_list.append(tmp_prop) + result_dict[video_name[2:]] = proposal_list + + +def bmn_post_processing(video_dict, subset, output_path, result_path): + video_list = video_dict.keys() + video_list = list(video_dict.keys()) + global result_dict + result_dict = mp.Manager().dict() + pp_num = 12 + + num_videos = len(video_list) + num_videos_per_thread = int(num_videos / pp_num) + processes = [] + for tid in range(pp_num - 1): + tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) * + num_videos_per_thread] + p = mp.Process( + target=video_process, + args=(tmp_video_list, video_dict, output_path, result_dict)) + p.start() + processes.append(p) + tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:] + p = mp.Process( + target=video_process, + args=(tmp_video_list, video_dict, output_path, result_dict)) + p.start() + processes.append(p) + for p in processes: + p.join() + + result_dict = dict(result_dict) + output_dict = { + "version": "VERSION 1.3", + "results": result_dict, + "external_data": {} + } + outfile = open( + os.path.join(result_path, "bmn_results_%s.json" % subset), "w") + + json.dump(output_dict, outfile) + outfile.close() + + +def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample, + num_sample_perbin): + """ generate sample mask for a boundary-matching pair """ + plen = float(seg_xmax - seg_xmin) + plen_sample = plen / (num_sample * num_sample_perbin - 1.0) + total_samples = [ + seg_xmin + plen_sample * ii + for ii in range(num_sample * num_sample_perbin) + ] + p_mask = [] + for idx in range(num_sample): + bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) * + num_sample_perbin] + bin_vector = np.zeros([tscale]) + for sample in bin_samples: + sample_upper = math.ceil(sample) + sample_decimal, sample_down = math.modf(sample) + if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0: + bin_vector[int(sample_down)] += 1 - sample_decimal + if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0: + bin_vector[int(sample_upper)] += sample_decimal + bin_vector = 1.0 / num_sample_perbin * bin_vector + p_mask.append(bin_vector) + p_mask = np.stack(p_mask, axis=1) + return p_mask + + +def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample, + num_sample_perbin): + """ generate sample mask for each point in Boundary-Matching Map """ + mask_mat = [] + for start_index in range(tscale): + mask_mat_vector = [] + for duration_index in range(dscale): + if start_index + duration_index < tscale: + p_xmin = start_index + p_xmax = start_index + duration_index + center_len = float(p_xmax - p_xmin) + 1 + sample_xmin = p_xmin - center_len * prop_boundary_ratio + sample_xmax = p_xmax + center_len * prop_boundary_ratio + p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax, + tscale, num_sample, + num_sample_perbin) + else: + p_mask = np.zeros([tscale, num_sample]) + mask_mat_vector.append(p_mask) + mask_mat_vector = np.stack(mask_mat_vector, axis=2) + mask_mat.append(mask_mat_vector) + mask_mat = np.stack(mask_mat, axis=3) + mask_mat = mask_mat.astype(np.float32) + + sample_mask = np.reshape(mask_mat, [tscale, -1]) + return sample_mask diff --git a/bmn/config_utils.py b/bmn/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fbef865781061e3f716de01e76d93f782f55b3a6 --- /dev/null +++ b/bmn/config_utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import yaml +import logging + +logger = logging.getLogger(__name__) + +CONFIG_SECS = [ + 'train', + 'valid', + 'test', + 'infer', +] + + +class AttrDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value + + +def parse_config(cfg_file): + """Load a config file into AttrDict""" + with open(cfg_file, 'r') as fopen: + yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader)) + create_attr_dict(yaml_config) + return yaml_config + + +def create_attr_dict(yaml_config): + from ast import literal_eval + for key, value in yaml_config.items(): + if type(value) is dict: + yaml_config[key] = value = AttrDict(value) + if isinstance(value, str): + try: + value = literal_eval(value) + except BaseException: + pass + if isinstance(value, AttrDict): + create_attr_dict(yaml_config[key]) + else: + yaml_config[key] = value + return + + +def merge_configs(cfg, sec, args_dict): + assert sec in CONFIG_SECS, "invalid config section {}".format(sec) + sec_dict = getattr(cfg, sec.upper()) + for k, v in args_dict.items(): + if v is None: + continue + try: + if hasattr(sec_dict, k): + setattr(sec_dict, k, v) + except: + pass + return cfg + + +def print_configs(cfg, mode): + logger.info("---------------- {:>5} Arguments ----------------".format( + mode)) + for sec, sec_items in cfg.items(): + logger.info("{}:".format(sec)) + for k, v in sec_items.items(): + logger.info(" {}:{}".format(k, v)) + logger.info("-------------------------------------------------") diff --git a/bmn/eval.py b/bmn/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..d25fc5c79d21fd55743def09445db5821e3e93af --- /dev/null +++ b/bmn/eval.py @@ -0,0 +1,129 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import argparse +import os +import sys +import logging +import paddle.fluid as fluid + +sys.path.append('../') + +from model import set_device, Input +from bmn_metric import BmnMetric +from bmn_model import BMN, BmnLoss +from reader import BmnDataset +from config_utils import * + +DATATYPE = 'float32' + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser("BMN test for performance evaluation.") + parser.add_argument( + "-d", + "--dynamic", + default=True, + action='store_true', + help="enable dygraph mode, only support dynamic mode at present time") + parser.add_argument( + '--config_file', + type=str, + default='bmn.yaml', + help='path to config file of model') + parser.add_argument( + '--device', + type=str, + default='gpu', + help='gpu or cpu, default use gpu.') + parser.add_argument( + '--weights', + type=str, + default="checkpoint/final", + help='weight path, None to automatically download weights provided by Paddle.' + ) + parser.add_argument( + '--log_interval', + type=int, + default=1, + help='mini-batch interval to log.') + args = parser.parse_args() + return args + + +# Performance Evaluation +def test_bmn(args): + # only support dynamic mode at present time + device = set_device(args.device) + fluid.enable_dygraph(device) if args.dynamic else None + + config = parse_config(args.config_file) + eval_cfg = merge_configs(config, 'test', vars(args)) + if not os.path.isdir(config.TEST.output_path): + os.makedirs(config.TEST.output_path) + if not os.path.isdir(config.TEST.result_path): + os.makedirs(config.TEST.result_path) + + inputs = [ + Input( + [None, config.MODEL.feat_dim, config.MODEL.tscale], + 'float32', + name='feat_input') + ] + gt_iou_map = Input( + [None, config.MODEL.dscale, config.MODEL.tscale], + 'float32', + name='gt_iou_map') + gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start') + gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end') + video_idx = Input([None, 1], 'int64', name='video_idx') + labels = [gt_iou_map, gt_start, gt_end, video_idx] + + #data + eval_dataset = BmnDataset(eval_cfg, 'test') + + #model + model = BMN(config, args.dynamic) + model.prepare( + loss_function=BmnLoss(config), + metrics=BmnMetric( + config, mode='test'), + inputs=inputs, + labels=labels, + device=device) + + #load checkpoint + if args.weights: + assert os.path.exists(args.weights + '.pdparams'), \ + "Given weight dir {} not exist.".format(args.weights) + logger.info('load test weights from {}'.format(args.weights)) + model.load(args.weights) + + model.evaluate( + eval_data=eval_dataset, + batch_size=eval_cfg.TEST.batch_size, + num_workers=eval_cfg.TEST.num_workers, + log_freq=args.log_interval) + + logger.info("[EVAL] eval finished") + + +if __name__ == '__main__': + args = parse_args() + test_bmn(args) diff --git a/bmn/eval_anet_prop.py b/bmn/eval_anet_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..2a91c6effff8083ad54180c0931585731f2fa6d8 --- /dev/null +++ b/bmn/eval_anet_prop.py @@ -0,0 +1,110 @@ +''' +Calculate AR@N and AUC; +Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git) +''' + +import sys +sys.path.append('./Evaluation') + +from eval_proposal import ANETproposal +import numpy as np +import argparse +import os + +parser = argparse.ArgumentParser("Eval AR vs AN of proposal") +parser.add_argument( + '--eval_file', + type=str, + default='bmn_results_validation.json', + help='name of results file to eval') + + +def run_evaluation(ground_truth_filename, + proposal_filename, + max_avg_nr_proposals=100, + tiou_thresholds=np.linspace(0.5, 0.95, 10), + subset='validation'): + + anet_proposal = ANETproposal( + ground_truth_filename, + proposal_filename, + tiou_thresholds=tiou_thresholds, + max_avg_nr_proposals=max_avg_nr_proposals, + subset=subset, + verbose=True, + check_status=False) + anet_proposal.evaluate() + recall = anet_proposal.recall + average_recall = anet_proposal.avg_recall + average_nr_proposals = anet_proposal.proposals_per_video + + return (average_nr_proposals, average_recall, recall) + + +def plot_metric(average_nr_proposals, + average_recall, + recall, + tiou_thresholds=np.linspace(0.5, 0.95, 10)): + fn_size = 14 + plt.figure(num=None, figsize=(12, 8)) + ax = plt.subplot(1, 1, 1) + + colors = [ + 'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo' + ] + area_under_curve = np.zeros_like(tiou_thresholds) + for i in range(recall.shape[0]): + area_under_curve[i] = np.trapz(recall[i], average_nr_proposals) + + for idx, tiou in enumerate(tiou_thresholds[::2]): + ax.plot( + average_nr_proposals, + recall[2 * idx, :], + color=colors[idx + 1], + label="tiou=[" + str(tiou) + "], area=" + str( + int(area_under_curve[2 * idx] * 100) / 100.), + linewidth=4, + linestyle='--', + marker=None) + + # Plots Average Recall vs Average number of proposals. + ax.plot( + average_nr_proposals, + average_recall, + color=colors[0], + label="tiou = 0.5:0.05:0.95," + " area=" + str( + int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.), + linewidth=4, + linestyle='-', + marker=None) + + handles, labels = ax.get_legend_handles_labels() + ax.legend( + [handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best') + + plt.ylabel('Average Recall', fontsize=fn_size) + plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size) + plt.grid(b=True, which="both") + plt.ylim([0, 1.0]) + plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size) + plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size) + plt.show() + + +if __name__ == "__main__": + args = parser.parse_args() + eval_file = args.eval_file + eval_file_path = os.path.join("evaluate_results", eval_file) + uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation( + "./Evaluation/data/activity_net_1_3_new.json", + eval_file_path, + max_avg_nr_proposals=100, + tiou_thresholds=np.linspace(0.5, 0.95, 10), + subset='validation') + + print("AR@1; AR@5; AR@10; AR@100") + print("%.02f %.02f %.02f %.02f" % + (100 * np.mean(uniform_recall_valid[:, 0]), + 100 * np.mean(uniform_recall_valid[:, 4]), + 100 * np.mean(uniform_recall_valid[:, 9]), + 100 * np.mean(uniform_recall_valid[:, -1]))) diff --git a/bmn/infer.list b/bmn/infer.list new file mode 100644 index 0000000000000000000000000000000000000000..44768f089e70e40913d9787571ae0a7151232558 --- /dev/null +++ b/bmn/infer.list @@ -0,0 +1 @@ +{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}} \ No newline at end of file diff --git a/bmn/predict.py b/bmn/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..e52927b60562425a1f03cfea12ab6cb21e76b3ef --- /dev/null +++ b/bmn/predict.py @@ -0,0 +1,125 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import argparse +import sys +import os +import logging +import paddle.fluid as fluid + +sys.path.append('../') + +from model import set_device, Input +from bmn_metric import BmnMetric +from bmn_model import BMN, BmnLoss +from reader import BmnDataset +from config_utils import * + +DATATYPE = 'float32' + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser("BMN inference.") + parser.add_argument( + "-d", + "--dynamic", + default=True, + action='store_true', + help="enable dygraph mode, only support dynamic mode at present time") + parser.add_argument( + '--config_file', + type=str, + default='bmn.yaml', + help='path to config file of model') + parser.add_argument( + '--device', type=str, default='GPU', help='default use gpu.') + parser.add_argument( + '--weights', + type=str, + default="checkpoint/final", + help='weight path, None to automatically download weights provided by Paddle.' + ) + parser.add_argument( + '--save_dir', + type=str, + default="predict_results/", + help='output dir path, default to use ./predict_results/') + parser.add_argument( + '--log_interval', + type=int, + default=1, + help='mini-batch interval to log.') + args = parser.parse_args() + return args + + +# Prediction +def infer_bmn(args): + # only support dynamic mode at present time + device = set_device(args.device) + fluid.enable_dygraph(device) if args.dynamic else None + + config = parse_config(args.config_file) + infer_cfg = merge_configs(config, 'infer', vars(args)) + + if not os.path.isdir(config.INFER.output_path): + os.makedirs(config.INFER.output_path) + if not os.path.isdir(config.INFER.result_path): + os.makedirs(config.INFER.result_path) + + inputs = [ + Input( + [None, config.MODEL.feat_dim, config.MODEL.tscale], + 'float32', + name='feat_input') + ] + labels = [Input([None, 1], 'int64', name='video_idx')] + + #data + infer_dataset = BmnDataset(infer_cfg, 'infer') + + model = BMN(config, args.dynamic) + model.prepare( + metrics=BmnMetric( + config, mode='infer'), + inputs=inputs, + labels=labels, + device=device) + + # load checkpoint + if args.weights: + assert os.path.exists( + args.weights + + ".pdparams"), "Given weight dir {} not exist.".format(args.weights) + logger.info('load test weights from {}'.format(args.weights)) + model.load(args.weights) + + # here use model.eval instead of model.test, as post process is required in our case + model.evaluate( + eval_data=infer_dataset, + batch_size=infer_cfg.TEST.batch_size, + num_workers=infer_cfg.TEST.num_workers, + log_freq=args.log_interval) + + logger.info("[INFER] infer finished") + + +if __name__ == '__main__': + args = parse_args() + infer_bmn(args) diff --git a/bmn/reader.py b/bmn/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c1da592e932b6ddbd19476e994f4267ae2f927 --- /dev/null +++ b/bmn/reader.py @@ -0,0 +1,157 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import numpy as np +import json +import logging +import os +import sys + +sys.path.append('../') + +from distributed import DistributedBatchSampler +from paddle.fluid.io import Dataset, DataLoader + +logger = logging.getLogger(__name__) + +from config_utils import * +from bmn_utils import iou_with_anchors, ioa_with_anchors + +DATATYPE = "float32" + + +class BmnDataset(Dataset): + def __init__(self, cfg, mode): + self.mode = mode + self.tscale = cfg.MODEL.tscale # 100 + self.dscale = cfg.MODEL.dscale # 100 + self.anno_file = cfg.MODEL.anno_file + self.feat_path = cfg.MODEL.feat_path + self.file_list = cfg.INFER.filelist + self.subset = cfg[mode.upper()]['subset'] + self.tgap = 1. / self.tscale + + self.get_dataset_dict() + self.get_match_map() + + def __getitem__(self, index): + video_name = self.video_list[index] + video_idx = self.video_list.index(video_name) + video_feat = self.load_file(video_name) + if self.mode == 'infer': + return video_feat, video_idx + else: + gt_iou_map, gt_start, gt_end = self.get_video_label(video_name) + if self.mode == 'train' or self.mode == 'valid': + return video_feat, gt_iou_map, gt_start, gt_end + elif self.mode == 'test': + return video_feat, gt_iou_map, gt_start, gt_end, video_idx + + def __len__(self): + return len(self.video_list) + + def get_dataset_dict(self): + assert ( + os.path.exists(self.feat_path)), "Input feature path not exists" + assert (os.listdir(self.feat_path)), "No feature file in feature path" + self.video_dict = {} + if self.mode == "infer": + annos = json.load(open(self.file_list)) + for video_name in annos.keys(): + self.video_dict[video_name] = annos[video_name] + else: + annos = json.load(open(self.anno_file)) + for video_name in annos.keys(): + video_subset = annos[video_name]["subset"] + if self.subset in video_subset: + self.video_dict[video_name] = annos[video_name] + self.video_list = list(self.video_dict.keys()) + self.video_list.sort() + print("%s subset video numbers: %d" % + (self.subset, len(self.video_list))) + video_name_set = set( + [video_name + '.npy' for video_name in self.video_list]) + assert (video_name_set.intersection(set(os.listdir(self.feat_path))) == + video_name_set), "Input feature not exists in feature path" + + def get_match_map(self): + match_map = [] + for idx in range(self.tscale): + tmp_match_window = [] + xmin = self.tgap * idx + for jdx in range(1, self.tscale + 1): + xmax = xmin + self.tgap * jdx + tmp_match_window.append([xmin, xmax]) + match_map.append(tmp_match_window) + match_map = np.array(match_map) + match_map = np.transpose(match_map, [1, 0, 2]) + match_map = np.reshape(match_map, [-1, 2]) + self.match_map = match_map + self.anchor_xmin = [self.tgap * i for i in range(self.tscale)] + self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)] + + def get_video_label(self, video_name): + video_info = self.video_dict[video_name] + video_second = video_info['duration_second'] + video_labels = video_info['annotations'] + + gt_bbox = [] + gt_iou_map = [] + for gt in video_labels: + tmp_start = max(min(1, gt["segment"][0] / video_second), 0) + tmp_end = max(min(1, gt["segment"][1] / video_second), 0) + gt_bbox.append([tmp_start, tmp_end]) + tmp_gt_iou_map = iou_with_anchors( + self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end) + tmp_gt_iou_map = np.reshape(tmp_gt_iou_map, + [self.dscale, self.tscale]) + gt_iou_map.append(tmp_gt_iou_map) + gt_iou_map = np.array(gt_iou_map) + gt_iou_map = np.max(gt_iou_map, axis=0) + + gt_bbox = np.array(gt_bbox) + gt_xmins = gt_bbox[:, 0] + gt_xmaxs = gt_bbox[:, 1] + gt_len_small = 3 * self.tgap + gt_start_bboxs = np.stack( + (gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1) + gt_end_bboxs = np.stack( + (gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1) + + match_score_start = [] + for jdx in range(len(self.anchor_xmin)): + match_score_start.append( + np.max( + ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[ + jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1]))) + match_score_end = [] + for jdx in range(len(self.anchor_xmin)): + match_score_end.append( + np.max( + ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[ + jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1]))) + + gt_start = np.array(match_score_start) + gt_end = np.array(match_score_end) + return gt_iou_map.astype(DATATYPE), gt_start.astype( + DATATYPE), gt_end.astype(DATATYPE) + + def load_file(self, video_name): + file_name = video_name + ".npy" + file_path = os.path.join(self.feat_path, file_name) + video_feat = np.load(file_path) + video_feat = video_feat.T + video_feat = video_feat.astype("float32") + return video_feat diff --git a/bmn/run.sh b/bmn/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..24fd8e3da991c74628f6d345badf7bfe2e67c35d --- /dev/null +++ b/bmn/run.sh @@ -0,0 +1,3 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +python -m paddle.distributed.launch train.py diff --git a/bmn/train.py b/bmn/train.py new file mode 100644 index 0000000000000000000000000000000000000000..fe46f6a607c6ab8f93be45ffeee11478ef862eb6 --- /dev/null +++ b/bmn/train.py @@ -0,0 +1,168 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.fluid as fluid +import argparse +import logging +import sys +import os + +sys.path.append('../') + +from model import set_device, Input +from bmn_model import BMN, BmnLoss +from reader import BmnDataset +from config_utils import * + +DATATYPE = 'float32' + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser("Paddle high level api of BMN.") + parser.add_argument( + "-d", + "--dynamic", + default=True, + action='store_true', + help="enable dygraph mode") + parser.add_argument( + '--config_file', + type=str, + default='bmn.yaml', + help='path to config file of model') + parser.add_argument( + '--batch_size', + type=int, + default=None, + help='training batch size. None to use config file setting.') + parser.add_argument( + '--learning_rate', + type=float, + default=0.001, + help='learning rate use for training. None to use config file setting.') + parser.add_argument( + '--resume', + type=str, + default=None, + help='filename to resume training based on previous checkpoints. ' + 'None for not resuming any checkpoints.') + parser.add_argument( + '--device', + type=str, + default='gpu', + help='gpu or cpu, default use gpu.') + parser.add_argument( + '--epoch', + type=int, + default=9, + help='epoch number, 0 for read from config file') + parser.add_argument( + '--valid_interval', + type=int, + default=1, + help='validation epoch interval, 0 for no validation.') + parser.add_argument( + '--save_dir', + type=str, + default="checkpoint", + help='path to save train snapshoot') + parser.add_argument( + '--log_interval', + type=int, + default=10, + help='mini-batch interval to log.') + args = parser.parse_args() + return args + + +# Optimizer +def optimizer(cfg, parameter_list): + bd = [cfg.TRAIN.lr_decay_iter] + base_lr = cfg.TRAIN.learning_rate + lr_decay = cfg.TRAIN.learning_rate_decay + l2_weight_decay = cfg.TRAIN.l2_weight_decay + lr = [base_lr, base_lr * lr_decay] + optimizer = fluid.optimizer.Adam( + fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + parameter_list=parameter_list, + regularization=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=l2_weight_decay)) + return optimizer + + +# TRAIN +def train_bmn(args): + device = set_device(args.device) + fluid.enable_dygraph(device) if args.dynamic else None + + if not os.path.isdir(args.save_dir): + os.makedirs(args.save_dir) + + config = parse_config(args.config_file) + train_cfg = merge_configs(config, 'train', vars(args)) + val_cfg = merge_configs(config, 'valid', vars(args)) + + inputs = [ + Input( + [None, config.MODEL.feat_dim, config.MODEL.tscale], + 'float32', + name='feat_input') + ] + gt_iou_map = Input( + [None, config.MODEL.dscale, config.MODEL.tscale], + 'float32', + name='gt_iou_map') + gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start') + gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end') + labels = [gt_iou_map, gt_start, gt_end] + + # data + train_dataset = BmnDataset(train_cfg, 'train') + val_dataset = BmnDataset(val_cfg, 'valid') + + # model + model = BMN(config, args.dynamic) + optim = optimizer(config, parameter_list=model.parameters()) + model.prepare( + optimizer=optim, + loss_function=BmnLoss(config), + inputs=inputs, + labels=labels, + device=device) + + # if resume weights is given, load resume weights directly + if args.resume is not None: + model.load(args.resume) + + model.fit(train_data=train_dataset, + eval_data=val_dataset, + batch_size=train_cfg.TRAIN.batch_size, + epochs=args.epoch, + eval_freq=args.valid_interval, + log_freq=args.log_interval, + save_dir=args.save_dir, + shuffle=train_cfg.TRAIN.use_shuffle, + num_workers=train_cfg.TRAIN.num_workers, + drop_last=True) + + +if __name__ == "__main__": + args = parse_args() + train_bmn(args)