diff --git a/dygraph/bmn/README.md b/dygraph/bmn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1408cb5dd746b1308cb09ac703df4ff0e04ef46b
--- /dev/null
+++ b/dygraph/bmn/README.md
@@ -0,0 +1,129 @@
+# BMN 视频动作定位模型动态图实现
+
+---
+## 内容
+
+- [模型简介](#模型简介)
+- [代码结构](#代码结构)
+- [数据准备](#数据准备)
+- [模型训练](#模型训练)
+- [模型评估](#模型评估)
+- [模型推断](#模型推断)
+- [参考论文](#参考论文)
+
+
+## 模型简介
+
+BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。
+
+
+
+BMN Overview
+
+
+动态图文档请参考[Dygraph](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/user_guides/howto/dygraph/DyGraph.html)
+
+
+## 代码结构
+```
+├── bmn.yaml # 网络配置文件,用户可方便的配置参数
+├── run.sh # 快速运行脚本,可直接开始多卡训练
+├── train.py # 训练代码,包含网络结构相关代码
+├── eval.py # 评估代码,评估网络性能
+├── predict.py # 预测代码,针对任意输入预测结果
+├── model.py # 网络结构与损失函数定义
+├── reader.py # 数据reader
+├── eval_anet_prop.py # 计算精度评估指标
+├── bmn_utils.py # 模型细节相关代码
+├── config_utils.py # 配置细节相关代码
+└── infer.list # 推断文件列表
+```
+
+
+## 数据准备
+
+BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。
+
+
+## 模型训练
+
+数据准备完成后,可通过如下两种方式启动训练:
+
+默认使用4卡训练,启动方式如下:
+
+ bash run.sh
+
+若使用单卡训练,启动方式如下:
+
+ export CUDA_VISIBLE_DEVICES=0
+ python train.py
+
+- 代码运行需要先安装pandas
+
+- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
+
+**训练策略:**
+
+* 采用Adam优化器,初始learning\_rate=0.001
+* 权重衰减系数为1e-4
+* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
+
+- 下面的表格列出了此模型训练的大致时长(单位:分钟),使用的GPU型号为P40,CUDA版本8.0,cudnn版本7.2
+
+| | 单卡 | 4卡 |
+| :---: | :---: | :---: |
+| 静态图 | 79 | 27 |
+| 动态图 | 98 | 31 |
+
+## 模型评估
+
+训练完成后,可通过如下方式进行模型评估:
+
+ python eval.py --weights=$PATH_TO_WEIGHTS
+
+- 进行评估时,可修改脚本中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/bmn\_paddle\_dy\_final.pdparams。
+
+- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。
+
+- 使用CPU进行评估时,请将上面的命令行`use_gpu`设置为False。
+
+- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。
+
+
+使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下:
+
+- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org)
+
+- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。)
+
+- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。
+
+- 计算精度指标
+
+ ```python eval_anet_prop.py```
+
+
+在ActivityNet1.3数据集下评估精度如下:
+
+| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
+| :---: | :---: | :---: | :---: | :---: |
+| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% |
+
+
+## 模型推断
+
+可通过如下方式启动模型推断:
+
+ python predict.py --weights=$PATH_TO_WEIGHTS \
+ --filelist=$FILELIST
+
+- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/bmn\_paddle\_dy\_final.pdparams。
+
+- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。
+
+- 使用CPU进行推断时,请将命令行中的`use_gpu`设置为False
+
+
+## 参考论文
+
+- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
diff --git a/dygraph/bmn/bmn.yaml b/dygraph/bmn/bmn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e16fa53bc01c4a94fcfb9781091443fa4e55775c
--- /dev/null
+++ b/dygraph/bmn/bmn.yaml
@@ -0,0 +1,50 @@
+MODEL:
+ name: "BMN"
+ tscale: 100
+ dscale: 100
+ feat_dim: 400
+ prop_boundary_ratio: 0.5
+ num_sample: 32
+ num_sample_perbin: 3
+ anno_file: "../../PaddleCV/PaddleVideo/data/dataset/bmn/activitynet_1.3_annotations.json"
+ feat_path: './fix_feat_100'
+
+TRAIN:
+ subset: "train"
+ epoch: 9
+ batch_size: 16
+ num_threads: 8
+ use_gpu: True
+ num_gpus: 4
+ learning_rate: 0.001
+ learning_rate_decay: 0.1
+ lr_decay_iter: 4200
+ l2_weight_decay: 1e-4
+
+VALID:
+ subset: "validation"
+ batch_size: 16
+ num_threads: 8
+ use_gpu: True
+ num_gpus: 4
+
+TEST:
+ subset: "validation"
+ batch_size: 1
+ num_threads: 1
+ snms_alpha: 0.001
+ snms_t1: 0.5
+ snms_t2: 0.9
+ output_path: "output/EVAL/BMN_results"
+ result_path: "evaluate_results"
+
+INFER:
+ subset: "test"
+ batch_size: 1
+ num_threads: 1
+ snms_alpha: 0.4
+ snms_t1: 0.5
+ snms_t2: 0.9
+ filelist: './infer.list'
+ output_path: "output/INFER/BMN_results"
+ result_path: "predict_results"
diff --git a/dygraph/bmn/bmn_utils.py b/dygraph/bmn/bmn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a960841e6c97d3484f870be260e28ca123e566
--- /dev/null
+++ b/dygraph/bmn/bmn_utils.py
@@ -0,0 +1,217 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import numpy as np
+import pandas as pd
+import multiprocessing as mp
+import json
+import os
+import math
+
+
+def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
+ """Compute jaccard score between a box and the anchors.
+ """
+ len_anchors = anchors_max - anchors_min
+ int_xmin = np.maximum(anchors_min, box_min)
+ int_xmax = np.minimum(anchors_max, box_max)
+ inter_len = np.maximum(int_xmax - int_xmin, 0.)
+ union_len = len_anchors - inter_len + box_max - box_min
+ jaccard = np.divide(inter_len, union_len)
+ return jaccard
+
+
+def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
+ """Compute intersection between score a box and the anchors.
+ """
+ len_anchors = anchors_max - anchors_min
+ int_xmin = np.maximum(anchors_min, box_min)
+ int_xmax = np.minimum(anchors_max, box_max)
+ inter_len = np.maximum(int_xmax - int_xmin, 0.)
+ scores = np.divide(inter_len, len_anchors)
+ return scores
+
+
+def boundary_choose(score_list):
+ max_score = max(score_list)
+ mask_high = (score_list > max_score * 0.5)
+ score_list = list(score_list)
+ score_middle = np.array([0.0] + score_list + [0.0])
+ score_front = np.array([0.0, 0.0] + score_list)
+ score_back = np.array(score_list + [0.0, 0.0])
+ mask_peak = ((score_middle > score_front) & (score_middle > score_back))
+ mask_peak = mask_peak[1:-1]
+ mask = (mask_high | mask_peak).astype('float32')
+ return mask
+
+
+def soft_nms(df, alpha, t1, t2):
+ '''
+ df: proposals generated by network;
+ alpha: alpha value of Gaussian decaying function;
+ t1, t2: threshold for soft nms.
+ '''
+ df = df.sort_values(by="score", ascending=False)
+ tstart = list(df.xmin.values[:])
+ tend = list(df.xmax.values[:])
+ tscore = list(df.score.values[:])
+
+ rstart = []
+ rend = []
+ rscore = []
+
+ while len(tscore) > 1 and len(rscore) < 101:
+ max_index = tscore.index(max(tscore))
+ tmp_iou_list = iou_with_anchors(
+ np.array(tstart),
+ np.array(tend), tstart[max_index], tend[max_index])
+ for idx in range(0, len(tscore)):
+ if idx != max_index:
+ tmp_iou = tmp_iou_list[idx]
+ tmp_width = tend[max_index] - tstart[max_index]
+ if tmp_iou > t1 + (t2 - t1) * tmp_width:
+ tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
+ alpha)
+
+ rstart.append(tstart[max_index])
+ rend.append(tend[max_index])
+ rscore.append(tscore[max_index])
+ tstart.pop(max_index)
+ tend.pop(max_index)
+ tscore.pop(max_index)
+
+ newDf = pd.DataFrame()
+ newDf['score'] = rscore
+ newDf['xmin'] = rstart
+ newDf['xmax'] = rend
+ return newDf
+
+
+def video_process(video_list,
+ video_dict,
+ output_path,
+ result_dict,
+ snms_alpha=0.4,
+ snms_t1=0.55,
+ snms_t2=0.9):
+
+ for video_name in video_list:
+ print("Processing video........" + video_name)
+ df = pd.read_csv(os.path.join(output_path, video_name + ".csv"))
+ if len(df) > 1:
+ df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
+
+ video_duration = video_dict[video_name]["duration_second"]
+ proposal_list = []
+ for idx in range(min(100, len(df))):
+ tmp_prop={"score":df.score.values[idx], \
+ "segment":[max(0,df.xmin.values[idx])*video_duration, \
+ min(1,df.xmax.values[idx])*video_duration]}
+ proposal_list.append(tmp_prop)
+ result_dict[video_name[2:]] = proposal_list
+
+
+def bmn_post_processing(video_dict, subset, output_path, result_path):
+ video_list = video_dict.keys()
+ video_list = list(video_dict.keys())
+ global result_dict
+ result_dict = mp.Manager().dict()
+ pp_num = 12
+
+ num_videos = len(video_list)
+ num_videos_per_thread = int(num_videos / pp_num)
+ processes = []
+ for tid in range(pp_num - 1):
+ tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
+ num_videos_per_thread]
+ p = mp.Process(
+ target=video_process,
+ args=(tmp_video_list, video_dict, output_path, result_dict))
+ p.start()
+ processes.append(p)
+ tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
+ p = mp.Process(
+ target=video_process,
+ args=(tmp_video_list, video_dict, output_path, result_dict))
+ p.start()
+ processes.append(p)
+ for p in processes:
+ p.join()
+
+ result_dict = dict(result_dict)
+ output_dict = {
+ "version": "VERSION 1.3",
+ "results": result_dict,
+ "external_data": {}
+ }
+ outfile = open(
+ os.path.join(result_path, "bmn_results_%s.json" % subset), "w")
+
+ json.dump(output_dict, outfile)
+ outfile.close()
+
+
+def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample,
+ num_sample_perbin):
+ """ generate sample mask for a boundary-matching pair """
+ plen = float(seg_xmax - seg_xmin)
+ plen_sample = plen / (num_sample * num_sample_perbin - 1.0)
+ total_samples = [
+ seg_xmin + plen_sample * ii
+ for ii in range(num_sample * num_sample_perbin)
+ ]
+ p_mask = []
+ for idx in range(num_sample):
+ bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) *
+ num_sample_perbin]
+ bin_vector = np.zeros([tscale])
+ for sample in bin_samples:
+ sample_upper = math.ceil(sample)
+ sample_decimal, sample_down = math.modf(sample)
+ if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0:
+ bin_vector[int(sample_down)] += 1 - sample_decimal
+ if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0:
+ bin_vector[int(sample_upper)] += sample_decimal
+ bin_vector = 1.0 / num_sample_perbin * bin_vector
+ p_mask.append(bin_vector)
+ p_mask = np.stack(p_mask, axis=1)
+ return p_mask
+
+
+def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample,
+ num_sample_perbin):
+ """ generate sample mask for each point in Boundary-Matching Map """
+ mask_mat = []
+ for start_index in range(tscale):
+ mask_mat_vector = []
+ for duration_index in range(dscale):
+ if start_index + duration_index < tscale:
+ p_xmin = start_index
+ p_xmax = start_index + duration_index
+ center_len = float(p_xmax - p_xmin) + 1
+ sample_xmin = p_xmin - center_len * prop_boundary_ratio
+ sample_xmax = p_xmax + center_len * prop_boundary_ratio
+ p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax,
+ tscale, num_sample,
+ num_sample_perbin)
+ else:
+ p_mask = np.zeros([tscale, num_sample])
+ mask_mat_vector.append(p_mask)
+ mask_mat_vector = np.stack(mask_mat_vector, axis=2)
+ mask_mat.append(mask_mat_vector)
+ mask_mat = np.stack(mask_mat, axis=3)
+ mask_mat = mask_mat.astype(np.float32)
+
+ sample_mask = np.reshape(mask_mat, [tscale, -1])
+ return sample_mask
diff --git a/dygraph/bmn/config_utils.py b/dygraph/bmn/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb59a25c316d3e435a9f74c453b61e0362f85d38
--- /dev/null
+++ b/dygraph/bmn/config_utils.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import yaml
+import logging
+
+logger = logging.getLogger(__name__)
+
+CONFIG_SECS = [
+ 'train',
+ 'valid',
+ 'test',
+ 'infer',
+]
+
+
+class AttrDict(dict):
+ def __getattr__(self, key):
+ return self[key]
+
+ def __setattr__(self, key, value):
+ if key in self.__dict__:
+ self.__dict__[key] = value
+ else:
+ self[key] = value
+
+
+def parse_config(cfg_file):
+ """Load a config file into AttrDict"""
+ with open(cfg_file, 'r') as fopen:
+ yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
+ create_attr_dict(yaml_config)
+ return yaml_config
+
+
+def create_attr_dict(yaml_config):
+ from ast import literal_eval
+ for key, value in yaml_config.items():
+ if type(value) is dict:
+ yaml_config[key] = value = AttrDict(value)
+ if isinstance(value, str):
+ try:
+ value = literal_eval(value)
+ except BaseException:
+ pass
+ if isinstance(value, AttrDict):
+ create_attr_dict(yaml_config[key])
+ else:
+ yaml_config[key] = value
+ return
+
+
+def merge_configs(cfg, sec, args_dict):
+ assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
+ sec_dict = getattr(cfg, sec.upper())
+ for k, v in args_dict.items():
+ if v is None:
+ continue
+ try:
+ if hasattr(sec_dict, k):
+ setattr(sec_dict, k, v)
+ except:
+ pass
+ return cfg
+
+
+def print_configs(cfg, mode):
+ logger.info("---------------- {:>5} Arguments ----------------".format(
+ mode))
+ for sec, sec_items in cfg.items():
+ logger.info("{}:".format(sec))
+ for k, v in sec_items.items():
+ logger.info(" {}:{}".format(k, v))
+ logger.info("-------------------------------------------------")
diff --git a/dygraph/bmn/eval.py b/dygraph/bmn/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..f96780691a0575d7d668f7114655309552854ffc
--- /dev/null
+++ b/dygraph/bmn/eval.py
@@ -0,0 +1,220 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import paddle.fluid as fluid
+import numpy as np
+import argparse
+import pandas as pd
+import os
+import sys
+import ast
+import json
+import logging
+
+from reader import BMNReader
+from model import BMN, bmn_loss_func
+from bmn_utils import boundary_choose, bmn_post_processing
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("BMN test for performance evaluation.")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=None,
+ help='training batch size. None to use config file setting.')
+ parser.add_argument(
+ '--use_gpu',
+ type=ast.literal_eval,
+ default=True,
+ help='default use gpu.')
+ parser.add_argument(
+ '--weights',
+ type=str,
+ default="checkpoint/bmn_paddle_dy_final",
+ help='weight path, None to automatically download weights provided by Paddle.'
+ )
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default="evaluate_results/",
+ help='output dir path, default to use ./evaluate_results/')
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=1,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+def get_dataset_dict(cfg):
+ anno_file = cfg.MODEL.anno_file
+ annos = json.load(open(anno_file))
+ subset = cfg.TEST.subset
+ video_dict = {}
+ for video_name in annos.keys():
+ video_subset = annos[video_name]["subset"]
+ if subset in video_subset:
+ video_dict[video_name] = annos[video_name]
+ video_list = list(video_dict.keys())
+ video_list.sort()
+ return video_dict, video_list
+
+
+def gen_props(pred_bm, pred_start, pred_end, fid, video_list, cfg, mode='test'):
+ if mode == 'infer':
+ output_path = cfg.INFER.output_path
+ else:
+ output_path = cfg.TEST.output_path
+ tscale = cfg.MODEL.tscale
+ dscale = cfg.MODEL.dscale
+ snippet_xmins = [1.0 / tscale * i for i in range(tscale)]
+ snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)]
+ cols = ["xmin", "xmax", "score"]
+
+ video_name = video_list[fid]
+ pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :]
+ start_mask = boundary_choose(pred_start)
+ start_mask[0] = 1.
+ end_mask = boundary_choose(pred_end)
+ end_mask[-1] = 1.
+ score_vector_list = []
+ for idx in range(dscale):
+ for jdx in range(tscale):
+ start_index = jdx
+ end_index = start_index + idx
+ if end_index < tscale and start_mask[start_index] == 1 and end_mask[
+ end_index] == 1:
+ xmin = snippet_xmins[start_index]
+ xmax = snippet_xmaxs[end_index]
+ xmin_score = pred_start[start_index]
+ xmax_score = pred_end[end_index]
+ bm_score = pred_bm[idx, jdx]
+ conf_score = xmin_score * xmax_score * bm_score
+ score_vector_list.append([xmin, xmax, conf_score])
+
+ score_vector_list = np.stack(score_vector_list)
+ video_df = pd.DataFrame(score_vector_list, columns=cols)
+ video_df.to_csv(
+ os.path.join(output_path, "%s.csv" % video_name), index=False)
+
+
+# Performance Evaluation
+def test_bmn(args):
+ config = parse_config(args.config_file)
+ test_config = merge_configs(config, 'test', vars(args))
+ print_configs(test_config, "Test")
+
+ if not os.path.isdir(test_config.TEST.output_path):
+ os.makedirs(test_config.TEST.output_path)
+ if not os.path.isdir(test_config.TEST.result_path):
+ os.makedirs(test_config.TEST.result_path)
+ place = fluid.CUDAPlace(0)
+ with fluid.dygraph.guard(place):
+ bmn = BMN(test_config)
+
+ # load checkpoint
+ if args.weights:
+ assert os.path.exists(args.weights + '.pdparams'
+ ), "Given weight dir {} not exist.".format(
+ args.weights)
+
+ logger.info('load test weights from {}'.format(args.weights))
+ model_dict, _ = fluid.load_dygraph(args.weights)
+ bmn.set_dict(model_dict)
+
+ reader = BMNReader(mode="test", cfg=test_config)
+ test_reader = reader.create_reader()
+
+ aggr_loss = 0.0
+ aggr_tem_loss = 0.0
+ aggr_pem_reg_loss = 0.0
+ aggr_pem_cls_loss = 0.0
+ aggr_batch_size = 0
+ video_dict, video_list = get_dataset_dict(test_config)
+
+ bmn.eval()
+ for batch_id, data in enumerate(test_reader()):
+ video_feat = np.array([item[0] for item in data]).astype(DATATYPE)
+ gt_iou_map = np.array([item[1] for item in data]).astype(DATATYPE)
+ gt_start = np.array([item[2] for item in data]).astype(DATATYPE)
+ gt_end = np.array([item[3] for item in data]).astype(DATATYPE)
+ video_idx = [item[4] for item in data][0] #batch_size=1 by default
+
+ x_data = fluid.dygraph.base.to_variable(video_feat)
+ gt_iou_map = fluid.dygraph.base.to_variable(gt_iou_map)
+ gt_start = fluid.dygraph.base.to_variable(gt_start)
+ gt_end = fluid.dygraph.base.to_variable(gt_end)
+ gt_iou_map.stop_gradient = True
+ gt_start.stop_gradient = True
+ gt_end.stop_gradient = True
+
+ pred_bm, pred_start, pred_end = bmn(x_data)
+ loss, tem_loss, pem_reg_loss, pem_cls_loss = bmn_loss_func(
+ pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end,
+ test_config)
+
+ pred_bm = pred_bm.numpy()
+ pred_start = pred_start[0].numpy()
+ pred_end = pred_end[0].numpy()
+ aggr_loss += np.mean(loss.numpy())
+ aggr_tem_loss += np.mean(tem_loss.numpy())
+ aggr_pem_reg_loss += np.mean(pem_reg_loss.numpy())
+ aggr_pem_cls_loss += np.mean(pem_cls_loss.numpy())
+ aggr_batch_size += 1
+
+ logger.info("Processing................ batch {}".format(batch_id))
+ gen_props(
+ pred_bm,
+ pred_start,
+ pred_end,
+ video_idx,
+ video_list,
+ test_config,
+ mode='test')
+
+ avg_loss = aggr_loss / aggr_batch_size
+ avg_tem_loss = aggr_tem_loss / aggr_batch_size
+ avg_pem_reg_loss = aggr_pem_reg_loss / aggr_batch_size
+ avg_pem_cls_loss = aggr_pem_cls_loss / aggr_batch_size
+
+ logger.info('[EVAL] \tAvg_oss = {}, \tAvg_tem_loss = {}, \tAvg_pem_reg_loss = {}, \tAvg_pem_cls_loss = {}'.format(
+ '%.04f' % avg_loss, '%.04f' % avg_tem_loss, \
+ '%.04f' % avg_pem_reg_loss, '%.04f' % avg_pem_cls_loss))
+
+ logger.info("Post_processing....This may take a while")
+ bmn_post_processing(video_dict, test_config.TEST.subset,
+ test_config.TEST.output_path,
+ test_config.TEST.result_path)
+ logger.info("[EVAL] eval finished")
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ test_bmn(args)
diff --git a/dygraph/bmn/eval_anet_prop.py b/dygraph/bmn/eval_anet_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a91c6effff8083ad54180c0931585731f2fa6d8
--- /dev/null
+++ b/dygraph/bmn/eval_anet_prop.py
@@ -0,0 +1,110 @@
+'''
+Calculate AR@N and AUC;
+Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)
+'''
+
+import sys
+sys.path.append('./Evaluation')
+
+from eval_proposal import ANETproposal
+import numpy as np
+import argparse
+import os
+
+parser = argparse.ArgumentParser("Eval AR vs AN of proposal")
+parser.add_argument(
+ '--eval_file',
+ type=str,
+ default='bmn_results_validation.json',
+ help='name of results file to eval')
+
+
+def run_evaluation(ground_truth_filename,
+ proposal_filename,
+ max_avg_nr_proposals=100,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
+ subset='validation'):
+
+ anet_proposal = ANETproposal(
+ ground_truth_filename,
+ proposal_filename,
+ tiou_thresholds=tiou_thresholds,
+ max_avg_nr_proposals=max_avg_nr_proposals,
+ subset=subset,
+ verbose=True,
+ check_status=False)
+ anet_proposal.evaluate()
+ recall = anet_proposal.recall
+ average_recall = anet_proposal.avg_recall
+ average_nr_proposals = anet_proposal.proposals_per_video
+
+ return (average_nr_proposals, average_recall, recall)
+
+
+def plot_metric(average_nr_proposals,
+ average_recall,
+ recall,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10)):
+ fn_size = 14
+ plt.figure(num=None, figsize=(12, 8))
+ ax = plt.subplot(1, 1, 1)
+
+ colors = [
+ 'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo'
+ ]
+ area_under_curve = np.zeros_like(tiou_thresholds)
+ for i in range(recall.shape[0]):
+ area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
+
+ for idx, tiou in enumerate(tiou_thresholds[::2]):
+ ax.plot(
+ average_nr_proposals,
+ recall[2 * idx, :],
+ color=colors[idx + 1],
+ label="tiou=[" + str(tiou) + "], area=" + str(
+ int(area_under_curve[2 * idx] * 100) / 100.),
+ linewidth=4,
+ linestyle='--',
+ marker=None)
+
+ # Plots Average Recall vs Average number of proposals.
+ ax.plot(
+ average_nr_proposals,
+ average_recall,
+ color=colors[0],
+ label="tiou = 0.5:0.05:0.95," + " area=" + str(
+ int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.),
+ linewidth=4,
+ linestyle='-',
+ marker=None)
+
+ handles, labels = ax.get_legend_handles_labels()
+ ax.legend(
+ [handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
+
+ plt.ylabel('Average Recall', fontsize=fn_size)
+ plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
+ plt.grid(b=True, which="both")
+ plt.ylim([0, 1.0])
+ plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
+ plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
+ plt.show()
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ eval_file = args.eval_file
+ eval_file_path = os.path.join("evaluate_results", eval_file)
+ uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
+ "./Evaluation/data/activity_net_1_3_new.json",
+ eval_file_path,
+ max_avg_nr_proposals=100,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
+ subset='validation')
+
+ print("AR@1; AR@5; AR@10; AR@100")
+ print("%.02f %.02f %.02f %.02f" %
+ (100 * np.mean(uniform_recall_valid[:, 0]),
+ 100 * np.mean(uniform_recall_valid[:, 4]),
+ 100 * np.mean(uniform_recall_valid[:, 9]),
+ 100 * np.mean(uniform_recall_valid[:, -1])))
diff --git a/dygraph/bmn/infer.list b/dygraph/bmn/infer.list
new file mode 100644
index 0000000000000000000000000000000000000000..44768f089e70e40913d9787571ae0a7151232558
--- /dev/null
+++ b/dygraph/bmn/infer.list
@@ -0,0 +1 @@
+{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}}
\ No newline at end of file
diff --git a/dygraph/bmn/model.py b/dygraph/bmn/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f77e8e0e95bc4e0d397ba7247327c5bf7038c8e4
--- /dev/null
+++ b/dygraph/bmn/model.py
@@ -0,0 +1,339 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid import ParamAttr
+import numpy as np
+import math
+
+from bmn_utils import get_interp1d_mask
+
+DATATYPE = 'float32'
+
+
+# Net
+class Conv1D(fluid.dygraph.Layer):
+ def __init__(self,
+ prefix,
+ num_channels=256,
+ num_filters=256,
+ size_k=3,
+ padding=1,
+ groups=1,
+ act="relu"):
+ super(Conv1D, self).__init__()
+ fan_in = num_channels * size_k * 1
+ k = 1. / math.sqrt(fan_in)
+ param_attr = ParamAttr(
+ name=prefix + "_w",
+ initializer=fluid.initializer.Uniform(
+ low=-k, high=k))
+ bias_attr = ParamAttr(
+ name=prefix + "_b",
+ initializer=fluid.initializer.Uniform(
+ low=-k, high=k))
+
+ self._conv2d = fluid.dygraph.Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=(1, size_k),
+ stride=1,
+ padding=(0, padding),
+ groups=groups,
+ act=act,
+ param_attr=param_attr,
+ bias_attr=bias_attr)
+
+ def forward(self, x):
+ x = fluid.layers.unsqueeze(input=x, axes=[2])
+ x = self._conv2d(x)
+ x = fluid.layers.squeeze(input=x, axes=[2])
+ return x
+
+
+class BMN(fluid.dygraph.Layer):
+ def __init__(self, cfg):
+ super(BMN, self).__init__()
+
+ #init config
+ self.tscale = cfg.MODEL.tscale
+ self.dscale = cfg.MODEL.dscale
+ self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
+ self.num_sample = cfg.MODEL.num_sample
+ self.num_sample_perbin = cfg.MODEL.num_sample_perbin
+
+ self.hidden_dim_1d = 256
+ self.hidden_dim_2d = 128
+ self.hidden_dim_3d = 512
+
+ # Base Module
+ self.b_conv1 = Conv1D(
+ prefix="Base_1",
+ num_channels=400,
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.b_conv2 = Conv1D(
+ prefix="Base_2",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+
+ # Temporal Evaluation Module
+ self.ts_conv1 = Conv1D(
+ prefix="TEM_s1",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.ts_conv2 = Conv1D(
+ prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid")
+ self.te_conv1 = Conv1D(
+ prefix="TEM_e1",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.te_conv2 = Conv1D(
+ prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid")
+
+ #Proposal Evaluation Module
+ self.p_conv1 = Conv1D(
+ prefix="PEM_1d",
+ num_filters=self.hidden_dim_2d,
+ size_k=3,
+ padding=1,
+ act="relu")
+
+ # init to speed up
+ sample_mask = get_interp1d_mask(self.tscale, self.dscale,
+ self.prop_boundary_ratio,
+ self.num_sample, self.num_sample_perbin)
+ self.sample_mask = fluid.dygraph.base.to_variable(sample_mask)
+ self.sample_mask.stop_gradient = True
+
+ self.p_conv3d1 = fluid.dygraph.Conv3D(
+ num_channels=128,
+ num_filters=self.hidden_dim_3d,
+ filter_size=(self.num_sample, 1, 1),
+ stride=(self.num_sample, 1, 1),
+ padding=0,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_3d1_w"),
+ bias_attr=ParamAttr(name="PEM_3d1_b"))
+
+ self.p_conv2d1 = fluid.dygraph.Conv2D(
+ num_channels=512,
+ num_filters=self.hidden_dim_2d,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d1_w"),
+ bias_attr=ParamAttr(name="PEM_2d1_b"))
+ self.p_conv2d2 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=self.hidden_dim_2d,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d2_w"),
+ bias_attr=ParamAttr(name="PEM_2d2_b"))
+ self.p_conv2d3 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=self.hidden_dim_2d,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d3_w"),
+ bias_attr=ParamAttr(name="PEM_2d3_b"))
+ self.p_conv2d4 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=2,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act="sigmoid",
+ param_attr=ParamAttr(name="PEM_2d4_w"),
+ bias_attr=ParamAttr(name="PEM_2d4_b"))
+
+ def forward(self, x):
+ #Base Module
+ x = self.b_conv1(x)
+ x = self.b_conv2(x)
+
+ #TEM
+ xs = self.ts_conv1(x)
+ xs = self.ts_conv2(xs)
+ xs = fluid.layers.squeeze(xs, axes=[1])
+ xe = self.te_conv1(x)
+ xe = self.te_conv2(xe)
+ xe = fluid.layers.squeeze(xe, axes=[1])
+
+ #PEM
+ xp = self.p_conv1(x)
+ #BM layer
+ xp = fluid.layers.matmul(xp, self.sample_mask)
+ xp = fluid.layers.reshape(
+ xp, shape=[0, 0, -1, self.dscale, self.tscale])
+
+ xp = self.p_conv3d1(xp)
+ xp = fluid.layers.squeeze(xp, axes=[2])
+ xp = self.p_conv2d1(xp)
+ xp = self.p_conv2d2(xp)
+ xp = self.p_conv2d3(xp)
+ xp = self.p_conv2d4(xp)
+ return xp, xs, xe
+
+
+def bmn_loss_func(pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end,
+ cfg):
+ def _get_mask(cfg):
+ dscale = cfg.MODEL.dscale
+ tscale = cfg.MODEL.tscale
+ bm_mask = []
+ for idx in range(dscale):
+ mask_vector = [1 for i in range(tscale - idx)
+ ] + [0 for i in range(idx)]
+ bm_mask.append(mask_vector)
+ bm_mask = np.array(bm_mask, dtype=np.float32)
+ self_bm_mask = fluid.layers.create_global_var(
+ shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True)
+ fluid.layers.assign(bm_mask, self_bm_mask)
+ self_bm_mask.stop_gradient = True
+ return self_bm_mask
+
+ def tem_loss_func(pred_start, pred_end, gt_start, gt_end):
+ def bi_loss(pred_score, gt_label):
+ pred_score = fluid.layers.reshape(
+ x=pred_score, shape=[-1], inplace=False)
+ gt_label = fluid.layers.reshape(
+ x=gt_label, shape=[-1], inplace=False)
+ gt_label.stop_gradient = True
+ pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE)
+ num_entries = fluid.layers.cast(
+ fluid.layers.shape(pmask), dtype=DATATYPE)
+ num_positive = fluid.layers.cast(
+ fluid.layers.reduce_sum(pmask), dtype=DATATYPE)
+ ratio = num_entries / num_positive
+ coef_0 = 0.5 * ratio / (ratio - 1)
+ coef_1 = 0.5 * ratio
+ epsilon = 0.000001
+ temp = fluid.layers.log(pred_score + epsilon)
+ loss_pos = fluid.layers.elementwise_mul(
+ fluid.layers.log(pred_score + epsilon), pmask)
+ loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
+ loss_neg = fluid.layers.elementwise_mul(
+ fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
+ loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
+ loss = -1 * (loss_pos + loss_neg)
+ return loss
+
+ loss_start = bi_loss(pred_start, gt_start)
+ loss_end = bi_loss(pred_end, gt_end)
+ loss = loss_start + loss_end
+ return loss
+
+ def pem_reg_loss_func(pred_score, gt_iou_map, mask):
+
+ gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
+
+ u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE)
+ u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3)
+ u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
+ u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.)
+ u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
+ u_lmask = fluid.layers.elementwise_mul(u_lmask, mask)
+
+ num_h = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
+ num_m = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
+ num_l = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
+
+ r_m = num_h / num_m
+ u_smmask = fluid.layers.uniform_random(
+ shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
+ dtype=DATATYPE,
+ min=0.0,
+ max=1.0)
+ u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
+ u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
+
+ r_l = num_h / num_l
+ u_slmask = fluid.layers.uniform_random(
+ shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
+ dtype=DATATYPE,
+ min=0.0,
+ max=1.0)
+ u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
+ u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
+
+ weights = u_hmask + u_smmask + u_slmask
+ weights.stop_gradient = True
+ loss = fluid.layers.square_error_cost(pred_score, gt_iou_map)
+ loss = fluid.layers.elementwise_mul(loss, weights)
+ loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
+ weights)
+
+ return loss
+
+ def pem_cls_loss_func(pred_score, gt_iou_map, mask):
+ gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
+ gt_iou_map.stop_gradient = True
+ pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE)
+ nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE)
+ nmask = fluid.layers.elementwise_mul(nmask, mask)
+
+ num_positive = fluid.layers.reduce_sum(pmask)
+ num_entries = num_positive + fluid.layers.reduce_sum(nmask)
+ ratio = num_entries / num_positive
+ coef_0 = 0.5 * ratio / (ratio - 1)
+ coef_1 = 0.5 * ratio
+ epsilon = 0.000001
+ loss_pos = fluid.layers.elementwise_mul(
+ fluid.layers.log(pred_score + epsilon), pmask)
+ loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos)
+ loss_neg = fluid.layers.elementwise_mul(
+ fluid.layers.log(1.0 - pred_score + epsilon), nmask)
+ loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg)
+ loss = -1 * (loss_pos + loss_neg) / num_entries
+ return loss
+
+ pred_bm_reg = fluid.layers.squeeze(
+ fluid.layers.slice(
+ pred_bm, axes=[1], starts=[0], ends=[1]), axes=[1])
+ pred_bm_cls = fluid.layers.squeeze(
+ fluid.layers.slice(
+ pred_bm, axes=[1], starts=[1], ends=[2]), axes=[1])
+
+ bm_mask = _get_mask(cfg)
+
+ pem_reg_loss = pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask)
+ pem_cls_loss = pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask)
+
+ tem_loss = tem_loss_func(pred_start, pred_end, gt_start, gt_end)
+
+ loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
+ return loss, tem_loss, pem_reg_loss, pem_cls_loss
diff --git a/dygraph/bmn/predict.py b/dygraph/bmn/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..363e15b0d36c8065436531432109dbc60296b8ee
--- /dev/null
+++ b/dygraph/bmn/predict.py
@@ -0,0 +1,147 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import paddle.fluid as fluid
+import numpy as np
+import argparse
+import sys
+import os
+import ast
+import json
+
+from model import BMN
+from eval import gen_props
+from reader import BMNReader
+from bmn_utils import bmn_post_processing
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("BMN test for performance evaluation.")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=None,
+ help='training batch size. None to use config file setting.')
+ parser.add_argument(
+ '--use_gpu',
+ type=ast.literal_eval,
+ default=True,
+ help='default use gpu.')
+ parser.add_argument(
+ '--weights',
+ type=str,
+ default="checkpoint/bmn_paddle_dy_final",
+ help='weight path, None to automatically download weights provided by Paddle.'
+ )
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default="predict_results/",
+ help='output dir path, default to use ./predict_results/')
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=1,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+def get_dataset_dict(cfg):
+ file_list = cfg.INFER.filelist
+ annos = json.load(open(file_list))
+ video_dict = {}
+ for video_name in annos.keys():
+ video_dict[video_name] = annos[video_name]
+ video_list = list(video_dict.keys())
+ video_list.sort()
+ return video_dict, video_list
+
+
+# Prediction
+def infer_bmn(args):
+ config = parse_config(args.config_file)
+ infer_config = merge_configs(config, 'infer', vars(args))
+ print_configs(infer_config, "Infer")
+
+ if not os.path.isdir(infer_config.INFER.output_path):
+ os.makedirs(infer_config.INFER.output_path)
+ if not os.path.isdir(infer_config.INFER.result_path):
+ os.makedirs(infer_config.INFER.result_path)
+ place = fluid.CUDAPlace(0)
+ with fluid.dygraph.guard(place):
+ bmn = BMN(infer_config)
+ # load checkpoint
+ if args.weights:
+ assert os.path.exists(args.weights + ".pdparams"
+ ), "Given weight dir {} not exist.".format(
+ args.weights)
+
+ logger.info('load test weights from {}'.format(args.weights))
+ model_dict, _ = fluid.load_dygraph(args.weights)
+ bmn.set_dict(model_dict)
+
+ reader = BMNReader(mode="infer", cfg=infer_config)
+ infer_reader = reader.create_reader()
+
+ video_dict, video_list = get_dataset_dict(infer_config)
+
+ bmn.eval()
+ for batch_id, data in enumerate(infer_reader()):
+ video_feat = np.array([item[0] for item in data]).astype(DATATYPE)
+ video_idx = [item[1] for item in data][0] #batch_size=1 by default
+
+ x_data = fluid.dygraph.base.to_variable(video_feat)
+
+ pred_bm, pred_start, pred_end = bmn(x_data)
+
+ pred_bm = pred_bm.numpy()
+ pred_start = pred_start[0].numpy()
+ pred_end = pred_end[0].numpy()
+
+ logger.info("Processing................ batch {}".format(batch_id))
+ gen_props(
+ pred_bm,
+ pred_start,
+ pred_end,
+ video_idx,
+ video_list,
+ infer_config,
+ mode='infer')
+
+ logger.info("Post_processing....This may take a while")
+ bmn_post_processing(video_dict, infer_config.INFER.subset,
+ infer_config.INFER.output_path,
+ infer_config.INFER.result_path)
+ logger.info("[INFER] infer finished. Results saved in {}".format(
+ args.save_dir) + "bmn_results_test.json")
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ infer_bmn(args)
diff --git a/dygraph/bmn/reader.py b/dygraph/bmn/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0877c5907337bfaf13d30b328305d07b03fbe4b
--- /dev/null
+++ b/dygraph/bmn/reader.py
@@ -0,0 +1,287 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import numpy as np
+import random
+import json
+import multiprocessing
+import functools
+import logging
+import platform
+import os
+
+logger = logging.getLogger(__name__)
+
+from bmn_utils import iou_with_anchors, ioa_with_anchors
+
+
+class BMNReader():
+ def __init__(self, mode, cfg):
+ self.mode = mode
+ self.tscale = cfg.MODEL.tscale # 100
+ self.dscale = cfg.MODEL.dscale # 100
+ self.anno_file = cfg.MODEL.anno_file
+ self.file_list = cfg.INFER.filelist
+ self.subset = cfg[mode.upper()]['subset']
+ self.tgap = 1. / self.tscale
+ self.feat_path = cfg.MODEL.feat_path
+
+ self.get_dataset_dict()
+ self.get_match_map()
+
+ self.batch_size = cfg[mode.upper()]['batch_size']
+ self.num_threads = cfg[mode.upper()]['num_threads']
+ if (mode == 'test') or (mode == 'infer'):
+ self.num_threads = 1 # set num_threads as 1 for test and infer
+
+ def get_dataset_dict(self):
+ self.video_dict = {}
+ if self.mode == "infer":
+ annos = json.load(open(self.file_list))
+ for video_name in annos.keys():
+ self.video_dict[video_name] = annos[video_name]
+ else:
+ annos = json.load(open(self.anno_file))
+ for video_name in annos.keys():
+ video_subset = annos[video_name]["subset"]
+ if self.subset in video_subset:
+ self.video_dict[video_name] = annos[video_name]
+ self.video_list = list(self.video_dict.keys())
+ self.video_list.sort()
+ print("%s subset video numbers: %d" %
+ (self.subset, len(self.video_list)))
+
+ def get_match_map(self):
+ match_map = []
+ for idx in range(self.tscale):
+ tmp_match_window = []
+ xmin = self.tgap * idx
+ for jdx in range(1, self.tscale + 1):
+ xmax = xmin + self.tgap * jdx
+ tmp_match_window.append([xmin, xmax])
+ match_map.append(tmp_match_window)
+ match_map = np.array(match_map)
+ match_map = np.transpose(match_map, [1, 0, 2])
+ match_map = np.reshape(match_map, [-1, 2])
+ self.match_map = match_map
+ self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
+ self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
+
+ def get_video_label(self, video_name):
+ video_info = self.video_dict[video_name]
+ video_second = video_info['duration_second']
+ video_labels = video_info['annotations']
+
+ gt_bbox = []
+ gt_iou_map = []
+ for gt in video_labels:
+ tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
+ tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
+ gt_bbox.append([tmp_start, tmp_end])
+ tmp_gt_iou_map = iou_with_anchors(
+ self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end)
+ tmp_gt_iou_map = np.reshape(tmp_gt_iou_map,
+ [self.dscale, self.tscale])
+ gt_iou_map.append(tmp_gt_iou_map)
+ gt_iou_map = np.array(gt_iou_map)
+ gt_iou_map = np.max(gt_iou_map, axis=0)
+
+ gt_bbox = np.array(gt_bbox)
+ gt_xmins = gt_bbox[:, 0]
+ gt_xmaxs = gt_bbox[:, 1]
+ gt_len_small = 3 * self.tgap
+ gt_start_bboxs = np.stack(
+ (gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
+ gt_end_bboxs = np.stack(
+ (gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
+
+ match_score_start = []
+ for jdx in range(len(self.anchor_xmin)):
+ match_score_start.append(
+ np.max(
+ ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
+ jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
+ match_score_end = []
+ for jdx in range(len(self.anchor_xmin)):
+ match_score_end.append(
+ np.max(
+ ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
+ jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
+
+ gt_start = np.array(match_score_start)
+ gt_end = np.array(match_score_end)
+ return gt_iou_map, gt_start, gt_end
+
+ def load_file(self, video_name):
+ file_name = video_name + ".npy"
+ file_path = os.path.join(self.feat_path, file_name)
+ video_feat = np.load(file_path)
+ video_feat = video_feat.T
+ video_feat = video_feat.astype("float32")
+ return video_feat
+
+ def create_reader(self):
+ """reader creator for bmn model"""
+ if self.mode == 'infer':
+ return self.make_infer_reader()
+ if self.num_threads == 1:
+ return self.make_reader()
+ else:
+ sysstr = platform.system()
+ if sysstr == 'Windows':
+ return self.make_multithread_reader()
+ else:
+ return self.make_multiprocess_reader()
+
+ def make_infer_reader(self):
+ """reader for inference"""
+
+ def reader():
+ batch_out = []
+ for video_name in self.video_list:
+ video_idx = self.video_list.index(video_name)
+ video_feat = self.load_file(video_name)
+ batch_out.append((video_feat, video_idx))
+
+ if len(batch_out) == self.batch_size:
+ yield batch_out
+ batch_out = []
+
+ return reader
+
+ def make_reader(self):
+ """single process reader"""
+
+ def reader():
+ video_list = self.video_list
+ if self.mode == 'train':
+ random.shuffle(video_list)
+
+ batch_out = []
+ for video_name in video_list:
+ video_idx = video_list.index(video_name)
+ video_feat = self.load_file(video_name)
+ gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
+
+ if self.mode == 'train' or self.mode == 'valid':
+ batch_out.append((video_feat, gt_iou_map, gt_start, gt_end))
+ elif self.mode == 'test':
+ batch_out.append(
+ (video_feat, gt_iou_map, gt_start, gt_end, video_idx))
+ else:
+ raise NotImplementedError('mode {} not implemented'.format(
+ self.mode))
+ if len(batch_out) == self.batch_size:
+ yield batch_out
+ batch_out = []
+
+ return reader
+
+ def make_multithread_reader(self):
+ def reader():
+ if self.mode == 'train':
+ random.shuffle(self.video_list)
+ for video_name in self.video_list:
+ video_idx = self.video_list.index(video_name)
+ yield [video_name, video_idx]
+
+ def process_data(sample, mode):
+ video_name = sample[0]
+ video_idx = sample[1]
+ video_feat = self.load_file(video_name)
+ gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
+ if mode == 'train' or mode == 'valid':
+ return (video_feat, gt_iou_map, gt_start, gt_end)
+ elif mode == 'test':
+ return (video_feat, gt_iou_map, gt_start, gt_end, video_idx)
+ else:
+ raise NotImplementedError('mode {} not implemented'.format(
+ mode))
+
+ mapper = functools.partial(process_data, mode=self.mode)
+
+ def batch_reader():
+ xreader = paddle.reader.xmap_readers(mapper, reader,
+ self.num_threads, 1024)
+ batch = []
+ for item in xreader():
+ batch.append(item)
+ if len(batch) == self.batch_size:
+ yield batch
+ batch = []
+
+ return batch_reader
+
+ def make_multiprocess_reader(self):
+ """multiprocess reader"""
+
+ def read_into_queue(video_list, queue):
+
+ batch_out = []
+ for video_name in video_list:
+ video_idx = video_list.index(video_name)
+ video_feat = self.load_file(video_name)
+ gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
+
+ if self.mode == 'train' or self.mode == 'valid':
+ batch_out.append((video_feat, gt_iou_map, gt_start, gt_end))
+ elif self.mode == 'test':
+ batch_out.append(
+ (video_feat, gt_iou_map, gt_start, gt_end, video_idx))
+ else:
+ raise NotImplementedError('mode {} not implemented'.format(
+ self.mode))
+
+ if len(batch_out) == self.batch_size:
+ queue.put(batch_out)
+ batch_out = []
+ queue.put(None)
+
+ def queue_reader():
+ video_list = self.video_list
+ if self.mode == 'train':
+ random.shuffle(video_list)
+
+ n = self.num_threads
+ queue_size = 20
+ reader_lists = [None] * n
+ file_num = int(len(video_list) // n)
+ for i in range(n):
+ if i < len(reader_lists) - 1:
+ tmp_list = video_list[i * file_num:(i + 1) * file_num]
+ else:
+ tmp_list = video_list[i * file_num:]
+ reader_lists[i] = tmp_list
+
+ queue = multiprocessing.Queue(queue_size)
+ p_list = [None] * len(reader_lists)
+ for i in range(len(reader_lists)):
+ reader_list = reader_lists[i]
+ p_list[i] = multiprocessing.Process(
+ target=read_into_queue, args=(reader_list, queue))
+ p_list[i].start()
+ reader_num = len(reader_lists)
+ finish_num = 0
+ while finish_num < reader_num:
+ sample = queue.get()
+ if sample is None:
+ finish_num += 1
+ else:
+ yield sample
+ for i in range(len(p_list)):
+ if p_list[i].is_alive():
+ p_list[i].join()
+
+ return queue_reader
diff --git a/dygraph/bmn/run.sh b/dygraph/bmn/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b426012056b92f9524271d127595775f281f789a
--- /dev/null
+++ b/dygraph/bmn/run.sh
@@ -0,0 +1,5 @@
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python -m paddle.distributed.launch \
+ --selected_gpus=0,1,2,3 \
+ --log_dir ./mylog \
+ train.py --use_data_parallel True
diff --git a/dygraph/bmn/train.py b/dygraph/bmn/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..0171830e2a3b7e02ee3ef3a95a410a803293ec1e
--- /dev/null
+++ b/dygraph/bmn/train.py
@@ -0,0 +1,243 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import paddle.fluid as fluid
+import numpy as np
+import argparse
+import ast
+import logging
+import sys
+import os
+
+from model import BMN, bmn_loss_func
+from reader import BMNReader
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("Paddle dynamic graph mode of BMN.")
+ parser.add_argument(
+ "--use_data_parallel",
+ type=ast.literal_eval,
+ default=False,
+ help="The flag indicating whether to use data parallel mode to train the model."
+ )
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=None,
+ help='training batch size. None to use config file setting.')
+ parser.add_argument(
+ '--learning_rate',
+ type=float,
+ default=0.001,
+ help='learning rate use for training. None to use config file setting.')
+ parser.add_argument(
+ '--resume',
+ type=str,
+ default=None,
+ help='filename to resume training based on previous checkpoints. '
+ 'None for not resuming any checkpoints.')
+ parser.add_argument(
+ '--use_gpu',
+ type=ast.literal_eval,
+ default=True,
+ help='default use gpu.')
+ parser.add_argument(
+ '--epoch',
+ type=int,
+ default=9,
+ help='epoch number, 0 for read from config file')
+ parser.add_argument(
+ '--valid_interval',
+ type=int,
+ default=1,
+ help='validation epoch interval, 0 for no validation.')
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default="checkpoint",
+ help='path to save train snapshoot')
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=10,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+# Optimizer
+def optimizer(cfg, parameter_list):
+ bd = [cfg.TRAIN.lr_decay_iter]
+ base_lr = cfg.TRAIN.learning_rate
+ lr_decay = cfg.TRAIN.learning_rate_decay
+ l2_weight_decay = cfg.TRAIN.l2_weight_decay
+ lr = [base_lr, base_lr * lr_decay]
+ optimizer = fluid.optimizer.Adam(
+ fluid.layers.piecewise_decay(
+ boundaries=bd, values=lr),
+ parameter_list=parameter_list,
+ regularization=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=l2_weight_decay))
+ return optimizer
+
+
+# Validation
+def val_bmn(model, config, args):
+ reader = BMNReader(mode="valid", cfg=config)
+ val_reader = reader.create_reader()
+ for batch_id, data in enumerate(val_reader()):
+ video_feat = np.array([item[0] for item in data]).astype(DATATYPE)
+ gt_iou_map = np.array([item[1] for item in data]).astype(DATATYPE)
+ gt_start = np.array([item[2] for item in data]).astype(DATATYPE)
+ gt_end = np.array([item[3] for item in data]).astype(DATATYPE)
+
+ x_data = fluid.dygraph.base.to_variable(video_feat)
+ gt_iou_map = fluid.dygraph.base.to_variable(gt_iou_map)
+ gt_start = fluid.dygraph.base.to_variable(gt_start)
+ gt_end = fluid.dygraph.base.to_variable(gt_end)
+ gt_iou_map.stop_gradient = True
+ gt_start.stop_gradient = True
+ gt_end.stop_gradient = True
+
+ pred_bm, pred_start, pred_end = model(x_data)
+
+ loss, tem_loss, pem_reg_loss, pem_cls_loss = bmn_loss_func(
+ pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end, config)
+ avg_loss = fluid.layers.mean(loss)
+
+ if args.log_interval > 0 and (batch_id % args.log_interval == 0):
+ logger.info('[VALID] iter {} '.format(batch_id)
+ + '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format(
+ '%.04f' % avg_loss.numpy()[0], '%.04f' % tem_loss.numpy()[0], \
+ '%.04f' % pem_reg_loss.numpy()[0], '%.04f' % pem_cls_loss.numpy()[0]))
+
+
+# TRAIN
+def train_bmn(args):
+ config = parse_config(args.config_file)
+ train_config = merge_configs(config, 'train', vars(args))
+ valid_config = merge_configs(config, 'valid', vars(args))
+
+ if not args.use_gpu:
+ place = fluid.CPUPlace()
+ elif not args.use_data_parallel:
+ place = fluid.CUDAPlace(0)
+ else:
+ place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
+
+ with fluid.dygraph.guard(place):
+ if args.use_data_parallel:
+ strategy = fluid.dygraph.parallel.prepare_context()
+ bmn = BMN(train_config)
+ adam = optimizer(train_config, parameter_list=bmn.parameters())
+
+ if args.use_data_parallel:
+ bmn = fluid.dygraph.parallel.DataParallel(bmn, strategy)
+
+ if args.resume:
+ # if resume weights is given, load resume weights directly
+ assert os.path.exists(args.resume + ".pdparams"), \
+ "Given resume weight dir {} not exist.".format(args.resume)
+
+ model, _ = fluid.dygraph.load_dygraph(args.resume)
+ bmn.set_dict(model)
+
+ reader = BMNReader(mode="train", cfg=train_config)
+ train_reader = reader.create_reader()
+ if args.use_data_parallel:
+ train_reader = fluid.contrib.reader.distributed_batch_reader(
+ train_reader)
+
+ for epoch in range(args.epoch):
+ for batch_id, data in enumerate(train_reader()):
+ video_feat = np.array(
+ [item[0] for item in data]).astype(DATATYPE)
+ gt_iou_map = np.array(
+ [item[1] for item in data]).astype(DATATYPE)
+ gt_start = np.array([item[2] for item in data]).astype(DATATYPE)
+ gt_end = np.array([item[3] for item in data]).astype(DATATYPE)
+
+ x_data = fluid.dygraph.base.to_variable(video_feat)
+ gt_iou_map = fluid.dygraph.base.to_variable(gt_iou_map)
+ gt_start = fluid.dygraph.base.to_variable(gt_start)
+ gt_end = fluid.dygraph.base.to_variable(gt_end)
+ gt_iou_map.stop_gradient = True
+ gt_start.stop_gradient = True
+ gt_end.stop_gradient = True
+
+ pred_bm, pred_start, pred_end = bmn(x_data)
+
+ loss, tem_loss, pem_reg_loss, pem_cls_loss = bmn_loss_func(
+ pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end,
+ train_config)
+ avg_loss = fluid.layers.mean(loss)
+
+ if args.use_data_parallel:
+ avg_loss = bmn.scale_loss(avg_loss)
+ avg_loss.backward()
+ bmn.apply_collective_grads()
+ else:
+ avg_loss.backward()
+
+ adam.minimize(avg_loss)
+
+ bmn.clear_gradients()
+
+ if args.log_interval > 0 and (
+ batch_id % args.log_interval == 0):
+ logger.info('[TRAIN] Epoch {}, iter {} '.format(epoch, batch_id)
+ + '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format(
+ '%.04f' % avg_loss.numpy()[0], '%.04f' % tem_loss.numpy()[0], \
+ '%.04f' % pem_reg_loss.numpy()[0], '%.04f' % pem_cls_loss.numpy()[0]))
+
+ logger.info('[TRAIN] Epoch {} training finished'.format(epoch))
+ if not os.path.isdir(args.save_dir):
+ os.makedirs(args.save_dir)
+ save_model_name = os.path.join(
+ args.save_dir, "bmn_paddle_dy" + "_epoch{}".format(epoch))
+ fluid.dygraph.save_dygraph(bmn.state_dict(), save_model_name)
+
+ # validation
+ if args.valid_interval > 0 and (epoch + 1
+ ) % args.valid_interval == 0:
+ bmn.eval()
+ val_bmn(bmn, valid_config, args)
+ bmn.train()
+
+ #save final results
+ if fluid.dygraph.parallel.Env().local_rank == 0:
+ save_model_name = os.path.join(args.save_dir,
+ "bmn_paddle_dy" + "_final")
+ fluid.dygraph.save_dygraph(bmn.state_dict(), save_model_name)
+ logger.info('[TRAIN] training finished')
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ train_bmn(args)