diff --git a/dygraph/bmn/README.md b/dygraph/bmn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1408cb5dd746b1308cb09ac703df4ff0e04ef46b --- /dev/null +++ b/dygraph/bmn/README.md @@ -0,0 +1,129 @@ +# BMN 视频动作定位模型动态图实现 + +--- +## 内容 + +- [模型简介](#模型简介) +- [代码结构](#代码结构) +- [数据准备](#数据准备) +- [模型训练](#模型训练) +- [模型评估](#模型评估) +- [模型推断](#模型推断) +- [参考论文](#参考论文) + + +## 模型简介 + +BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。 + +

+
+BMN Overview +

+ +动态图文档请参考[Dygraph](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/user_guides/howto/dygraph/DyGraph.html) + + +## 代码结构 +``` +├── bmn.yaml # 网络配置文件,用户可方便的配置参数 +├── run.sh # 快速运行脚本,可直接开始多卡训练 +├── train.py # 训练代码,包含网络结构相关代码 +├── eval.py # 评估代码,评估网络性能 +├── predict.py # 预测代码,针对任意输入预测结果 +├── model.py # 网络结构与损失函数定义 +├── reader.py # 数据reader +├── eval_anet_prop.py # 计算精度评估指标 +├── bmn_utils.py # 模型细节相关代码 +├── config_utils.py # 配置细节相关代码 +└── infer.list # 推断文件列表 +``` + + +## 数据准备 + +BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。 + + +## 模型训练 + +数据准备完成后,可通过如下两种方式启动训练: + +默认使用4卡训练,启动方式如下: + + bash run.sh + +若使用单卡训练,启动方式如下: + + export CUDA_VISIBLE_DEVICES=0 + python train.py + +- 代码运行需要先安装pandas + +- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型 + +**训练策略:** + +* 采用Adam优化器,初始learning\_rate=0.001 +* 权重衰减系数为1e-4 +* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1 + +- 下面的表格列出了此模型训练的大致时长(单位:分钟),使用的GPU型号为P40,CUDA版本8.0,cudnn版本7.2 + +| | 单卡 | 4卡 | +| :---: | :---: | :---: | +| 静态图 | 79 | 27 | +| 动态图 | 98 | 31 | + +## 模型评估 + +训练完成后,可通过如下方式进行模型评估: + + python eval.py --weights=$PATH_TO_WEIGHTS + +- 进行评估时,可修改脚本中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/bmn\_paddle\_dy\_final.pdparams。 + +- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。 + +- 使用CPU进行评估时,请将上面的命令行`use_gpu`设置为False。 + +- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。 + + +使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下: + +- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org) + +- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。) + +- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。 + +- 计算精度指标 + + ```python eval_anet_prop.py``` + + +在ActivityNet1.3数据集下评估精度如下: + +| AR@1 | AR@5 | AR@10 | AR@100 | AUC | +| :---: | :---: | :---: | :---: | :---: | +| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% | + + +## 模型推断 + +可通过如下方式启动模型推断: + + python predict.py --weights=$PATH_TO_WEIGHTS \ + --filelist=$FILELIST + +- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/bmn\_paddle\_dy\_final.pdparams。 + +- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。 + +- 使用CPU进行推断时,请将命令行中的`use_gpu`设置为False + + +## 参考论文 + +- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen. diff --git a/dygraph/bmn/bmn.yaml b/dygraph/bmn/bmn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e16fa53bc01c4a94fcfb9781091443fa4e55775c --- /dev/null +++ b/dygraph/bmn/bmn.yaml @@ -0,0 +1,50 @@ +MODEL: + name: "BMN" + tscale: 100 + dscale: 100 + feat_dim: 400 + prop_boundary_ratio: 0.5 + num_sample: 32 + num_sample_perbin: 3 + anno_file: "../../PaddleCV/PaddleVideo/data/dataset/bmn/activitynet_1.3_annotations.json" + feat_path: './fix_feat_100' + +TRAIN: + subset: "train" + epoch: 9 + batch_size: 16 + num_threads: 8 + use_gpu: True + num_gpus: 4 + learning_rate: 0.001 + learning_rate_decay: 0.1 + lr_decay_iter: 4200 + l2_weight_decay: 1e-4 + +VALID: + subset: "validation" + batch_size: 16 + num_threads: 8 + use_gpu: True + num_gpus: 4 + +TEST: + subset: "validation" + batch_size: 1 + num_threads: 1 + snms_alpha: 0.001 + snms_t1: 0.5 + snms_t2: 0.9 + output_path: "output/EVAL/BMN_results" + result_path: "evaluate_results" + +INFER: + subset: "test" + batch_size: 1 + num_threads: 1 + snms_alpha: 0.4 + snms_t1: 0.5 + snms_t2: 0.9 + filelist: './infer.list' + output_path: "output/INFER/BMN_results" + result_path: "predict_results" diff --git a/dygraph/bmn/bmn_utils.py b/dygraph/bmn/bmn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02a960841e6c97d3484f870be260e28ca123e566 --- /dev/null +++ b/dygraph/bmn/bmn_utils.py @@ -0,0 +1,217 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import numpy as np +import pandas as pd +import multiprocessing as mp +import json +import os +import math + + +def iou_with_anchors(anchors_min, anchors_max, box_min, box_max): + """Compute jaccard score between a box and the anchors. + """ + len_anchors = anchors_max - anchors_min + int_xmin = np.maximum(anchors_min, box_min) + int_xmax = np.minimum(anchors_max, box_max) + inter_len = np.maximum(int_xmax - int_xmin, 0.) + union_len = len_anchors - inter_len + box_max - box_min + jaccard = np.divide(inter_len, union_len) + return jaccard + + +def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max): + """Compute intersection between score a box and the anchors. + """ + len_anchors = anchors_max - anchors_min + int_xmin = np.maximum(anchors_min, box_min) + int_xmax = np.minimum(anchors_max, box_max) + inter_len = np.maximum(int_xmax - int_xmin, 0.) + scores = np.divide(inter_len, len_anchors) + return scores + + +def boundary_choose(score_list): + max_score = max(score_list) + mask_high = (score_list > max_score * 0.5) + score_list = list(score_list) + score_middle = np.array([0.0] + score_list + [0.0]) + score_front = np.array([0.0, 0.0] + score_list) + score_back = np.array(score_list + [0.0, 0.0]) + mask_peak = ((score_middle > score_front) & (score_middle > score_back)) + mask_peak = mask_peak[1:-1] + mask = (mask_high | mask_peak).astype('float32') + return mask + + +def soft_nms(df, alpha, t1, t2): + ''' + df: proposals generated by network; + alpha: alpha value of Gaussian decaying function; + t1, t2: threshold for soft nms. + ''' + df = df.sort_values(by="score", ascending=False) + tstart = list(df.xmin.values[:]) + tend = list(df.xmax.values[:]) + tscore = list(df.score.values[:]) + + rstart = [] + rend = [] + rscore = [] + + while len(tscore) > 1 and len(rscore) < 101: + max_index = tscore.index(max(tscore)) + tmp_iou_list = iou_with_anchors( + np.array(tstart), + np.array(tend), tstart[max_index], tend[max_index]) + for idx in range(0, len(tscore)): + if idx != max_index: + tmp_iou = tmp_iou_list[idx] + tmp_width = tend[max_index] - tstart[max_index] + if tmp_iou > t1 + (t2 - t1) * tmp_width: + tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) / + alpha) + + rstart.append(tstart[max_index]) + rend.append(tend[max_index]) + rscore.append(tscore[max_index]) + tstart.pop(max_index) + tend.pop(max_index) + tscore.pop(max_index) + + newDf = pd.DataFrame() + newDf['score'] = rscore + newDf['xmin'] = rstart + newDf['xmax'] = rend + return newDf + + +def video_process(video_list, + video_dict, + output_path, + result_dict, + snms_alpha=0.4, + snms_t1=0.55, + snms_t2=0.9): + + for video_name in video_list: + print("Processing video........" + video_name) + df = pd.read_csv(os.path.join(output_path, video_name + ".csv")) + if len(df) > 1: + df = soft_nms(df, snms_alpha, snms_t1, snms_t2) + + video_duration = video_dict[video_name]["duration_second"] + proposal_list = [] + for idx in range(min(100, len(df))): + tmp_prop={"score":df.score.values[idx], \ + "segment":[max(0,df.xmin.values[idx])*video_duration, \ + min(1,df.xmax.values[idx])*video_duration]} + proposal_list.append(tmp_prop) + result_dict[video_name[2:]] = proposal_list + + +def bmn_post_processing(video_dict, subset, output_path, result_path): + video_list = video_dict.keys() + video_list = list(video_dict.keys()) + global result_dict + result_dict = mp.Manager().dict() + pp_num = 12 + + num_videos = len(video_list) + num_videos_per_thread = int(num_videos / pp_num) + processes = [] + for tid in range(pp_num - 1): + tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) * + num_videos_per_thread] + p = mp.Process( + target=video_process, + args=(tmp_video_list, video_dict, output_path, result_dict)) + p.start() + processes.append(p) + tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:] + p = mp.Process( + target=video_process, + args=(tmp_video_list, video_dict, output_path, result_dict)) + p.start() + processes.append(p) + for p in processes: + p.join() + + result_dict = dict(result_dict) + output_dict = { + "version": "VERSION 1.3", + "results": result_dict, + "external_data": {} + } + outfile = open( + os.path.join(result_path, "bmn_results_%s.json" % subset), "w") + + json.dump(output_dict, outfile) + outfile.close() + + +def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample, + num_sample_perbin): + """ generate sample mask for a boundary-matching pair """ + plen = float(seg_xmax - seg_xmin) + plen_sample = plen / (num_sample * num_sample_perbin - 1.0) + total_samples = [ + seg_xmin + plen_sample * ii + for ii in range(num_sample * num_sample_perbin) + ] + p_mask = [] + for idx in range(num_sample): + bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) * + num_sample_perbin] + bin_vector = np.zeros([tscale]) + for sample in bin_samples: + sample_upper = math.ceil(sample) + sample_decimal, sample_down = math.modf(sample) + if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0: + bin_vector[int(sample_down)] += 1 - sample_decimal + if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0: + bin_vector[int(sample_upper)] += sample_decimal + bin_vector = 1.0 / num_sample_perbin * bin_vector + p_mask.append(bin_vector) + p_mask = np.stack(p_mask, axis=1) + return p_mask + + +def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample, + num_sample_perbin): + """ generate sample mask for each point in Boundary-Matching Map """ + mask_mat = [] + for start_index in range(tscale): + mask_mat_vector = [] + for duration_index in range(dscale): + if start_index + duration_index < tscale: + p_xmin = start_index + p_xmax = start_index + duration_index + center_len = float(p_xmax - p_xmin) + 1 + sample_xmin = p_xmin - center_len * prop_boundary_ratio + sample_xmax = p_xmax + center_len * prop_boundary_ratio + p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax, + tscale, num_sample, + num_sample_perbin) + else: + p_mask = np.zeros([tscale, num_sample]) + mask_mat_vector.append(p_mask) + mask_mat_vector = np.stack(mask_mat_vector, axis=2) + mask_mat.append(mask_mat_vector) + mask_mat = np.stack(mask_mat, axis=3) + mask_mat = mask_mat.astype(np.float32) + + sample_mask = np.reshape(mask_mat, [tscale, -1]) + return sample_mask diff --git a/dygraph/bmn/config_utils.py b/dygraph/bmn/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb59a25c316d3e435a9f74c453b61e0362f85d38 --- /dev/null +++ b/dygraph/bmn/config_utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import yaml +import logging + +logger = logging.getLogger(__name__) + +CONFIG_SECS = [ + 'train', + 'valid', + 'test', + 'infer', +] + + +class AttrDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value + + +def parse_config(cfg_file): + """Load a config file into AttrDict""" + with open(cfg_file, 'r') as fopen: + yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader)) + create_attr_dict(yaml_config) + return yaml_config + + +def create_attr_dict(yaml_config): + from ast import literal_eval + for key, value in yaml_config.items(): + if type(value) is dict: + yaml_config[key] = value = AttrDict(value) + if isinstance(value, str): + try: + value = literal_eval(value) + except BaseException: + pass + if isinstance(value, AttrDict): + create_attr_dict(yaml_config[key]) + else: + yaml_config[key] = value + return + + +def merge_configs(cfg, sec, args_dict): + assert sec in CONFIG_SECS, "invalid config section {}".format(sec) + sec_dict = getattr(cfg, sec.upper()) + for k, v in args_dict.items(): + if v is None: + continue + try: + if hasattr(sec_dict, k): + setattr(sec_dict, k, v) + except: + pass + return cfg + + +def print_configs(cfg, mode): + logger.info("---------------- {:>5} Arguments ----------------".format( + mode)) + for sec, sec_items in cfg.items(): + logger.info("{}:".format(sec)) + for k, v in sec_items.items(): + logger.info(" {}:{}".format(k, v)) + logger.info("-------------------------------------------------") diff --git a/dygraph/bmn/eval.py b/dygraph/bmn/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..f96780691a0575d7d668f7114655309552854ffc --- /dev/null +++ b/dygraph/bmn/eval.py @@ -0,0 +1,220 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import argparse +import pandas as pd +import os +import sys +import ast +import json +import logging + +from reader import BMNReader +from model import BMN, bmn_loss_func +from bmn_utils import boundary_choose, bmn_post_processing +from config_utils import * + +DATATYPE = 'float32' + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser("BMN test for performance evaluation.") + parser.add_argument( + '--config_file', + type=str, + default='bmn.yaml', + help='path to config file of model') + parser.add_argument( + '--batch_size', + type=int, + default=None, + help='training batch size. None to use config file setting.') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + parser.add_argument( + '--weights', + type=str, + default="checkpoint/bmn_paddle_dy_final", + help='weight path, None to automatically download weights provided by Paddle.' + ) + parser.add_argument( + '--save_dir', + type=str, + default="evaluate_results/", + help='output dir path, default to use ./evaluate_results/') + parser.add_argument( + '--log_interval', + type=int, + default=1, + help='mini-batch interval to log.') + args = parser.parse_args() + return args + + +def get_dataset_dict(cfg): + anno_file = cfg.MODEL.anno_file + annos = json.load(open(anno_file)) + subset = cfg.TEST.subset + video_dict = {} + for video_name in annos.keys(): + video_subset = annos[video_name]["subset"] + if subset in video_subset: + video_dict[video_name] = annos[video_name] + video_list = list(video_dict.keys()) + video_list.sort() + return video_dict, video_list + + +def gen_props(pred_bm, pred_start, pred_end, fid, video_list, cfg, mode='test'): + if mode == 'infer': + output_path = cfg.INFER.output_path + else: + output_path = cfg.TEST.output_path + tscale = cfg.MODEL.tscale + dscale = cfg.MODEL.dscale + snippet_xmins = [1.0 / tscale * i for i in range(tscale)] + snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)] + cols = ["xmin", "xmax", "score"] + + video_name = video_list[fid] + pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :] + start_mask = boundary_choose(pred_start) + start_mask[0] = 1. + end_mask = boundary_choose(pred_end) + end_mask[-1] = 1. + score_vector_list = [] + for idx in range(dscale): + for jdx in range(tscale): + start_index = jdx + end_index = start_index + idx + if end_index < tscale and start_mask[start_index] == 1 and end_mask[ + end_index] == 1: + xmin = snippet_xmins[start_index] + xmax = snippet_xmaxs[end_index] + xmin_score = pred_start[start_index] + xmax_score = pred_end[end_index] + bm_score = pred_bm[idx, jdx] + conf_score = xmin_score * xmax_score * bm_score + score_vector_list.append([xmin, xmax, conf_score]) + + score_vector_list = np.stack(score_vector_list) + video_df = pd.DataFrame(score_vector_list, columns=cols) + video_df.to_csv( + os.path.join(output_path, "%s.csv" % video_name), index=False) + + +# Performance Evaluation +def test_bmn(args): + config = parse_config(args.config_file) + test_config = merge_configs(config, 'test', vars(args)) + print_configs(test_config, "Test") + + if not os.path.isdir(test_config.TEST.output_path): + os.makedirs(test_config.TEST.output_path) + if not os.path.isdir(test_config.TEST.result_path): + os.makedirs(test_config.TEST.result_path) + place = fluid.CUDAPlace(0) + with fluid.dygraph.guard(place): + bmn = BMN(test_config) + + # load checkpoint + if args.weights: + assert os.path.exists(args.weights + '.pdparams' + ), "Given weight dir {} not exist.".format( + args.weights) + + logger.info('load test weights from {}'.format(args.weights)) + model_dict, _ = fluid.load_dygraph(args.weights) + bmn.set_dict(model_dict) + + reader = BMNReader(mode="test", cfg=test_config) + test_reader = reader.create_reader() + + aggr_loss = 0.0 + aggr_tem_loss = 0.0 + aggr_pem_reg_loss = 0.0 + aggr_pem_cls_loss = 0.0 + aggr_batch_size = 0 + video_dict, video_list = get_dataset_dict(test_config) + + bmn.eval() + for batch_id, data in enumerate(test_reader()): + video_feat = np.array([item[0] for item in data]).astype(DATATYPE) + gt_iou_map = np.array([item[1] for item in data]).astype(DATATYPE) + gt_start = np.array([item[2] for item in data]).astype(DATATYPE) + gt_end = np.array([item[3] for item in data]).astype(DATATYPE) + video_idx = [item[4] for item in data][0] #batch_size=1 by default + + x_data = fluid.dygraph.base.to_variable(video_feat) + gt_iou_map = fluid.dygraph.base.to_variable(gt_iou_map) + gt_start = fluid.dygraph.base.to_variable(gt_start) + gt_end = fluid.dygraph.base.to_variable(gt_end) + gt_iou_map.stop_gradient = True + gt_start.stop_gradient = True + gt_end.stop_gradient = True + + pred_bm, pred_start, pred_end = bmn(x_data) + loss, tem_loss, pem_reg_loss, pem_cls_loss = bmn_loss_func( + pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end, + test_config) + + pred_bm = pred_bm.numpy() + pred_start = pred_start[0].numpy() + pred_end = pred_end[0].numpy() + aggr_loss += np.mean(loss.numpy()) + aggr_tem_loss += np.mean(tem_loss.numpy()) + aggr_pem_reg_loss += np.mean(pem_reg_loss.numpy()) + aggr_pem_cls_loss += np.mean(pem_cls_loss.numpy()) + aggr_batch_size += 1 + + logger.info("Processing................ batch {}".format(batch_id)) + gen_props( + pred_bm, + pred_start, + pred_end, + video_idx, + video_list, + test_config, + mode='test') + + avg_loss = aggr_loss / aggr_batch_size + avg_tem_loss = aggr_tem_loss / aggr_batch_size + avg_pem_reg_loss = aggr_pem_reg_loss / aggr_batch_size + avg_pem_cls_loss = aggr_pem_cls_loss / aggr_batch_size + + logger.info('[EVAL] \tAvg_oss = {}, \tAvg_tem_loss = {}, \tAvg_pem_reg_loss = {}, \tAvg_pem_cls_loss = {}'.format( + '%.04f' % avg_loss, '%.04f' % avg_tem_loss, \ + '%.04f' % avg_pem_reg_loss, '%.04f' % avg_pem_cls_loss)) + + logger.info("Post_processing....This may take a while") + bmn_post_processing(video_dict, test_config.TEST.subset, + test_config.TEST.output_path, + test_config.TEST.result_path) + logger.info("[EVAL] eval finished") + + +if __name__ == '__main__': + args = parse_args() + test_bmn(args) diff --git a/dygraph/bmn/eval_anet_prop.py b/dygraph/bmn/eval_anet_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..2a91c6effff8083ad54180c0931585731f2fa6d8 --- /dev/null +++ b/dygraph/bmn/eval_anet_prop.py @@ -0,0 +1,110 @@ +''' +Calculate AR@N and AUC; +Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git) +''' + +import sys +sys.path.append('./Evaluation') + +from eval_proposal import ANETproposal +import numpy as np +import argparse +import os + +parser = argparse.ArgumentParser("Eval AR vs AN of proposal") +parser.add_argument( + '--eval_file', + type=str, + default='bmn_results_validation.json', + help='name of results file to eval') + + +def run_evaluation(ground_truth_filename, + proposal_filename, + max_avg_nr_proposals=100, + tiou_thresholds=np.linspace(0.5, 0.95, 10), + subset='validation'): + + anet_proposal = ANETproposal( + ground_truth_filename, + proposal_filename, + tiou_thresholds=tiou_thresholds, + max_avg_nr_proposals=max_avg_nr_proposals, + subset=subset, + verbose=True, + check_status=False) + anet_proposal.evaluate() + recall = anet_proposal.recall + average_recall = anet_proposal.avg_recall + average_nr_proposals = anet_proposal.proposals_per_video + + return (average_nr_proposals, average_recall, recall) + + +def plot_metric(average_nr_proposals, + average_recall, + recall, + tiou_thresholds=np.linspace(0.5, 0.95, 10)): + fn_size = 14 + plt.figure(num=None, figsize=(12, 8)) + ax = plt.subplot(1, 1, 1) + + colors = [ + 'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo' + ] + area_under_curve = np.zeros_like(tiou_thresholds) + for i in range(recall.shape[0]): + area_under_curve[i] = np.trapz(recall[i], average_nr_proposals) + + for idx, tiou in enumerate(tiou_thresholds[::2]): + ax.plot( + average_nr_proposals, + recall[2 * idx, :], + color=colors[idx + 1], + label="tiou=[" + str(tiou) + "], area=" + str( + int(area_under_curve[2 * idx] * 100) / 100.), + linewidth=4, + linestyle='--', + marker=None) + + # Plots Average Recall vs Average number of proposals. + ax.plot( + average_nr_proposals, + average_recall, + color=colors[0], + label="tiou = 0.5:0.05:0.95," + " area=" + str( + int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.), + linewidth=4, + linestyle='-', + marker=None) + + handles, labels = ax.get_legend_handles_labels() + ax.legend( + [handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best') + + plt.ylabel('Average Recall', fontsize=fn_size) + plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size) + plt.grid(b=True, which="both") + plt.ylim([0, 1.0]) + plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size) + plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size) + plt.show() + + +if __name__ == "__main__": + args = parser.parse_args() + eval_file = args.eval_file + eval_file_path = os.path.join("evaluate_results", eval_file) + uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation( + "./Evaluation/data/activity_net_1_3_new.json", + eval_file_path, + max_avg_nr_proposals=100, + tiou_thresholds=np.linspace(0.5, 0.95, 10), + subset='validation') + + print("AR@1; AR@5; AR@10; AR@100") + print("%.02f %.02f %.02f %.02f" % + (100 * np.mean(uniform_recall_valid[:, 0]), + 100 * np.mean(uniform_recall_valid[:, 4]), + 100 * np.mean(uniform_recall_valid[:, 9]), + 100 * np.mean(uniform_recall_valid[:, -1]))) diff --git a/dygraph/bmn/infer.list b/dygraph/bmn/infer.list new file mode 100644 index 0000000000000000000000000000000000000000..44768f089e70e40913d9787571ae0a7151232558 --- /dev/null +++ b/dygraph/bmn/infer.list @@ -0,0 +1 @@ +{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}} \ No newline at end of file diff --git a/dygraph/bmn/model.py b/dygraph/bmn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f77e8e0e95bc4e0d397ba7247327c5bf7038c8e4 --- /dev/null +++ b/dygraph/bmn/model.py @@ -0,0 +1,339 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +import numpy as np +import math + +from bmn_utils import get_interp1d_mask + +DATATYPE = 'float32' + + +# Net +class Conv1D(fluid.dygraph.Layer): + def __init__(self, + prefix, + num_channels=256, + num_filters=256, + size_k=3, + padding=1, + groups=1, + act="relu"): + super(Conv1D, self).__init__() + fan_in = num_channels * size_k * 1 + k = 1. / math.sqrt(fan_in) + param_attr = ParamAttr( + name=prefix + "_w", + initializer=fluid.initializer.Uniform( + low=-k, high=k)) + bias_attr = ParamAttr( + name=prefix + "_b", + initializer=fluid.initializer.Uniform( + low=-k, high=k)) + + self._conv2d = fluid.dygraph.Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=(1, size_k), + stride=1, + padding=(0, padding), + groups=groups, + act=act, + param_attr=param_attr, + bias_attr=bias_attr) + + def forward(self, x): + x = fluid.layers.unsqueeze(input=x, axes=[2]) + x = self._conv2d(x) + x = fluid.layers.squeeze(input=x, axes=[2]) + return x + + +class BMN(fluid.dygraph.Layer): + def __init__(self, cfg): + super(BMN, self).__init__() + + #init config + self.tscale = cfg.MODEL.tscale + self.dscale = cfg.MODEL.dscale + self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio + self.num_sample = cfg.MODEL.num_sample + self.num_sample_perbin = cfg.MODEL.num_sample_perbin + + self.hidden_dim_1d = 256 + self.hidden_dim_2d = 128 + self.hidden_dim_3d = 512 + + # Base Module + self.b_conv1 = Conv1D( + prefix="Base_1", + num_channels=400, + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + self.b_conv2 = Conv1D( + prefix="Base_2", + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + + # Temporal Evaluation Module + self.ts_conv1 = Conv1D( + prefix="TEM_s1", + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + self.ts_conv2 = Conv1D( + prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid") + self.te_conv1 = Conv1D( + prefix="TEM_e1", + num_filters=self.hidden_dim_1d, + size_k=3, + padding=1, + groups=4, + act="relu") + self.te_conv2 = Conv1D( + prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid") + + #Proposal Evaluation Module + self.p_conv1 = Conv1D( + prefix="PEM_1d", + num_filters=self.hidden_dim_2d, + size_k=3, + padding=1, + act="relu") + + # init to speed up + sample_mask = get_interp1d_mask(self.tscale, self.dscale, + self.prop_boundary_ratio, + self.num_sample, self.num_sample_perbin) + self.sample_mask = fluid.dygraph.base.to_variable(sample_mask) + self.sample_mask.stop_gradient = True + + self.p_conv3d1 = fluid.dygraph.Conv3D( + num_channels=128, + num_filters=self.hidden_dim_3d, + filter_size=(self.num_sample, 1, 1), + stride=(self.num_sample, 1, 1), + padding=0, + act="relu", + param_attr=ParamAttr(name="PEM_3d1_w"), + bias_attr=ParamAttr(name="PEM_3d1_b")) + + self.p_conv2d1 = fluid.dygraph.Conv2D( + num_channels=512, + num_filters=self.hidden_dim_2d, + filter_size=1, + stride=1, + padding=0, + act="relu", + param_attr=ParamAttr(name="PEM_2d1_w"), + bias_attr=ParamAttr(name="PEM_2d1_b")) + self.p_conv2d2 = fluid.dygraph.Conv2D( + num_channels=128, + num_filters=self.hidden_dim_2d, + filter_size=3, + stride=1, + padding=1, + act="relu", + param_attr=ParamAttr(name="PEM_2d2_w"), + bias_attr=ParamAttr(name="PEM_2d2_b")) + self.p_conv2d3 = fluid.dygraph.Conv2D( + num_channels=128, + num_filters=self.hidden_dim_2d, + filter_size=3, + stride=1, + padding=1, + act="relu", + param_attr=ParamAttr(name="PEM_2d3_w"), + bias_attr=ParamAttr(name="PEM_2d3_b")) + self.p_conv2d4 = fluid.dygraph.Conv2D( + num_channels=128, + num_filters=2, + filter_size=1, + stride=1, + padding=0, + act="sigmoid", + param_attr=ParamAttr(name="PEM_2d4_w"), + bias_attr=ParamAttr(name="PEM_2d4_b")) + + def forward(self, x): + #Base Module + x = self.b_conv1(x) + x = self.b_conv2(x) + + #TEM + xs = self.ts_conv1(x) + xs = self.ts_conv2(xs) + xs = fluid.layers.squeeze(xs, axes=[1]) + xe = self.te_conv1(x) + xe = self.te_conv2(xe) + xe = fluid.layers.squeeze(xe, axes=[1]) + + #PEM + xp = self.p_conv1(x) + #BM layer + xp = fluid.layers.matmul(xp, self.sample_mask) + xp = fluid.layers.reshape( + xp, shape=[0, 0, -1, self.dscale, self.tscale]) + + xp = self.p_conv3d1(xp) + xp = fluid.layers.squeeze(xp, axes=[2]) + xp = self.p_conv2d1(xp) + xp = self.p_conv2d2(xp) + xp = self.p_conv2d3(xp) + xp = self.p_conv2d4(xp) + return xp, xs, xe + + +def bmn_loss_func(pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end, + cfg): + def _get_mask(cfg): + dscale = cfg.MODEL.dscale + tscale = cfg.MODEL.tscale + bm_mask = [] + for idx in range(dscale): + mask_vector = [1 for i in range(tscale - idx) + ] + [0 for i in range(idx)] + bm_mask.append(mask_vector) + bm_mask = np.array(bm_mask, dtype=np.float32) + self_bm_mask = fluid.layers.create_global_var( + shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True) + fluid.layers.assign(bm_mask, self_bm_mask) + self_bm_mask.stop_gradient = True + return self_bm_mask + + def tem_loss_func(pred_start, pred_end, gt_start, gt_end): + def bi_loss(pred_score, gt_label): + pred_score = fluid.layers.reshape( + x=pred_score, shape=[-1], inplace=False) + gt_label = fluid.layers.reshape( + x=gt_label, shape=[-1], inplace=False) + gt_label.stop_gradient = True + pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE) + num_entries = fluid.layers.cast( + fluid.layers.shape(pmask), dtype=DATATYPE) + num_positive = fluid.layers.cast( + fluid.layers.reduce_sum(pmask), dtype=DATATYPE) + ratio = num_entries / num_positive + coef_0 = 0.5 * ratio / (ratio - 1) + coef_1 = 0.5 * ratio + epsilon = 0.000001 + temp = fluid.layers.log(pred_score + epsilon) + loss_pos = fluid.layers.elementwise_mul( + fluid.layers.log(pred_score + epsilon), pmask) + loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos) + loss_neg = fluid.layers.elementwise_mul( + fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask)) + loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg) + loss = -1 * (loss_pos + loss_neg) + return loss + + loss_start = bi_loss(pred_start, gt_start) + loss_end = bi_loss(pred_end, gt_end) + loss = loss_start + loss_end + return loss + + def pem_reg_loss_func(pred_score, gt_iou_map, mask): + + gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask) + + u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE) + u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3) + u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE) + u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.) + u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE) + u_lmask = fluid.layers.elementwise_mul(u_lmask, mask) + + num_h = fluid.layers.cast( + fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE) + num_m = fluid.layers.cast( + fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE) + num_l = fluid.layers.cast( + fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE) + + r_m = num_h / num_m + u_smmask = fluid.layers.uniform_random( + shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]], + dtype=DATATYPE, + min=0.0, + max=1.0) + u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask) + u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE) + + r_l = num_h / num_l + u_slmask = fluid.layers.uniform_random( + shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]], + dtype=DATATYPE, + min=0.0, + max=1.0) + u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask) + u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE) + + weights = u_hmask + u_smmask + u_slmask + weights.stop_gradient = True + loss = fluid.layers.square_error_cost(pred_score, gt_iou_map) + loss = fluid.layers.elementwise_mul(loss, weights) + loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum( + weights) + + return loss + + def pem_cls_loss_func(pred_score, gt_iou_map, mask): + gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask) + gt_iou_map.stop_gradient = True + pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE) + nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE) + nmask = fluid.layers.elementwise_mul(nmask, mask) + + num_positive = fluid.layers.reduce_sum(pmask) + num_entries = num_positive + fluid.layers.reduce_sum(nmask) + ratio = num_entries / num_positive + coef_0 = 0.5 * ratio / (ratio - 1) + coef_1 = 0.5 * ratio + epsilon = 0.000001 + loss_pos = fluid.layers.elementwise_mul( + fluid.layers.log(pred_score + epsilon), pmask) + loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos) + loss_neg = fluid.layers.elementwise_mul( + fluid.layers.log(1.0 - pred_score + epsilon), nmask) + loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg) + loss = -1 * (loss_pos + loss_neg) / num_entries + return loss + + pred_bm_reg = fluid.layers.squeeze( + fluid.layers.slice( + pred_bm, axes=[1], starts=[0], ends=[1]), axes=[1]) + pred_bm_cls = fluid.layers.squeeze( + fluid.layers.slice( + pred_bm, axes=[1], starts=[1], ends=[2]), axes=[1]) + + bm_mask = _get_mask(cfg) + + pem_reg_loss = pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask) + pem_cls_loss = pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask) + + tem_loss = tem_loss_func(pred_start, pred_end, gt_start, gt_end) + + loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss + return loss, tem_loss, pem_reg_loss, pem_cls_loss diff --git a/dygraph/bmn/predict.py b/dygraph/bmn/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..363e15b0d36c8065436531432109dbc60296b8ee --- /dev/null +++ b/dygraph/bmn/predict.py @@ -0,0 +1,147 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import argparse +import sys +import os +import ast +import json + +from model import BMN +from eval import gen_props +from reader import BMNReader +from bmn_utils import bmn_post_processing +from config_utils import * + +DATATYPE = 'float32' + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser("BMN test for performance evaluation.") + parser.add_argument( + '--config_file', + type=str, + default='bmn.yaml', + help='path to config file of model') + parser.add_argument( + '--batch_size', + type=int, + default=None, + help='training batch size. None to use config file setting.') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + parser.add_argument( + '--weights', + type=str, + default="checkpoint/bmn_paddle_dy_final", + help='weight path, None to automatically download weights provided by Paddle.' + ) + parser.add_argument( + '--save_dir', + type=str, + default="predict_results/", + help='output dir path, default to use ./predict_results/') + parser.add_argument( + '--log_interval', + type=int, + default=1, + help='mini-batch interval to log.') + args = parser.parse_args() + return args + + +def get_dataset_dict(cfg): + file_list = cfg.INFER.filelist + annos = json.load(open(file_list)) + video_dict = {} + for video_name in annos.keys(): + video_dict[video_name] = annos[video_name] + video_list = list(video_dict.keys()) + video_list.sort() + return video_dict, video_list + + +# Prediction +def infer_bmn(args): + config = parse_config(args.config_file) + infer_config = merge_configs(config, 'infer', vars(args)) + print_configs(infer_config, "Infer") + + if not os.path.isdir(infer_config.INFER.output_path): + os.makedirs(infer_config.INFER.output_path) + if not os.path.isdir(infer_config.INFER.result_path): + os.makedirs(infer_config.INFER.result_path) + place = fluid.CUDAPlace(0) + with fluid.dygraph.guard(place): + bmn = BMN(infer_config) + # load checkpoint + if args.weights: + assert os.path.exists(args.weights + ".pdparams" + ), "Given weight dir {} not exist.".format( + args.weights) + + logger.info('load test weights from {}'.format(args.weights)) + model_dict, _ = fluid.load_dygraph(args.weights) + bmn.set_dict(model_dict) + + reader = BMNReader(mode="infer", cfg=infer_config) + infer_reader = reader.create_reader() + + video_dict, video_list = get_dataset_dict(infer_config) + + bmn.eval() + for batch_id, data in enumerate(infer_reader()): + video_feat = np.array([item[0] for item in data]).astype(DATATYPE) + video_idx = [item[1] for item in data][0] #batch_size=1 by default + + x_data = fluid.dygraph.base.to_variable(video_feat) + + pred_bm, pred_start, pred_end = bmn(x_data) + + pred_bm = pred_bm.numpy() + pred_start = pred_start[0].numpy() + pred_end = pred_end[0].numpy() + + logger.info("Processing................ batch {}".format(batch_id)) + gen_props( + pred_bm, + pred_start, + pred_end, + video_idx, + video_list, + infer_config, + mode='infer') + + logger.info("Post_processing....This may take a while") + bmn_post_processing(video_dict, infer_config.INFER.subset, + infer_config.INFER.output_path, + infer_config.INFER.result_path) + logger.info("[INFER] infer finished. Results saved in {}".format( + args.save_dir) + "bmn_results_test.json") + + +if __name__ == '__main__': + args = parse_args() + infer_bmn(args) diff --git a/dygraph/bmn/reader.py b/dygraph/bmn/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..c0877c5907337bfaf13d30b328305d07b03fbe4b --- /dev/null +++ b/dygraph/bmn/reader.py @@ -0,0 +1,287 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import numpy as np +import random +import json +import multiprocessing +import functools +import logging +import platform +import os + +logger = logging.getLogger(__name__) + +from bmn_utils import iou_with_anchors, ioa_with_anchors + + +class BMNReader(): + def __init__(self, mode, cfg): + self.mode = mode + self.tscale = cfg.MODEL.tscale # 100 + self.dscale = cfg.MODEL.dscale # 100 + self.anno_file = cfg.MODEL.anno_file + self.file_list = cfg.INFER.filelist + self.subset = cfg[mode.upper()]['subset'] + self.tgap = 1. / self.tscale + self.feat_path = cfg.MODEL.feat_path + + self.get_dataset_dict() + self.get_match_map() + + self.batch_size = cfg[mode.upper()]['batch_size'] + self.num_threads = cfg[mode.upper()]['num_threads'] + if (mode == 'test') or (mode == 'infer'): + self.num_threads = 1 # set num_threads as 1 for test and infer + + def get_dataset_dict(self): + self.video_dict = {} + if self.mode == "infer": + annos = json.load(open(self.file_list)) + for video_name in annos.keys(): + self.video_dict[video_name] = annos[video_name] + else: + annos = json.load(open(self.anno_file)) + for video_name in annos.keys(): + video_subset = annos[video_name]["subset"] + if self.subset in video_subset: + self.video_dict[video_name] = annos[video_name] + self.video_list = list(self.video_dict.keys()) + self.video_list.sort() + print("%s subset video numbers: %d" % + (self.subset, len(self.video_list))) + + def get_match_map(self): + match_map = [] + for idx in range(self.tscale): + tmp_match_window = [] + xmin = self.tgap * idx + for jdx in range(1, self.tscale + 1): + xmax = xmin + self.tgap * jdx + tmp_match_window.append([xmin, xmax]) + match_map.append(tmp_match_window) + match_map = np.array(match_map) + match_map = np.transpose(match_map, [1, 0, 2]) + match_map = np.reshape(match_map, [-1, 2]) + self.match_map = match_map + self.anchor_xmin = [self.tgap * i for i in range(self.tscale)] + self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)] + + def get_video_label(self, video_name): + video_info = self.video_dict[video_name] + video_second = video_info['duration_second'] + video_labels = video_info['annotations'] + + gt_bbox = [] + gt_iou_map = [] + for gt in video_labels: + tmp_start = max(min(1, gt["segment"][0] / video_second), 0) + tmp_end = max(min(1, gt["segment"][1] / video_second), 0) + gt_bbox.append([tmp_start, tmp_end]) + tmp_gt_iou_map = iou_with_anchors( + self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end) + tmp_gt_iou_map = np.reshape(tmp_gt_iou_map, + [self.dscale, self.tscale]) + gt_iou_map.append(tmp_gt_iou_map) + gt_iou_map = np.array(gt_iou_map) + gt_iou_map = np.max(gt_iou_map, axis=0) + + gt_bbox = np.array(gt_bbox) + gt_xmins = gt_bbox[:, 0] + gt_xmaxs = gt_bbox[:, 1] + gt_len_small = 3 * self.tgap + gt_start_bboxs = np.stack( + (gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1) + gt_end_bboxs = np.stack( + (gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1) + + match_score_start = [] + for jdx in range(len(self.anchor_xmin)): + match_score_start.append( + np.max( + ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[ + jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1]))) + match_score_end = [] + for jdx in range(len(self.anchor_xmin)): + match_score_end.append( + np.max( + ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[ + jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1]))) + + gt_start = np.array(match_score_start) + gt_end = np.array(match_score_end) + return gt_iou_map, gt_start, gt_end + + def load_file(self, video_name): + file_name = video_name + ".npy" + file_path = os.path.join(self.feat_path, file_name) + video_feat = np.load(file_path) + video_feat = video_feat.T + video_feat = video_feat.astype("float32") + return video_feat + + def create_reader(self): + """reader creator for bmn model""" + if self.mode == 'infer': + return self.make_infer_reader() + if self.num_threads == 1: + return self.make_reader() + else: + sysstr = platform.system() + if sysstr == 'Windows': + return self.make_multithread_reader() + else: + return self.make_multiprocess_reader() + + def make_infer_reader(self): + """reader for inference""" + + def reader(): + batch_out = [] + for video_name in self.video_list: + video_idx = self.video_list.index(video_name) + video_feat = self.load_file(video_name) + batch_out.append((video_feat, video_idx)) + + if len(batch_out) == self.batch_size: + yield batch_out + batch_out = [] + + return reader + + def make_reader(self): + """single process reader""" + + def reader(): + video_list = self.video_list + if self.mode == 'train': + random.shuffle(video_list) + + batch_out = [] + for video_name in video_list: + video_idx = video_list.index(video_name) + video_feat = self.load_file(video_name) + gt_iou_map, gt_start, gt_end = self.get_video_label(video_name) + + if self.mode == 'train' or self.mode == 'valid': + batch_out.append((video_feat, gt_iou_map, gt_start, gt_end)) + elif self.mode == 'test': + batch_out.append( + (video_feat, gt_iou_map, gt_start, gt_end, video_idx)) + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + if len(batch_out) == self.batch_size: + yield batch_out + batch_out = [] + + return reader + + def make_multithread_reader(self): + def reader(): + if self.mode == 'train': + random.shuffle(self.video_list) + for video_name in self.video_list: + video_idx = self.video_list.index(video_name) + yield [video_name, video_idx] + + def process_data(sample, mode): + video_name = sample[0] + video_idx = sample[1] + video_feat = self.load_file(video_name) + gt_iou_map, gt_start, gt_end = self.get_video_label(video_name) + if mode == 'train' or mode == 'valid': + return (video_feat, gt_iou_map, gt_start, gt_end) + elif mode == 'test': + return (video_feat, gt_iou_map, gt_start, gt_end, video_idx) + else: + raise NotImplementedError('mode {} not implemented'.format( + mode)) + + mapper = functools.partial(process_data, mode=self.mode) + + def batch_reader(): + xreader = paddle.reader.xmap_readers(mapper, reader, + self.num_threads, 1024) + batch = [] + for item in xreader(): + batch.append(item) + if len(batch) == self.batch_size: + yield batch + batch = [] + + return batch_reader + + def make_multiprocess_reader(self): + """multiprocess reader""" + + def read_into_queue(video_list, queue): + + batch_out = [] + for video_name in video_list: + video_idx = video_list.index(video_name) + video_feat = self.load_file(video_name) + gt_iou_map, gt_start, gt_end = self.get_video_label(video_name) + + if self.mode == 'train' or self.mode == 'valid': + batch_out.append((video_feat, gt_iou_map, gt_start, gt_end)) + elif self.mode == 'test': + batch_out.append( + (video_feat, gt_iou_map, gt_start, gt_end, video_idx)) + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + + if len(batch_out) == self.batch_size: + queue.put(batch_out) + batch_out = [] + queue.put(None) + + def queue_reader(): + video_list = self.video_list + if self.mode == 'train': + random.shuffle(video_list) + + n = self.num_threads + queue_size = 20 + reader_lists = [None] * n + file_num = int(len(video_list) // n) + for i in range(n): + if i < len(reader_lists) - 1: + tmp_list = video_list[i * file_num:(i + 1) * file_num] + else: + tmp_list = video_list[i * file_num:] + reader_lists[i] = tmp_list + + queue = multiprocessing.Queue(queue_size) + p_list = [None] * len(reader_lists) + for i in range(len(reader_lists)): + reader_list = reader_lists[i] + p_list[i] = multiprocessing.Process( + target=read_into_queue, args=(reader_list, queue)) + p_list[i].start() + reader_num = len(reader_lists) + finish_num = 0 + while finish_num < reader_num: + sample = queue.get() + if sample is None: + finish_num += 1 + else: + yield sample + for i in range(len(p_list)): + if p_list[i].is_alive(): + p_list[i].join() + + return queue_reader diff --git a/dygraph/bmn/run.sh b/dygraph/bmn/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..b426012056b92f9524271d127595775f281f789a --- /dev/null +++ b/dygraph/bmn/run.sh @@ -0,0 +1,5 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch \ + --selected_gpus=0,1,2,3 \ + --log_dir ./mylog \ + train.py --use_data_parallel True diff --git a/dygraph/bmn/train.py b/dygraph/bmn/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0171830e2a3b7e02ee3ef3a95a410a803293ec1e --- /dev/null +++ b/dygraph/bmn/train.py @@ -0,0 +1,243 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import argparse +import ast +import logging +import sys +import os + +from model import BMN, bmn_loss_func +from reader import BMNReader +from config_utils import * + +DATATYPE = 'float32' + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser("Paddle dynamic graph mode of BMN.") + parser.add_argument( + "--use_data_parallel", + type=ast.literal_eval, + default=False, + help="The flag indicating whether to use data parallel mode to train the model." + ) + parser.add_argument( + '--config_file', + type=str, + default='bmn.yaml', + help='path to config file of model') + parser.add_argument( + '--batch_size', + type=int, + default=None, + help='training batch size. None to use config file setting.') + parser.add_argument( + '--learning_rate', + type=float, + default=0.001, + help='learning rate use for training. None to use config file setting.') + parser.add_argument( + '--resume', + type=str, + default=None, + help='filename to resume training based on previous checkpoints. ' + 'None for not resuming any checkpoints.') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + parser.add_argument( + '--epoch', + type=int, + default=9, + help='epoch number, 0 for read from config file') + parser.add_argument( + '--valid_interval', + type=int, + default=1, + help='validation epoch interval, 0 for no validation.') + parser.add_argument( + '--save_dir', + type=str, + default="checkpoint", + help='path to save train snapshoot') + parser.add_argument( + '--log_interval', + type=int, + default=10, + help='mini-batch interval to log.') + args = parser.parse_args() + return args + + +# Optimizer +def optimizer(cfg, parameter_list): + bd = [cfg.TRAIN.lr_decay_iter] + base_lr = cfg.TRAIN.learning_rate + lr_decay = cfg.TRAIN.learning_rate_decay + l2_weight_decay = cfg.TRAIN.l2_weight_decay + lr = [base_lr, base_lr * lr_decay] + optimizer = fluid.optimizer.Adam( + fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + parameter_list=parameter_list, + regularization=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=l2_weight_decay)) + return optimizer + + +# Validation +def val_bmn(model, config, args): + reader = BMNReader(mode="valid", cfg=config) + val_reader = reader.create_reader() + for batch_id, data in enumerate(val_reader()): + video_feat = np.array([item[0] for item in data]).astype(DATATYPE) + gt_iou_map = np.array([item[1] for item in data]).astype(DATATYPE) + gt_start = np.array([item[2] for item in data]).astype(DATATYPE) + gt_end = np.array([item[3] for item in data]).astype(DATATYPE) + + x_data = fluid.dygraph.base.to_variable(video_feat) + gt_iou_map = fluid.dygraph.base.to_variable(gt_iou_map) + gt_start = fluid.dygraph.base.to_variable(gt_start) + gt_end = fluid.dygraph.base.to_variable(gt_end) + gt_iou_map.stop_gradient = True + gt_start.stop_gradient = True + gt_end.stop_gradient = True + + pred_bm, pred_start, pred_end = model(x_data) + + loss, tem_loss, pem_reg_loss, pem_cls_loss = bmn_loss_func( + pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end, config) + avg_loss = fluid.layers.mean(loss) + + if args.log_interval > 0 and (batch_id % args.log_interval == 0): + logger.info('[VALID] iter {} '.format(batch_id) + + '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format( + '%.04f' % avg_loss.numpy()[0], '%.04f' % tem_loss.numpy()[0], \ + '%.04f' % pem_reg_loss.numpy()[0], '%.04f' % pem_cls_loss.numpy()[0])) + + +# TRAIN +def train_bmn(args): + config = parse_config(args.config_file) + train_config = merge_configs(config, 'train', vars(args)) + valid_config = merge_configs(config, 'valid', vars(args)) + + if not args.use_gpu: + place = fluid.CPUPlace() + elif not args.use_data_parallel: + place = fluid.CUDAPlace(0) + else: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + + with fluid.dygraph.guard(place): + if args.use_data_parallel: + strategy = fluid.dygraph.parallel.prepare_context() + bmn = BMN(train_config) + adam = optimizer(train_config, parameter_list=bmn.parameters()) + + if args.use_data_parallel: + bmn = fluid.dygraph.parallel.DataParallel(bmn, strategy) + + if args.resume: + # if resume weights is given, load resume weights directly + assert os.path.exists(args.resume + ".pdparams"), \ + "Given resume weight dir {} not exist.".format(args.resume) + + model, _ = fluid.dygraph.load_dygraph(args.resume) + bmn.set_dict(model) + + reader = BMNReader(mode="train", cfg=train_config) + train_reader = reader.create_reader() + if args.use_data_parallel: + train_reader = fluid.contrib.reader.distributed_batch_reader( + train_reader) + + for epoch in range(args.epoch): + for batch_id, data in enumerate(train_reader()): + video_feat = np.array( + [item[0] for item in data]).astype(DATATYPE) + gt_iou_map = np.array( + [item[1] for item in data]).astype(DATATYPE) + gt_start = np.array([item[2] for item in data]).astype(DATATYPE) + gt_end = np.array([item[3] for item in data]).astype(DATATYPE) + + x_data = fluid.dygraph.base.to_variable(video_feat) + gt_iou_map = fluid.dygraph.base.to_variable(gt_iou_map) + gt_start = fluid.dygraph.base.to_variable(gt_start) + gt_end = fluid.dygraph.base.to_variable(gt_end) + gt_iou_map.stop_gradient = True + gt_start.stop_gradient = True + gt_end.stop_gradient = True + + pred_bm, pred_start, pred_end = bmn(x_data) + + loss, tem_loss, pem_reg_loss, pem_cls_loss = bmn_loss_func( + pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end, + train_config) + avg_loss = fluid.layers.mean(loss) + + if args.use_data_parallel: + avg_loss = bmn.scale_loss(avg_loss) + avg_loss.backward() + bmn.apply_collective_grads() + else: + avg_loss.backward() + + adam.minimize(avg_loss) + + bmn.clear_gradients() + + if args.log_interval > 0 and ( + batch_id % args.log_interval == 0): + logger.info('[TRAIN] Epoch {}, iter {} '.format(epoch, batch_id) + + '\tLoss = {}, \ttem_loss = {}, \tpem_reg_loss = {}, \tpem_cls_loss = {}'.format( + '%.04f' % avg_loss.numpy()[0], '%.04f' % tem_loss.numpy()[0], \ + '%.04f' % pem_reg_loss.numpy()[0], '%.04f' % pem_cls_loss.numpy()[0])) + + logger.info('[TRAIN] Epoch {} training finished'.format(epoch)) + if not os.path.isdir(args.save_dir): + os.makedirs(args.save_dir) + save_model_name = os.path.join( + args.save_dir, "bmn_paddle_dy" + "_epoch{}".format(epoch)) + fluid.dygraph.save_dygraph(bmn.state_dict(), save_model_name) + + # validation + if args.valid_interval > 0 and (epoch + 1 + ) % args.valid_interval == 0: + bmn.eval() + val_bmn(bmn, valid_config, args) + bmn.train() + + #save final results + if fluid.dygraph.parallel.Env().local_rank == 0: + save_model_name = os.path.join(args.save_dir, + "bmn_paddle_dy" + "_final") + fluid.dygraph.save_dygraph(bmn.state_dict(), save_model_name) + logger.info('[TRAIN] training finished') + + +if __name__ == "__main__": + args = parse_args() + train_bmn(args)