diff --git a/bmn/BMN.png b/bmn/BMN.png
new file mode 100644
index 0000000000000000000000000000000000000000..46b9743cf5b5e79ada93ebb6b61da96967f350a4
Binary files /dev/null and b/bmn/BMN.png differ
diff --git a/bmn/README.md b/bmn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dbf593512c07da9645418b92b2a6819c9be13fa5
--- /dev/null
+++ b/bmn/README.md
@@ -0,0 +1,120 @@
+# BMN 视频动作定位模型高层API实现
+
+---
+## 内容
+
+- [模型简介](#模型简介)
+- [代码结构](#代码结构)
+- [数据准备](#数据准备)
+- [模型训练](#模型训练)
+- [模型评估](#模型评估)
+- [模型推断](#模型推断)
+- [参考论文](#参考论文)
+
+
+## 模型简介
+
+BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。
+
+
+
+BMN Overview
+
+
+
+## 代码结构
+```
+├── bmn.yaml # 网络配置文件,快速配置参数
+├── run.sh # 快速运行脚本,可直接开始多卡训练
+├── train.py # 训练代码,训练网络
+├── eval.py # 评估代码,评估网络性能
+├── predict.py # 预测代码,针对任意输入预测结果
+├── bmn_model.py # 网络结构与损失函数定义
+├── bmn_metric.py # 精度评估方法定义
+├── reader.py # 数据reader,构造Dataset和Dataloader
+├── bmn_utils.py # 模型细节相关代码
+├── config_utils.py # 配置细节相关代码
+├── eval_anet_prop.py # 计算精度评估指标
+└── infer.list # 推断文件列表
+```
+
+
+## 数据准备
+
+BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。对应的标签文件请下载[label](https://paddlemodels.bj.bcebos.com/video_detection/activitynet_1.3_annotations.json)并修改bmn.yaml中的标签文件路径anno\_file。
+
+
+## 模型训练
+
+数据准备完成后,可通过如下两种方式启动训练:
+
+默认使用4卡训练,启动方式如下:
+
+ bash run.sh
+
+若使用单卡训练,启动方式如下:
+
+ export CUDA_VISIBLE_DEVICES=0
+ python train.py
+
+- 代码运行需要先安装pandas
+
+- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
+
+- 单卡训练时,请将配置文件中的batch_size调整为16
+
+**训练策略:**
+
+* 采用Adam优化器,初始learning\_rate=0.001
+* 权重衰减系数为1e-4
+* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
+
+
+## 模型评估
+
+训练完成后,可通过如下方式进行模型评估:
+
+ python eval.py --weights=$PATH_TO_WEIGHTS
+
+- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
+
+- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。
+
+- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。
+
+
+使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下:
+
+- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org)
+
+- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:由于第三方评估代码不支持python3,此处建议使用python2进行评估;若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。)
+
+- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn/Evaluation/data目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。
+
+- 计算精度指标
+
+ ```python eval_anet_prop.py```
+
+
+在ActivityNet1.3数据集下评估精度如下:
+
+| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
+| :---: | :---: | :---: | :---: | :---: |
+| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% |
+
+
+## 模型推断
+
+可通过如下方式启动模型推断:
+
+ python predict.py --weights=$PATH_TO_WEIGHTS \
+ --filelist=$FILELIST
+
+- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
+
+- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。
+
+
+## 参考论文
+
+- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
diff --git a/bmn/bmn.yaml b/bmn/bmn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..da50ea4f7c654d40fbf2498863cb7e87664fe55a
--- /dev/null
+++ b/bmn/bmn.yaml
@@ -0,0 +1,50 @@
+MODEL:
+ name: "BMN"
+ tscale: 100
+ dscale: 100
+ feat_dim: 400
+ prop_boundary_ratio: 0.5
+ num_sample: 32
+ num_sample_perbin: 3
+ anno_file: "./activitynet_1.3_annotations.json"
+ feat_path: './fix_feat_100'
+
+TRAIN:
+ subset: "train"
+ epoch: 9
+ batch_size: 4
+ num_workers: 4
+ use_shuffle: True
+ device: "gpu"
+ num_gpus: 4
+ learning_rate: 0.001
+ learning_rate_decay: 0.1
+ lr_decay_iter: 4200
+ l2_weight_decay: 1e-4
+
+VALID:
+ subset: "validation"
+
+TEST:
+ subset: "validation"
+ batch_size: 1
+ num_workers: 1
+ use_buffer: False
+ snms_alpha: 0.001
+ snms_t1: 0.5
+ snms_t2: 0.9
+ output_path: "output/EVAL/BMN_results"
+ result_path: "evaluate_results"
+
+INFER:
+ subset: "test"
+ batch_size: 1
+ num_workers: 1
+ use_buffer: False
+ snms_alpha: 0.4
+ snms_t1: 0.5
+ snms_t2: 0.9
+ filelist: './infer.list'
+ output_path: "output/INFER/BMN_results"
+ result_path: "predict_results"
+
diff --git a/bmn/bmn_metric.py b/bmn/bmn_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..a19f87c6b42b8737ddeb52c3a330f59dcc932004
--- /dev/null
+++ b/bmn/bmn_metric.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import numpy as np
+import pandas as pd
+import os
+import sys
+import json
+
+sys.path.append('../')
+
+from metrics import Metric
+from bmn_utils import boundary_choose, bmn_post_processing
+
+
+class BmnMetric(Metric):
+ """
+ only support update with batch_size=1
+ """
+
+ def __init__(self, cfg, mode):
+ super(BmnMetric, self).__init__()
+ self.cfg = cfg
+ self.mode = mode
+ #get video_dict and video_list
+ if self.mode == 'test':
+ self.get_test_dataset_dict()
+ elif self.mode == 'infer':
+ self.get_infer_dataset_dict()
+
+ def add_metric_op(self, preds, label):
+ pred_bm, pred_start, pred_en = preds
+ video_index = label[-1]
+ return [pred_bm, pred_start, pred_en, video_index] #return list
+
+ def update(self, pred_bm, pred_start, pred_end, fid):
+ # generate proposals
+ pred_start = pred_start[0]
+ pred_end = pred_end[0]
+ fid = fid[0]
+
+ if self.mode == 'infer':
+ output_path = self.cfg.INFER.output_path
+ else:
+ output_path = self.cfg.TEST.output_path
+ tscale = self.cfg.MODEL.tscale
+ dscale = self.cfg.MODEL.dscale
+ snippet_xmins = [1.0 / tscale * i for i in range(tscale)]
+ snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)]
+ cols = ["xmin", "xmax", "score"]
+
+ video_name = self.video_list[fid]
+ pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :]
+ start_mask = boundary_choose(pred_start)
+ start_mask[0] = 1.
+ end_mask = boundary_choose(pred_end)
+ end_mask[-1] = 1.
+ score_vector_list = []
+ for idx in range(dscale):
+ for jdx in range(tscale):
+ start_index = jdx
+ end_index = start_index + idx
+ if end_index < tscale and start_mask[
+ start_index] == 1 and end_mask[end_index] == 1:
+ xmin = snippet_xmins[start_index]
+ xmax = snippet_xmaxs[end_index]
+ xmin_score = pred_start[start_index]
+ xmax_score = pred_end[end_index]
+ bm_score = pred_bm[idx, jdx]
+ conf_score = xmin_score * xmax_score * bm_score
+ score_vector_list.append([xmin, xmax, conf_score])
+
+ score_vector_list = np.stack(score_vector_list)
+ video_df = pd.DataFrame(score_vector_list, columns=cols)
+ video_df.to_csv(
+ os.path.join(output_path, "%s.csv" % video_name), index=False)
+
+ return 0 # result has saved in output path
+
+ def accumulate(self):
+ return 'post_processing is required...' # required method
+
+ def reset(self):
+ print("Post_processing....This may take a while")
+ if self.mode == 'test':
+ bmn_post_processing(self.video_dict, self.cfg.TEST.subset,
+ self.cfg.TEST.output_path,
+ self.cfg.TEST.result_path)
+ elif self.mode == 'infer':
+ bmn_post_processing(self.video_dict, self.cfg.INFER.subset,
+ self.cfg.INFER.output_path,
+ self.cfg.INFER.result_path)
+
+ def name(self):
+ return 'bmn_metric'
+
+ def get_test_dataset_dict(self):
+ anno_file = self.cfg.MODEL.anno_file
+ annos = json.load(open(anno_file))
+ subset = self.cfg.TEST.subset
+ self.video_dict = {}
+ for video_name in annos.keys():
+ video_subset = annos[video_name]["subset"]
+ if subset in video_subset:
+ self.video_dict[video_name] = annos[video_name]
+ self.video_list = list(self.video_dict.keys())
+ self.video_list.sort()
+
+ def get_infer_dataset_dict(self):
+ file_list = self.cfg.INFER.filelist
+ annos = json.load(open(file_list))
+ self.video_dict = {}
+ for video_name in annos.keys():
+ self.video_dict[video_name] = annos[video_name]
+ self.video_list = list(self.video_dict.keys())
+ self.video_list.sort()
diff --git a/bmn/bmn_model.py b/bmn/bmn_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfde7bcc5cdaac8aa5ea3c069f580308b49ec01f
--- /dev/null
+++ b/bmn/bmn_model.py
@@ -0,0 +1,364 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle.fluid as fluid
+from paddle.fluid import ParamAttr
+import numpy as np
+import math
+
+from bmn_utils import get_interp1d_mask
+from model import Model, Loss
+
+DATATYPE = 'float32'
+
+
+# Net
+class Conv1D(fluid.dygraph.Layer):
+ def __init__(self,
+ prefix,
+ num_channels=256,
+ num_filters=256,
+ size_k=3,
+ padding=1,
+ groups=1,
+ act="relu"):
+ super(Conv1D, self).__init__()
+ fan_in = num_channels * size_k * 1
+ k = 1. / math.sqrt(fan_in)
+ param_attr = ParamAttr(
+ name=prefix + "_w",
+ initializer=fluid.initializer.Uniform(
+ low=-k, high=k))
+ bias_attr = ParamAttr(
+ name=prefix + "_b",
+ initializer=fluid.initializer.Uniform(
+ low=-k, high=k))
+
+ self._conv2d = fluid.dygraph.Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=(1, size_k),
+ stride=1,
+ padding=(0, padding),
+ groups=groups,
+ act=act,
+ param_attr=param_attr,
+ bias_attr=bias_attr)
+
+ def forward(self, x):
+ x = fluid.layers.unsqueeze(input=x, axes=[2])
+ x = self._conv2d(x)
+ x = fluid.layers.squeeze(input=x, axes=[2])
+ return x
+
+
+class BMN(Model):
+ def __init__(self, cfg, is_dygraph=True):
+ super(BMN, self).__init__()
+
+ #init config
+ self.tscale = cfg.MODEL.tscale
+ self.dscale = cfg.MODEL.dscale
+ self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
+ self.num_sample = cfg.MODEL.num_sample
+ self.num_sample_perbin = cfg.MODEL.num_sample_perbin
+ self.is_dygraph = is_dygraph
+
+ self.hidden_dim_1d = 256
+ self.hidden_dim_2d = 128
+ self.hidden_dim_3d = 512
+
+ # Base Module
+ self.b_conv1 = Conv1D(
+ prefix="Base_1",
+ num_channels=400,
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.b_conv2 = Conv1D(
+ prefix="Base_2",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+
+ # Temporal Evaluation Module
+ self.ts_conv1 = Conv1D(
+ prefix="TEM_s1",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.ts_conv2 = Conv1D(
+ prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid")
+ self.te_conv1 = Conv1D(
+ prefix="TEM_e1",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.te_conv2 = Conv1D(
+ prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid")
+
+ #Proposal Evaluation Module
+ self.p_conv1 = Conv1D(
+ prefix="PEM_1d",
+ num_filters=self.hidden_dim_2d,
+ size_k=3,
+ padding=1,
+ act="relu")
+
+ # init to speed up
+ sample_mask_array = get_interp1d_mask(
+ self.tscale, self.dscale, self.prop_boundary_ratio,
+ self.num_sample, self.num_sample_perbin)
+ if self.is_dygraph:
+ self.sample_mask = fluid.dygraph.base.to_variable(
+ sample_mask_array)
+ else: # static
+ self.sample_mask = fluid.layers.create_parameter(
+ shape=[
+ self.tscale, self.num_sample * self.dscale * self.tscale
+ ],
+ dtype=DATATYPE,
+ attr=fluid.ParamAttr(
+ name="sample_mask", trainable=False),
+ default_initializer=fluid.initializer.NumpyArrayInitializer(
+ sample_mask_array))
+
+ self.sample_mask.stop_gradient = True
+
+ self.p_conv3d1 = fluid.dygraph.Conv3D(
+ num_channels=128,
+ num_filters=self.hidden_dim_3d,
+ filter_size=(self.num_sample, 1, 1),
+ stride=(self.num_sample, 1, 1),
+ padding=0,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_3d1_w"),
+ bias_attr=ParamAttr(name="PEM_3d1_b"))
+
+ self.p_conv2d1 = fluid.dygraph.Conv2D(
+ num_channels=512,
+ num_filters=self.hidden_dim_2d,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d1_w"),
+ bias_attr=ParamAttr(name="PEM_2d1_b"))
+ self.p_conv2d2 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=self.hidden_dim_2d,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d2_w"),
+ bias_attr=ParamAttr(name="PEM_2d2_b"))
+ self.p_conv2d3 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=self.hidden_dim_2d,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d3_w"),
+ bias_attr=ParamAttr(name="PEM_2d3_b"))
+ self.p_conv2d4 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=2,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act="sigmoid",
+ param_attr=ParamAttr(name="PEM_2d4_w"),
+ bias_attr=ParamAttr(name="PEM_2d4_b"))
+
+ def forward(self, x):
+ #Base Module
+ x = self.b_conv1(x)
+ x = self.b_conv2(x)
+
+ #TEM
+ xs = self.ts_conv1(x)
+ xs = self.ts_conv2(xs)
+ xs = fluid.layers.squeeze(xs, axes=[1])
+ xe = self.te_conv1(x)
+ xe = self.te_conv2(xe)
+ xe = fluid.layers.squeeze(xe, axes=[1])
+
+ #PEM
+ xp = self.p_conv1(x)
+ #BM layer
+ xp = fluid.layers.matmul(xp, self.sample_mask)
+ xp = fluid.layers.reshape(
+ xp, shape=[0, 0, -1, self.dscale, self.tscale])
+
+ xp = self.p_conv3d1(xp)
+ xp = fluid.layers.squeeze(xp, axes=[2])
+ xp = self.p_conv2d1(xp)
+ xp = self.p_conv2d2(xp)
+ xp = self.p_conv2d3(xp)
+ xp = self.p_conv2d4(xp)
+ return xp, xs, xe
+
+
+class BmnLoss(Loss):
+ def __init__(self, cfg):
+ super(BmnLoss, self).__init__()
+ self.cfg = cfg
+
+ def _get_mask(self):
+ dscale = self.cfg.MODEL.dscale
+ tscale = self.cfg.MODEL.tscale
+ bm_mask = []
+ for idx in range(dscale):
+ mask_vector = [1 for i in range(tscale - idx)
+ ] + [0 for i in range(idx)]
+ bm_mask.append(mask_vector)
+ bm_mask = np.array(bm_mask, dtype=np.float32)
+ self_bm_mask = fluid.layers.create_global_var(
+ shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True)
+ fluid.layers.assign(bm_mask, self_bm_mask)
+ self_bm_mask.stop_gradient = True
+ return self_bm_mask
+
+ def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end):
+ def bi_loss(pred_score, gt_label):
+ pred_score = fluid.layers.reshape(
+ x=pred_score, shape=[-1], inplace=False)
+ gt_label = fluid.layers.reshape(
+ x=gt_label, shape=[-1], inplace=False)
+ gt_label.stop_gradient = True
+ pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE)
+ num_entries = fluid.layers.cast(
+ fluid.layers.shape(pmask), dtype=DATATYPE)
+ num_positive = fluid.layers.cast(
+ fluid.layers.reduce_sum(pmask), dtype=DATATYPE)
+ ratio = num_entries / num_positive
+ coef_0 = 0.5 * ratio / (ratio - 1)
+ coef_1 = 0.5 * ratio
+ epsilon = 0.000001
+ temp = fluid.layers.log(pred_score + epsilon)
+ loss_pos = fluid.layers.elementwise_mul(
+ fluid.layers.log(pred_score + epsilon), pmask)
+ loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
+ loss_neg = fluid.layers.elementwise_mul(
+ fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
+ loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
+ loss = -1 * (loss_pos + loss_neg)
+ return loss
+
+ loss_start = bi_loss(pred_start, gt_start)
+ loss_end = bi_loss(pred_end, gt_end)
+ loss = loss_start + loss_end
+ return loss
+
+ def pem_reg_loss_func(self, pred_score, gt_iou_map, mask):
+
+ gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
+
+ u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE)
+ u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3)
+ u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
+ u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.)
+ u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
+ u_lmask = fluid.layers.elementwise_mul(u_lmask, mask)
+
+ num_h = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
+ num_m = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
+ num_l = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
+
+ r_m = num_h / num_m
+ u_smmask = fluid.layers.uniform_random(
+ shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
+ dtype=DATATYPE,
+ min=0.0,
+ max=1.0)
+ u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
+ u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
+
+ r_l = num_h / num_l
+ u_slmask = fluid.layers.uniform_random(
+ shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
+ dtype=DATATYPE,
+ min=0.0,
+ max=1.0)
+ u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
+ u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
+
+ weights = u_hmask + u_smmask + u_slmask
+ weights.stop_gradient = True
+ loss = fluid.layers.square_error_cost(pred_score, gt_iou_map)
+ loss = fluid.layers.elementwise_mul(loss, weights)
+ loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
+ weights)
+
+ return loss
+
+ def pem_cls_loss_func(self, pred_score, gt_iou_map, mask):
+ gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
+ gt_iou_map.stop_gradient = True
+ pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE)
+ nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE)
+ nmask = fluid.layers.elementwise_mul(nmask, mask)
+
+ num_positive = fluid.layers.reduce_sum(pmask)
+ num_entries = num_positive + fluid.layers.reduce_sum(nmask)
+ ratio = num_entries / num_positive
+ coef_0 = 0.5 * ratio / (ratio - 1)
+ coef_1 = 0.5 * ratio
+ epsilon = 0.000001
+ loss_pos = fluid.layers.elementwise_mul(
+ fluid.layers.log(pred_score + epsilon), pmask)
+ loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos)
+ loss_neg = fluid.layers.elementwise_mul(
+ fluid.layers.log(1.0 - pred_score + epsilon), nmask)
+ loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg)
+ loss = -1 * (loss_pos + loss_neg) / num_entries
+ return loss
+
+ def forward(self, outputs, labels):
+ pred_bm, pred_start, pred_end = outputs
+ if len(labels) == 3:
+ gt_iou_map, gt_start, gt_end = labels
+ elif len(labels) == 4: # video_index used in eval mode
+ gt_iou_map, gt_start, gt_end, video_index = labels
+ pred_bm_reg = fluid.layers.squeeze(
+ fluid.layers.slice(
+ pred_bm, axes=[1], starts=[0], ends=[1]),
+ axes=[1])
+ pred_bm_cls = fluid.layers.squeeze(
+ fluid.layers.slice(
+ pred_bm, axes=[1], starts=[1], ends=[2]),
+ axes=[1])
+
+ bm_mask = self._get_mask()
+
+ pem_reg_loss = self.pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask)
+ pem_cls_loss = self.pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask)
+
+ tem_loss = self.tem_loss_func(pred_start, pred_end, gt_start, gt_end)
+
+ loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
+ return loss
diff --git a/bmn/bmn_utils.py b/bmn/bmn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..06812e636fdaf6ccc419ca58151402ab50082112
--- /dev/null
+++ b/bmn/bmn_utils.py
@@ -0,0 +1,217 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import numpy as np
+import pandas as pd
+import multiprocessing as mp
+import json
+import os
+import math
+
+
+def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
+ """Compute jaccard score between a box and the anchors.
+ """
+ len_anchors = anchors_max - anchors_min
+ int_xmin = np.maximum(anchors_min, box_min)
+ int_xmax = np.minimum(anchors_max, box_max)
+ inter_len = np.maximum(int_xmax - int_xmin, 0.)
+ union_len = len_anchors - inter_len + box_max - box_min
+ jaccard = np.divide(inter_len, union_len)
+ return jaccard
+
+
+def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
+ """Compute intersection between score a box and the anchors.
+ """
+ len_anchors = anchors_max - anchors_min
+ int_xmin = np.maximum(anchors_min, box_min)
+ int_xmax = np.minimum(anchors_max, box_max)
+ inter_len = np.maximum(int_xmax - int_xmin, 0.)
+ scores = np.divide(inter_len, len_anchors)
+ return scores
+
+
+def boundary_choose(score_list):
+ max_score = max(score_list)
+ mask_high = (score_list > max_score * 0.5)
+ score_list = list(score_list)
+ score_middle = np.array([0.0] + score_list + [0.0])
+ score_front = np.array([0.0, 0.0] + score_list)
+ score_back = np.array(score_list + [0.0, 0.0])
+ mask_peak = ((score_middle > score_front) & (score_middle > score_back))
+ mask_peak = mask_peak[1:-1]
+ mask = (mask_high | mask_peak).astype('float32')
+ return mask
+
+
+def soft_nms(df, alpha, t1, t2):
+ '''
+ df: proposals generated by network;
+ alpha: alpha value of Gaussian decaying function;
+ t1, t2: threshold for soft nms.
+ '''
+ df = df.sort_values(by="score", ascending=False)
+ tstart = list(df.xmin.values[:])
+ tend = list(df.xmax.values[:])
+ tscore = list(df.score.values[:])
+
+ rstart = []
+ rend = []
+ rscore = []
+
+ while len(tscore) > 1 and len(rscore) < 101:
+ max_index = tscore.index(max(tscore))
+ tmp_iou_list = iou_with_anchors(
+ np.array(tstart),
+ np.array(tend), tstart[max_index], tend[max_index])
+ for idx in range(0, len(tscore)):
+ if idx != max_index:
+ tmp_iou = tmp_iou_list[idx]
+ tmp_width = tend[max_index] - tstart[max_index]
+ if tmp_iou > t1 + (t2 - t1) * tmp_width:
+ tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
+ alpha)
+
+ rstart.append(tstart[max_index])
+ rend.append(tend[max_index])
+ rscore.append(tscore[max_index])
+ tstart.pop(max_index)
+ tend.pop(max_index)
+ tscore.pop(max_index)
+
+ newDf = pd.DataFrame()
+ newDf['score'] = rscore
+ newDf['xmin'] = rstart
+ newDf['xmax'] = rend
+ return newDf
+
+
+def video_process(video_list,
+ video_dict,
+ output_path,
+ result_dict,
+ snms_alpha=0.4,
+ snms_t1=0.55,
+ snms_t2=0.9):
+
+ for video_name in video_list:
+ print("Processing video........" + video_name)
+ df = pd.read_csv(os.path.join(output_path, video_name + ".csv"))
+ if len(df) > 1:
+ df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
+
+ video_duration = video_dict[video_name]["duration_second"]
+ proposal_list = []
+ for idx in range(min(100, len(df))):
+ tmp_prop={"score":df.score.values[idx], \
+ "segment":[max(0,df.xmin.values[idx])*video_duration, \
+ min(1,df.xmax.values[idx])*video_duration]}
+ proposal_list.append(tmp_prop)
+ result_dict[video_name[2:]] = proposal_list
+
+
+def bmn_post_processing(video_dict, subset, output_path, result_path):
+ video_list = video_dict.keys()
+ video_list = list(video_dict.keys())
+ global result_dict
+ result_dict = mp.Manager().dict()
+ pp_num = 12
+
+ num_videos = len(video_list)
+ num_videos_per_thread = int(num_videos / pp_num)
+ processes = []
+ for tid in range(pp_num - 1):
+ tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
+ num_videos_per_thread]
+ p = mp.Process(
+ target=video_process,
+ args=(tmp_video_list, video_dict, output_path, result_dict))
+ p.start()
+ processes.append(p)
+ tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
+ p = mp.Process(
+ target=video_process,
+ args=(tmp_video_list, video_dict, output_path, result_dict))
+ p.start()
+ processes.append(p)
+ for p in processes:
+ p.join()
+
+ result_dict = dict(result_dict)
+ output_dict = {
+ "version": "VERSION 1.3",
+ "results": result_dict,
+ "external_data": {}
+ }
+ outfile = open(
+ os.path.join(result_path, "bmn_results_%s.json" % subset), "w")
+
+ json.dump(output_dict, outfile)
+ outfile.close()
+
+
+def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample,
+ num_sample_perbin):
+ """ generate sample mask for a boundary-matching pair """
+ plen = float(seg_xmax - seg_xmin)
+ plen_sample = plen / (num_sample * num_sample_perbin - 1.0)
+ total_samples = [
+ seg_xmin + plen_sample * ii
+ for ii in range(num_sample * num_sample_perbin)
+ ]
+ p_mask = []
+ for idx in range(num_sample):
+ bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) *
+ num_sample_perbin]
+ bin_vector = np.zeros([tscale])
+ for sample in bin_samples:
+ sample_upper = math.ceil(sample)
+ sample_decimal, sample_down = math.modf(sample)
+ if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0:
+ bin_vector[int(sample_down)] += 1 - sample_decimal
+ if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0:
+ bin_vector[int(sample_upper)] += sample_decimal
+ bin_vector = 1.0 / num_sample_perbin * bin_vector
+ p_mask.append(bin_vector)
+ p_mask = np.stack(p_mask, axis=1)
+ return p_mask
+
+
+def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample,
+ num_sample_perbin):
+ """ generate sample mask for each point in Boundary-Matching Map """
+ mask_mat = []
+ for start_index in range(tscale):
+ mask_mat_vector = []
+ for duration_index in range(dscale):
+ if start_index + duration_index < tscale:
+ p_xmin = start_index
+ p_xmax = start_index + duration_index
+ center_len = float(p_xmax - p_xmin) + 1
+ sample_xmin = p_xmin - center_len * prop_boundary_ratio
+ sample_xmax = p_xmax + center_len * prop_boundary_ratio
+ p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax,
+ tscale, num_sample,
+ num_sample_perbin)
+ else:
+ p_mask = np.zeros([tscale, num_sample])
+ mask_mat_vector.append(p_mask)
+ mask_mat_vector = np.stack(mask_mat_vector, axis=2)
+ mask_mat.append(mask_mat_vector)
+ mask_mat = np.stack(mask_mat, axis=3)
+ mask_mat = mask_mat.astype(np.float32)
+
+ sample_mask = np.reshape(mask_mat, [tscale, -1])
+ return sample_mask
diff --git a/bmn/config_utils.py b/bmn/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbef865781061e3f716de01e76d93f782f55b3a6
--- /dev/null
+++ b/bmn/config_utils.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import yaml
+import logging
+
+logger = logging.getLogger(__name__)
+
+CONFIG_SECS = [
+ 'train',
+ 'valid',
+ 'test',
+ 'infer',
+]
+
+
+class AttrDict(dict):
+ def __getattr__(self, key):
+ return self[key]
+
+ def __setattr__(self, key, value):
+ if key in self.__dict__:
+ self.__dict__[key] = value
+ else:
+ self[key] = value
+
+
+def parse_config(cfg_file):
+ """Load a config file into AttrDict"""
+ with open(cfg_file, 'r') as fopen:
+ yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
+ create_attr_dict(yaml_config)
+ return yaml_config
+
+
+def create_attr_dict(yaml_config):
+ from ast import literal_eval
+ for key, value in yaml_config.items():
+ if type(value) is dict:
+ yaml_config[key] = value = AttrDict(value)
+ if isinstance(value, str):
+ try:
+ value = literal_eval(value)
+ except BaseException:
+ pass
+ if isinstance(value, AttrDict):
+ create_attr_dict(yaml_config[key])
+ else:
+ yaml_config[key] = value
+ return
+
+
+def merge_configs(cfg, sec, args_dict):
+ assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
+ sec_dict = getattr(cfg, sec.upper())
+ for k, v in args_dict.items():
+ if v is None:
+ continue
+ try:
+ if hasattr(sec_dict, k):
+ setattr(sec_dict, k, v)
+ except:
+ pass
+ return cfg
+
+
+def print_configs(cfg, mode):
+ logger.info("---------------- {:>5} Arguments ----------------".format(
+ mode))
+ for sec, sec_items in cfg.items():
+ logger.info("{}:".format(sec))
+ for k, v in sec_items.items():
+ logger.info(" {}:{}".format(k, v))
+ logger.info("-------------------------------------------------")
diff --git a/bmn/eval.py b/bmn/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25fc5c79d21fd55743def09445db5821e3e93af
--- /dev/null
+++ b/bmn/eval.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import argparse
+import os
+import sys
+import logging
+import paddle.fluid as fluid
+
+sys.path.append('../')
+
+from model import set_device, Input
+from bmn_metric import BmnMetric
+from bmn_model import BMN, BmnLoss
+from reader import BmnDataset
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("BMN test for performance evaluation.")
+ parser.add_argument(
+ "-d",
+ "--dynamic",
+ default=True,
+ action='store_true',
+ help="enable dygraph mode, only support dynamic mode at present time")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='gpu',
+ help='gpu or cpu, default use gpu.')
+ parser.add_argument(
+ '--weights',
+ type=str,
+ default="checkpoint/final",
+ help='weight path, None to automatically download weights provided by Paddle.'
+ )
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=1,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+# Performance Evaluation
+def test_bmn(args):
+ # only support dynamic mode at present time
+ device = set_device(args.device)
+ fluid.enable_dygraph(device) if args.dynamic else None
+
+ config = parse_config(args.config_file)
+ eval_cfg = merge_configs(config, 'test', vars(args))
+ if not os.path.isdir(config.TEST.output_path):
+ os.makedirs(config.TEST.output_path)
+ if not os.path.isdir(config.TEST.result_path):
+ os.makedirs(config.TEST.result_path)
+
+ inputs = [
+ Input(
+ [None, config.MODEL.feat_dim, config.MODEL.tscale],
+ 'float32',
+ name='feat_input')
+ ]
+ gt_iou_map = Input(
+ [None, config.MODEL.dscale, config.MODEL.tscale],
+ 'float32',
+ name='gt_iou_map')
+ gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
+ gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
+ video_idx = Input([None, 1], 'int64', name='video_idx')
+ labels = [gt_iou_map, gt_start, gt_end, video_idx]
+
+ #data
+ eval_dataset = BmnDataset(eval_cfg, 'test')
+
+ #model
+ model = BMN(config, args.dynamic)
+ model.prepare(
+ loss_function=BmnLoss(config),
+ metrics=BmnMetric(
+ config, mode='test'),
+ inputs=inputs,
+ labels=labels,
+ device=device)
+
+ #load checkpoint
+ if args.weights:
+ assert os.path.exists(args.weights + '.pdparams'), \
+ "Given weight dir {} not exist.".format(args.weights)
+ logger.info('load test weights from {}'.format(args.weights))
+ model.load(args.weights)
+
+ model.evaluate(
+ eval_data=eval_dataset,
+ batch_size=eval_cfg.TEST.batch_size,
+ num_workers=eval_cfg.TEST.num_workers,
+ log_freq=args.log_interval)
+
+ logger.info("[EVAL] eval finished")
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ test_bmn(args)
diff --git a/bmn/eval_anet_prop.py b/bmn/eval_anet_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a91c6effff8083ad54180c0931585731f2fa6d8
--- /dev/null
+++ b/bmn/eval_anet_prop.py
@@ -0,0 +1,110 @@
+'''
+Calculate AR@N and AUC;
+Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)
+'''
+
+import sys
+sys.path.append('./Evaluation')
+
+from eval_proposal import ANETproposal
+import numpy as np
+import argparse
+import os
+
+parser = argparse.ArgumentParser("Eval AR vs AN of proposal")
+parser.add_argument(
+ '--eval_file',
+ type=str,
+ default='bmn_results_validation.json',
+ help='name of results file to eval')
+
+
+def run_evaluation(ground_truth_filename,
+ proposal_filename,
+ max_avg_nr_proposals=100,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
+ subset='validation'):
+
+ anet_proposal = ANETproposal(
+ ground_truth_filename,
+ proposal_filename,
+ tiou_thresholds=tiou_thresholds,
+ max_avg_nr_proposals=max_avg_nr_proposals,
+ subset=subset,
+ verbose=True,
+ check_status=False)
+ anet_proposal.evaluate()
+ recall = anet_proposal.recall
+ average_recall = anet_proposal.avg_recall
+ average_nr_proposals = anet_proposal.proposals_per_video
+
+ return (average_nr_proposals, average_recall, recall)
+
+
+def plot_metric(average_nr_proposals,
+ average_recall,
+ recall,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10)):
+ fn_size = 14
+ plt.figure(num=None, figsize=(12, 8))
+ ax = plt.subplot(1, 1, 1)
+
+ colors = [
+ 'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo'
+ ]
+ area_under_curve = np.zeros_like(tiou_thresholds)
+ for i in range(recall.shape[0]):
+ area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
+
+ for idx, tiou in enumerate(tiou_thresholds[::2]):
+ ax.plot(
+ average_nr_proposals,
+ recall[2 * idx, :],
+ color=colors[idx + 1],
+ label="tiou=[" + str(tiou) + "], area=" + str(
+ int(area_under_curve[2 * idx] * 100) / 100.),
+ linewidth=4,
+ linestyle='--',
+ marker=None)
+
+ # Plots Average Recall vs Average number of proposals.
+ ax.plot(
+ average_nr_proposals,
+ average_recall,
+ color=colors[0],
+ label="tiou = 0.5:0.05:0.95," + " area=" + str(
+ int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.),
+ linewidth=4,
+ linestyle='-',
+ marker=None)
+
+ handles, labels = ax.get_legend_handles_labels()
+ ax.legend(
+ [handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
+
+ plt.ylabel('Average Recall', fontsize=fn_size)
+ plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
+ plt.grid(b=True, which="both")
+ plt.ylim([0, 1.0])
+ plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
+ plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
+ plt.show()
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ eval_file = args.eval_file
+ eval_file_path = os.path.join("evaluate_results", eval_file)
+ uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
+ "./Evaluation/data/activity_net_1_3_new.json",
+ eval_file_path,
+ max_avg_nr_proposals=100,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
+ subset='validation')
+
+ print("AR@1; AR@5; AR@10; AR@100")
+ print("%.02f %.02f %.02f %.02f" %
+ (100 * np.mean(uniform_recall_valid[:, 0]),
+ 100 * np.mean(uniform_recall_valid[:, 4]),
+ 100 * np.mean(uniform_recall_valid[:, 9]),
+ 100 * np.mean(uniform_recall_valid[:, -1])))
diff --git a/bmn/infer.list b/bmn/infer.list
new file mode 100644
index 0000000000000000000000000000000000000000..44768f089e70e40913d9787571ae0a7151232558
--- /dev/null
+++ b/bmn/infer.list
@@ -0,0 +1 @@
+{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}}
\ No newline at end of file
diff --git a/bmn/predict.py b/bmn/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..e52927b60562425a1f03cfea12ab6cb21e76b3ef
--- /dev/null
+++ b/bmn/predict.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import argparse
+import sys
+import os
+import logging
+import paddle.fluid as fluid
+
+sys.path.append('../')
+
+from model import set_device, Input
+from bmn_metric import BmnMetric
+from bmn_model import BMN, BmnLoss
+from reader import BmnDataset
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("BMN inference.")
+ parser.add_argument(
+ "-d",
+ "--dynamic",
+ default=True,
+ action='store_true',
+ help="enable dygraph mode, only support dynamic mode at present time")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--device', type=str, default='GPU', help='default use gpu.')
+ parser.add_argument(
+ '--weights',
+ type=str,
+ default="checkpoint/final",
+ help='weight path, None to automatically download weights provided by Paddle.'
+ )
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default="predict_results/",
+ help='output dir path, default to use ./predict_results/')
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=1,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+# Prediction
+def infer_bmn(args):
+ # only support dynamic mode at present time
+ device = set_device(args.device)
+ fluid.enable_dygraph(device) if args.dynamic else None
+
+ config = parse_config(args.config_file)
+ infer_cfg = merge_configs(config, 'infer', vars(args))
+
+ if not os.path.isdir(config.INFER.output_path):
+ os.makedirs(config.INFER.output_path)
+ if not os.path.isdir(config.INFER.result_path):
+ os.makedirs(config.INFER.result_path)
+
+ inputs = [
+ Input(
+ [None, config.MODEL.feat_dim, config.MODEL.tscale],
+ 'float32',
+ name='feat_input')
+ ]
+ labels = [Input([None, 1], 'int64', name='video_idx')]
+
+ #data
+ infer_dataset = BmnDataset(infer_cfg, 'infer')
+
+ model = BMN(config, args.dynamic)
+ model.prepare(
+ metrics=BmnMetric(
+ config, mode='infer'),
+ inputs=inputs,
+ labels=labels,
+ device=device)
+
+ # load checkpoint
+ if args.weights:
+ assert os.path.exists(
+ args.weights +
+ ".pdparams"), "Given weight dir {} not exist.".format(args.weights)
+ logger.info('load test weights from {}'.format(args.weights))
+ model.load(args.weights)
+
+ # here use model.eval instead of model.test, as post process is required in our case
+ model.evaluate(
+ eval_data=infer_dataset,
+ batch_size=infer_cfg.TEST.batch_size,
+ num_workers=infer_cfg.TEST.num_workers,
+ log_freq=args.log_interval)
+
+ logger.info("[INFER] infer finished")
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ infer_bmn(args)
diff --git a/bmn/reader.py b/bmn/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c1da592e932b6ddbd19476e994f4267ae2f927
--- /dev/null
+++ b/bmn/reader.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import numpy as np
+import json
+import logging
+import os
+import sys
+
+sys.path.append('../')
+
+from distributed import DistributedBatchSampler
+from paddle.fluid.io import Dataset, DataLoader
+
+logger = logging.getLogger(__name__)
+
+from config_utils import *
+from bmn_utils import iou_with_anchors, ioa_with_anchors
+
+DATATYPE = "float32"
+
+
+class BmnDataset(Dataset):
+ def __init__(self, cfg, mode):
+ self.mode = mode
+ self.tscale = cfg.MODEL.tscale # 100
+ self.dscale = cfg.MODEL.dscale # 100
+ self.anno_file = cfg.MODEL.anno_file
+ self.feat_path = cfg.MODEL.feat_path
+ self.file_list = cfg.INFER.filelist
+ self.subset = cfg[mode.upper()]['subset']
+ self.tgap = 1. / self.tscale
+
+ self.get_dataset_dict()
+ self.get_match_map()
+
+ def __getitem__(self, index):
+ video_name = self.video_list[index]
+ video_idx = self.video_list.index(video_name)
+ video_feat = self.load_file(video_name)
+ if self.mode == 'infer':
+ return video_feat, video_idx
+ else:
+ gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
+ if self.mode == 'train' or self.mode == 'valid':
+ return video_feat, gt_iou_map, gt_start, gt_end
+ elif self.mode == 'test':
+ return video_feat, gt_iou_map, gt_start, gt_end, video_idx
+
+ def __len__(self):
+ return len(self.video_list)
+
+ def get_dataset_dict(self):
+ assert (
+ os.path.exists(self.feat_path)), "Input feature path not exists"
+ assert (os.listdir(self.feat_path)), "No feature file in feature path"
+ self.video_dict = {}
+ if self.mode == "infer":
+ annos = json.load(open(self.file_list))
+ for video_name in annos.keys():
+ self.video_dict[video_name] = annos[video_name]
+ else:
+ annos = json.load(open(self.anno_file))
+ for video_name in annos.keys():
+ video_subset = annos[video_name]["subset"]
+ if self.subset in video_subset:
+ self.video_dict[video_name] = annos[video_name]
+ self.video_list = list(self.video_dict.keys())
+ self.video_list.sort()
+ print("%s subset video numbers: %d" %
+ (self.subset, len(self.video_list)))
+ video_name_set = set(
+ [video_name + '.npy' for video_name in self.video_list])
+ assert (video_name_set.intersection(set(os.listdir(self.feat_path))) ==
+ video_name_set), "Input feature not exists in feature path"
+
+ def get_match_map(self):
+ match_map = []
+ for idx in range(self.tscale):
+ tmp_match_window = []
+ xmin = self.tgap * idx
+ for jdx in range(1, self.tscale + 1):
+ xmax = xmin + self.tgap * jdx
+ tmp_match_window.append([xmin, xmax])
+ match_map.append(tmp_match_window)
+ match_map = np.array(match_map)
+ match_map = np.transpose(match_map, [1, 0, 2])
+ match_map = np.reshape(match_map, [-1, 2])
+ self.match_map = match_map
+ self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
+ self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
+
+ def get_video_label(self, video_name):
+ video_info = self.video_dict[video_name]
+ video_second = video_info['duration_second']
+ video_labels = video_info['annotations']
+
+ gt_bbox = []
+ gt_iou_map = []
+ for gt in video_labels:
+ tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
+ tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
+ gt_bbox.append([tmp_start, tmp_end])
+ tmp_gt_iou_map = iou_with_anchors(
+ self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end)
+ tmp_gt_iou_map = np.reshape(tmp_gt_iou_map,
+ [self.dscale, self.tscale])
+ gt_iou_map.append(tmp_gt_iou_map)
+ gt_iou_map = np.array(gt_iou_map)
+ gt_iou_map = np.max(gt_iou_map, axis=0)
+
+ gt_bbox = np.array(gt_bbox)
+ gt_xmins = gt_bbox[:, 0]
+ gt_xmaxs = gt_bbox[:, 1]
+ gt_len_small = 3 * self.tgap
+ gt_start_bboxs = np.stack(
+ (gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
+ gt_end_bboxs = np.stack(
+ (gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
+
+ match_score_start = []
+ for jdx in range(len(self.anchor_xmin)):
+ match_score_start.append(
+ np.max(
+ ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
+ jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
+ match_score_end = []
+ for jdx in range(len(self.anchor_xmin)):
+ match_score_end.append(
+ np.max(
+ ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
+ jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
+
+ gt_start = np.array(match_score_start)
+ gt_end = np.array(match_score_end)
+ return gt_iou_map.astype(DATATYPE), gt_start.astype(
+ DATATYPE), gt_end.astype(DATATYPE)
+
+ def load_file(self, video_name):
+ file_name = video_name + ".npy"
+ file_path = os.path.join(self.feat_path, file_name)
+ video_feat = np.load(file_path)
+ video_feat = video_feat.T
+ video_feat = video_feat.astype("float32")
+ return video_feat
diff --git a/bmn/run.sh b/bmn/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..24fd8e3da991c74628f6d345badf7bfe2e67c35d
--- /dev/null
+++ b/bmn/run.sh
@@ -0,0 +1,3 @@
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+
+python -m paddle.distributed.launch train.py
diff --git a/bmn/train.py b/bmn/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe46f6a607c6ab8f93be45ffeee11478ef862eb6
--- /dev/null
+++ b/bmn/train.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle.fluid as fluid
+import argparse
+import logging
+import sys
+import os
+
+sys.path.append('../')
+
+from model import set_device, Input
+from bmn_model import BMN, BmnLoss
+from reader import BmnDataset
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("Paddle high level api of BMN.")
+ parser.add_argument(
+ "-d",
+ "--dynamic",
+ default=True,
+ action='store_true',
+ help="enable dygraph mode")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=None,
+ help='training batch size. None to use config file setting.')
+ parser.add_argument(
+ '--learning_rate',
+ type=float,
+ default=0.001,
+ help='learning rate use for training. None to use config file setting.')
+ parser.add_argument(
+ '--resume',
+ type=str,
+ default=None,
+ help='filename to resume training based on previous checkpoints. '
+ 'None for not resuming any checkpoints.')
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='gpu',
+ help='gpu or cpu, default use gpu.')
+ parser.add_argument(
+ '--epoch',
+ type=int,
+ default=9,
+ help='epoch number, 0 for read from config file')
+ parser.add_argument(
+ '--valid_interval',
+ type=int,
+ default=1,
+ help='validation epoch interval, 0 for no validation.')
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default="checkpoint",
+ help='path to save train snapshoot')
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=10,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+# Optimizer
+def optimizer(cfg, parameter_list):
+ bd = [cfg.TRAIN.lr_decay_iter]
+ base_lr = cfg.TRAIN.learning_rate
+ lr_decay = cfg.TRAIN.learning_rate_decay
+ l2_weight_decay = cfg.TRAIN.l2_weight_decay
+ lr = [base_lr, base_lr * lr_decay]
+ optimizer = fluid.optimizer.Adam(
+ fluid.layers.piecewise_decay(
+ boundaries=bd, values=lr),
+ parameter_list=parameter_list,
+ regularization=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=l2_weight_decay))
+ return optimizer
+
+
+# TRAIN
+def train_bmn(args):
+ device = set_device(args.device)
+ fluid.enable_dygraph(device) if args.dynamic else None
+
+ if not os.path.isdir(args.save_dir):
+ os.makedirs(args.save_dir)
+
+ config = parse_config(args.config_file)
+ train_cfg = merge_configs(config, 'train', vars(args))
+ val_cfg = merge_configs(config, 'valid', vars(args))
+
+ inputs = [
+ Input(
+ [None, config.MODEL.feat_dim, config.MODEL.tscale],
+ 'float32',
+ name='feat_input')
+ ]
+ gt_iou_map = Input(
+ [None, config.MODEL.dscale, config.MODEL.tscale],
+ 'float32',
+ name='gt_iou_map')
+ gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
+ gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
+ labels = [gt_iou_map, gt_start, gt_end]
+
+ # data
+ train_dataset = BmnDataset(train_cfg, 'train')
+ val_dataset = BmnDataset(val_cfg, 'valid')
+
+ # model
+ model = BMN(config, args.dynamic)
+ optim = optimizer(config, parameter_list=model.parameters())
+ model.prepare(
+ optimizer=optim,
+ loss_function=BmnLoss(config),
+ inputs=inputs,
+ labels=labels,
+ device=device)
+
+ # if resume weights is given, load resume weights directly
+ if args.resume is not None:
+ model.load(args.resume)
+
+ model.fit(train_data=train_dataset,
+ eval_data=val_dataset,
+ batch_size=train_cfg.TRAIN.batch_size,
+ epochs=args.epoch,
+ eval_freq=args.valid_interval,
+ log_freq=args.log_interval,
+ save_dir=args.save_dir,
+ shuffle=train_cfg.TRAIN.use_shuffle,
+ num_workers=train_cfg.TRAIN.num_workers,
+ drop_last=True)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ train_bmn(args)
diff --git a/model.py b/model.py
index a52cb8299b64af5ef55be34e183362a93734e55e..da5cebd669b83868a46dc1a6b03ad9df8f3b238a 100644
--- a/model.py
+++ b/model.py
@@ -410,7 +410,8 @@ class StaticGraphAdapter(object):
and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
- self.model._optimizer._learning_rate_map[prog] = lr_var
+ new_lr_var = prog.global_block().vars[lr_var.name]
+ self.model._optimizer._learning_rate_map[prog] = new_lr_var
losses = []
metrics = []