提交 58a0ff67 编写于 作者: D dengkaipeng

Merge branch 'master' of https://github.com/PaddlePaddle/hapi into add_tsm

# BMN 视频动作定位模型高层API实现
---
## 内容
- [模型简介](#模型简介)
- [代码结构](#代码结构)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。
<p align="center">
<img src="./BMN.png" height=300 width=500 hspace='10'/> <br />
BMN Overview
</p>
## 代码结构
```
├── bmn.yaml # 网络配置文件,快速配置参数
├── run.sh # 快速运行脚本,可直接开始多卡训练
├── train.py # 训练代码,训练网络
├── eval.py # 评估代码,评估网络性能
├── predict.py # 预测代码,针对任意输入预测结果
├── bmn_model.py # 网络结构与损失函数定义
├── bmn_metric.py # 精度评估方法定义
├── reader.py # 数据reader,构造Dataset和Dataloader
├── bmn_utils.py # 模型细节相关代码
├── config_utils.py # 配置细节相关代码
├── eval_anet_prop.py # 计算精度评估指标
└── infer.list # 推断文件列表
```
## 数据准备
BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。对应的标签文件请下载[label](https://paddlemodels.bj.bcebos.com/video_detection/activitynet_1.3_annotations.json)并修改bmn.yaml中的标签文件路径anno\_file。
## 模型训练
数据准备完成后,可通过如下两种方式启动训练:
默认使用4卡训练,启动方式如下:
bash run.sh
若使用单卡训练,启动方式如下:
export CUDA_VISIBLE_DEVICES=0
python train.py
- 代码运行需要先安装pandas
- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
- 单卡训练时,请将配置文件中的batch_size调整为16
**训练策略:**
* 采用Adam优化器,初始learning\_rate=0.001
* 权重衰减系数为1e-4
* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
## 模型评估
训练完成后,可通过如下方式进行模型评估:
python eval.py --weights=$PATH_TO_WEIGHTS
- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。
- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。
使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下:
- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org)
- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:由于第三方评估代码不支持python3,此处建议使用python2进行评估;若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。)
- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn/Evaluation/data目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。
- 计算精度指标
```python eval_anet_prop.py```
在ActivityNet1.3数据集下评估精度如下:
| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
| :---: | :---: | :---: | :---: | :---: |
| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% |
## 模型推断
可通过如下方式启动模型推断:
python predict.py --weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST
- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。
## 参考论文
- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
MODEL:
name: "BMN"
tscale: 100
dscale: 100
feat_dim: 400
prop_boundary_ratio: 0.5
num_sample: 32
num_sample_perbin: 3
anno_file: "./activitynet_1.3_annotations.json"
feat_path: './fix_feat_100'
TRAIN:
subset: "train"
epoch: 9
batch_size: 4
num_workers: 4
use_shuffle: True
device: "gpu"
num_gpus: 4
learning_rate: 0.001
learning_rate_decay: 0.1
lr_decay_iter: 4200
l2_weight_decay: 1e-4
VALID:
subset: "validation"
TEST:
subset: "validation"
batch_size: 1
num_workers: 1
use_buffer: False
snms_alpha: 0.001
snms_t1: 0.5
snms_t2: 0.9
output_path: "output/EVAL/BMN_results"
result_path: "evaluate_results"
INFER:
subset: "test"
batch_size: 1
num_workers: 1
use_buffer: False
snms_alpha: 0.4
snms_t1: 0.5
snms_t2: 0.9
filelist: './infer.list'
output_path: "output/INFER/BMN_results"
result_path: "predict_results"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import pandas as pd
import os
import sys
import json
sys.path.append('../')
from metrics import Metric
from bmn_utils import boundary_choose, bmn_post_processing
class BmnMetric(Metric):
"""
only support update with batch_size=1
"""
def __init__(self, cfg, mode):
super(BmnMetric, self).__init__()
self.cfg = cfg
self.mode = mode
#get video_dict and video_list
if self.mode == 'test':
self.get_test_dataset_dict()
elif self.mode == 'infer':
self.get_infer_dataset_dict()
def add_metric_op(self, preds, label):
pred_bm, pred_start, pred_en = preds
video_index = label[-1]
return [pred_bm, pred_start, pred_en, video_index] #return list
def update(self, pred_bm, pred_start, pred_end, fid):
# generate proposals
pred_start = pred_start[0]
pred_end = pred_end[0]
fid = fid[0]
if self.mode == 'infer':
output_path = self.cfg.INFER.output_path
else:
output_path = self.cfg.TEST.output_path
tscale = self.cfg.MODEL.tscale
dscale = self.cfg.MODEL.dscale
snippet_xmins = [1.0 / tscale * i for i in range(tscale)]
snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)]
cols = ["xmin", "xmax", "score"]
video_name = self.video_list[fid]
pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :]
start_mask = boundary_choose(pred_start)
start_mask[0] = 1.
end_mask = boundary_choose(pred_end)
end_mask[-1] = 1.
score_vector_list = []
for idx in range(dscale):
for jdx in range(tscale):
start_index = jdx
end_index = start_index + idx
if end_index < tscale and start_mask[
start_index] == 1 and end_mask[end_index] == 1:
xmin = snippet_xmins[start_index]
xmax = snippet_xmaxs[end_index]
xmin_score = pred_start[start_index]
xmax_score = pred_end[end_index]
bm_score = pred_bm[idx, jdx]
conf_score = xmin_score * xmax_score * bm_score
score_vector_list.append([xmin, xmax, conf_score])
score_vector_list = np.stack(score_vector_list)
video_df = pd.DataFrame(score_vector_list, columns=cols)
video_df.to_csv(
os.path.join(output_path, "%s.csv" % video_name), index=False)
return 0 # result has saved in output path
def accumulate(self):
return 'post_processing is required...' # required method
def reset(self):
print("Post_processing....This may take a while")
if self.mode == 'test':
bmn_post_processing(self.video_dict, self.cfg.TEST.subset,
self.cfg.TEST.output_path,
self.cfg.TEST.result_path)
elif self.mode == 'infer':
bmn_post_processing(self.video_dict, self.cfg.INFER.subset,
self.cfg.INFER.output_path,
self.cfg.INFER.result_path)
def name(self):
return 'bmn_metric'
def get_test_dataset_dict(self):
anno_file = self.cfg.MODEL.anno_file
annos = json.load(open(anno_file))
subset = self.cfg.TEST.subset
self.video_dict = {}
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
def get_infer_dataset_dict(self):
file_list = self.cfg.INFER.filelist
annos = json.load(open(file_list))
self.video_dict = {}
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
import math
from bmn_utils import get_interp1d_mask
from model import Model, Loss
DATATYPE = 'float32'
# Net
class Conv1D(fluid.dygraph.Layer):
def __init__(self,
prefix,
num_channels=256,
num_filters=256,
size_k=3,
padding=1,
groups=1,
act="relu"):
super(Conv1D, self).__init__()
fan_in = num_channels * size_k * 1
k = 1. / math.sqrt(fan_in)
param_attr = ParamAttr(
name=prefix + "_w",
initializer=fluid.initializer.Uniform(
low=-k, high=k))
bias_attr = ParamAttr(
name=prefix + "_b",
initializer=fluid.initializer.Uniform(
low=-k, high=k))
self._conv2d = fluid.dygraph.Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=(1, size_k),
stride=1,
padding=(0, padding),
groups=groups,
act=act,
param_attr=param_attr,
bias_attr=bias_attr)
def forward(self, x):
x = fluid.layers.unsqueeze(input=x, axes=[2])
x = self._conv2d(x)
x = fluid.layers.squeeze(input=x, axes=[2])
return x
class BMN(Model):
def __init__(self, cfg, is_dygraph=True):
super(BMN, self).__init__()
#init config
self.tscale = cfg.MODEL.tscale
self.dscale = cfg.MODEL.dscale
self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
self.num_sample = cfg.MODEL.num_sample
self.num_sample_perbin = cfg.MODEL.num_sample_perbin
self.is_dygraph = is_dygraph
self.hidden_dim_1d = 256
self.hidden_dim_2d = 128
self.hidden_dim_3d = 512
# Base Module
self.b_conv1 = Conv1D(
prefix="Base_1",
num_channels=400,
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.b_conv2 = Conv1D(
prefix="Base_2",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
# Temporal Evaluation Module
self.ts_conv1 = Conv1D(
prefix="TEM_s1",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.ts_conv2 = Conv1D(
prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid")
self.te_conv1 = Conv1D(
prefix="TEM_e1",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.te_conv2 = Conv1D(
prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid")
#Proposal Evaluation Module
self.p_conv1 = Conv1D(
prefix="PEM_1d",
num_filters=self.hidden_dim_2d,
size_k=3,
padding=1,
act="relu")
# init to speed up
sample_mask_array = get_interp1d_mask(
self.tscale, self.dscale, self.prop_boundary_ratio,
self.num_sample, self.num_sample_perbin)
if self.is_dygraph:
self.sample_mask = fluid.dygraph.base.to_variable(
sample_mask_array)
else: # static
self.sample_mask = fluid.layers.create_parameter(
shape=[
self.tscale, self.num_sample * self.dscale * self.tscale
],
dtype=DATATYPE,
attr=fluid.ParamAttr(
name="sample_mask", trainable=False),
default_initializer=fluid.initializer.NumpyArrayInitializer(
sample_mask_array))
self.sample_mask.stop_gradient = True
self.p_conv3d1 = fluid.dygraph.Conv3D(
num_channels=128,
num_filters=self.hidden_dim_3d,
filter_size=(self.num_sample, 1, 1),
stride=(self.num_sample, 1, 1),
padding=0,
act="relu",
param_attr=ParamAttr(name="PEM_3d1_w"),
bias_attr=ParamAttr(name="PEM_3d1_b"))
self.p_conv2d1 = fluid.dygraph.Conv2D(
num_channels=512,
num_filters=self.hidden_dim_2d,
filter_size=1,
stride=1,
padding=0,
act="relu",
param_attr=ParamAttr(name="PEM_2d1_w"),
bias_attr=ParamAttr(name="PEM_2d1_b"))
self.p_conv2d2 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=self.hidden_dim_2d,
filter_size=3,
stride=1,
padding=1,
act="relu",
param_attr=ParamAttr(name="PEM_2d2_w"),
bias_attr=ParamAttr(name="PEM_2d2_b"))
self.p_conv2d3 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=self.hidden_dim_2d,
filter_size=3,
stride=1,
padding=1,
act="relu",
param_attr=ParamAttr(name="PEM_2d3_w"),
bias_attr=ParamAttr(name="PEM_2d3_b"))
self.p_conv2d4 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=2,
filter_size=1,
stride=1,
padding=0,
act="sigmoid",
param_attr=ParamAttr(name="PEM_2d4_w"),
bias_attr=ParamAttr(name="PEM_2d4_b"))
def forward(self, x):
#Base Module
x = self.b_conv1(x)
x = self.b_conv2(x)
#TEM
xs = self.ts_conv1(x)
xs = self.ts_conv2(xs)
xs = fluid.layers.squeeze(xs, axes=[1])
xe = self.te_conv1(x)
xe = self.te_conv2(xe)
xe = fluid.layers.squeeze(xe, axes=[1])
#PEM
xp = self.p_conv1(x)
#BM layer
xp = fluid.layers.matmul(xp, self.sample_mask)
xp = fluid.layers.reshape(
xp, shape=[0, 0, -1, self.dscale, self.tscale])
xp = self.p_conv3d1(xp)
xp = fluid.layers.squeeze(xp, axes=[2])
xp = self.p_conv2d1(xp)
xp = self.p_conv2d2(xp)
xp = self.p_conv2d3(xp)
xp = self.p_conv2d4(xp)
return xp, xs, xe
class BmnLoss(Loss):
def __init__(self, cfg):
super(BmnLoss, self).__init__()
self.cfg = cfg
def _get_mask(self):
dscale = self.cfg.MODEL.dscale
tscale = self.cfg.MODEL.tscale
bm_mask = []
for idx in range(dscale):
mask_vector = [1 for i in range(tscale - idx)
] + [0 for i in range(idx)]
bm_mask.append(mask_vector)
bm_mask = np.array(bm_mask, dtype=np.float32)
self_bm_mask = fluid.layers.create_global_var(
shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True)
fluid.layers.assign(bm_mask, self_bm_mask)
self_bm_mask.stop_gradient = True
return self_bm_mask
def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end):
def bi_loss(pred_score, gt_label):
pred_score = fluid.layers.reshape(
x=pred_score, shape=[-1], inplace=False)
gt_label = fluid.layers.reshape(
x=gt_label, shape=[-1], inplace=False)
gt_label.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE)
num_entries = fluid.layers.cast(
fluid.layers.shape(pmask), dtype=DATATYPE)
num_positive = fluid.layers.cast(
fluid.layers.reduce_sum(pmask), dtype=DATATYPE)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
temp = fluid.layers.log(pred_score + epsilon)
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
loss = -1 * (loss_pos + loss_neg)
return loss
loss_start = bi_loss(pred_start, gt_start)
loss_end = bi_loss(pred_end, gt_end)
loss = loss_start + loss_end
return loss
def pem_reg_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE)
u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3)
u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.)
u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
u_lmask = fluid.layers.elementwise_mul(u_lmask, mask)
num_h = fluid.layers.cast(
fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
num_m = fluid.layers.cast(
fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
num_l = fluid.layers.cast(
fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
r_m = num_h / num_m
u_smmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
r_l = num_h / num_l
u_slmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
weights = u_hmask + u_smmask + u_slmask
weights.stop_gradient = True
loss = fluid.layers.square_error_cost(pred_score, gt_iou_map)
loss = fluid.layers.elementwise_mul(loss, weights)
loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
weights)
return loss
def pem_cls_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
gt_iou_map.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE)
nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE)
nmask = fluid.layers.elementwise_mul(nmask, mask)
num_positive = fluid.layers.reduce_sum(pmask)
num_entries = num_positive + fluid.layers.reduce_sum(nmask)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), nmask)
loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg)
loss = -1 * (loss_pos + loss_neg) / num_entries
return loss
def forward(self, outputs, labels):
pred_bm, pred_start, pred_end = outputs
if len(labels) == 3:
gt_iou_map, gt_start, gt_end = labels
elif len(labels) == 4: # video_index used in eval mode
gt_iou_map, gt_start, gt_end, video_index = labels
pred_bm_reg = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[0], ends=[1]),
axes=[1])
pred_bm_cls = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[1], ends=[2]),
axes=[1])
bm_mask = self._get_mask()
pem_reg_loss = self.pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask)
pem_cls_loss = self.pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask)
tem_loss = self.tem_loss_func(pred_start, pred_end, gt_start, gt_end)
loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import pandas as pd
import multiprocessing as mp
import json
import os
import math
def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute jaccard score between a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
union_len = len_anchors - inter_len + box_max - box_min
jaccard = np.divide(inter_len, union_len)
return jaccard
def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute intersection between score a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
scores = np.divide(inter_len, len_anchors)
return scores
def boundary_choose(score_list):
max_score = max(score_list)
mask_high = (score_list > max_score * 0.5)
score_list = list(score_list)
score_middle = np.array([0.0] + score_list + [0.0])
score_front = np.array([0.0, 0.0] + score_list)
score_back = np.array(score_list + [0.0, 0.0])
mask_peak = ((score_middle > score_front) & (score_middle > score_back))
mask_peak = mask_peak[1:-1]
mask = (mask_high | mask_peak).astype('float32')
return mask
def soft_nms(df, alpha, t1, t2):
'''
df: proposals generated by network;
alpha: alpha value of Gaussian decaying function;
t1, t2: threshold for soft nms.
'''
df = df.sort_values(by="score", ascending=False)
tstart = list(df.xmin.values[:])
tend = list(df.xmax.values[:])
tscore = list(df.score.values[:])
rstart = []
rend = []
rscore = []
while len(tscore) > 1 and len(rscore) < 101:
max_index = tscore.index(max(tscore))
tmp_iou_list = iou_with_anchors(
np.array(tstart),
np.array(tend), tstart[max_index], tend[max_index])
for idx in range(0, len(tscore)):
if idx != max_index:
tmp_iou = tmp_iou_list[idx]
tmp_width = tend[max_index] - tstart[max_index]
if tmp_iou > t1 + (t2 - t1) * tmp_width:
tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
alpha)
rstart.append(tstart[max_index])
rend.append(tend[max_index])
rscore.append(tscore[max_index])
tstart.pop(max_index)
tend.pop(max_index)
tscore.pop(max_index)
newDf = pd.DataFrame()
newDf['score'] = rscore
newDf['xmin'] = rstart
newDf['xmax'] = rend
return newDf
def video_process(video_list,
video_dict,
output_path,
result_dict,
snms_alpha=0.4,
snms_t1=0.55,
snms_t2=0.9):
for video_name in video_list:
print("Processing video........" + video_name)
df = pd.read_csv(os.path.join(output_path, video_name + ".csv"))
if len(df) > 1:
df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
video_duration = video_dict[video_name]["duration_second"]
proposal_list = []
for idx in range(min(100, len(df))):
tmp_prop={"score":df.score.values[idx], \
"segment":[max(0,df.xmin.values[idx])*video_duration, \
min(1,df.xmax.values[idx])*video_duration]}
proposal_list.append(tmp_prop)
result_dict[video_name[2:]] = proposal_list
def bmn_post_processing(video_dict, subset, output_path, result_path):
video_list = video_dict.keys()
video_list = list(video_dict.keys())
global result_dict
result_dict = mp.Manager().dict()
pp_num = 12
num_videos = len(video_list)
num_videos_per_thread = int(num_videos / pp_num)
processes = []
for tid in range(pp_num - 1):
tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
num_videos_per_thread]
p = mp.Process(
target=video_process,
args=(tmp_video_list, video_dict, output_path, result_dict))
p.start()
processes.append(p)
tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
p = mp.Process(
target=video_process,
args=(tmp_video_list, video_dict, output_path, result_dict))
p.start()
processes.append(p)
for p in processes:
p.join()
result_dict = dict(result_dict)
output_dict = {
"version": "VERSION 1.3",
"results": result_dict,
"external_data": {}
}
outfile = open(
os.path.join(result_path, "bmn_results_%s.json" % subset), "w")
json.dump(output_dict, outfile)
outfile.close()
def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample,
num_sample_perbin):
""" generate sample mask for a boundary-matching pair """
plen = float(seg_xmax - seg_xmin)
plen_sample = plen / (num_sample * num_sample_perbin - 1.0)
total_samples = [
seg_xmin + plen_sample * ii
for ii in range(num_sample * num_sample_perbin)
]
p_mask = []
for idx in range(num_sample):
bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) *
num_sample_perbin]
bin_vector = np.zeros([tscale])
for sample in bin_samples:
sample_upper = math.ceil(sample)
sample_decimal, sample_down = math.modf(sample)
if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0:
bin_vector[int(sample_down)] += 1 - sample_decimal
if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0:
bin_vector[int(sample_upper)] += sample_decimal
bin_vector = 1.0 / num_sample_perbin * bin_vector
p_mask.append(bin_vector)
p_mask = np.stack(p_mask, axis=1)
return p_mask
def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample,
num_sample_perbin):
""" generate sample mask for each point in Boundary-Matching Map """
mask_mat = []
for start_index in range(tscale):
mask_mat_vector = []
for duration_index in range(dscale):
if start_index + duration_index < tscale:
p_xmin = start_index
p_xmax = start_index + duration_index
center_len = float(p_xmax - p_xmin) + 1
sample_xmin = p_xmin - center_len * prop_boundary_ratio
sample_xmax = p_xmax + center_len * prop_boundary_ratio
p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax,
tscale, num_sample,
num_sample_perbin)
else:
p_mask = np.zeros([tscale, num_sample])
mask_mat_vector.append(p_mask)
mask_mat_vector = np.stack(mask_mat_vector, axis=2)
mask_mat.append(mask_mat_vector)
mask_mat = np.stack(mask_mat, axis=3)
mask_mat = mask_mat.astype(np.float32)
sample_mask = np.reshape(mask_mat, [tscale, -1])
return sample_mask
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import os
import sys
import logging
import paddle.fluid as fluid
sys.path.append('../')
from model import set_device, Input
from bmn_metric import BmnMetric
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("BMN test for performance evaluation.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode, only support dynamic mode at present time")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--device',
type=str,
default='gpu',
help='gpu or cpu, default use gpu.')
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Performance Evaluation
def test_bmn(args):
# only support dynamic mode at present time
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
config = parse_config(args.config_file)
eval_cfg = merge_configs(config, 'test', vars(args))
if not os.path.isdir(config.TEST.output_path):
os.makedirs(config.TEST.output_path)
if not os.path.isdir(config.TEST.result_path):
os.makedirs(config.TEST.result_path)
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
gt_iou_map = Input(
[None, config.MODEL.dscale, config.MODEL.tscale],
'float32',
name='gt_iou_map')
gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
video_idx = Input([None, 1], 'int64', name='video_idx')
labels = [gt_iou_map, gt_start, gt_end, video_idx]
#data
eval_dataset = BmnDataset(eval_cfg, 'test')
#model
model = BMN(config, args.dynamic)
model.prepare(
loss_function=BmnLoss(config),
metrics=BmnMetric(
config, mode='test'),
inputs=inputs,
labels=labels,
device=device)
#load checkpoint
if args.weights:
assert os.path.exists(args.weights + '.pdparams'), \
"Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
model.evaluate(
eval_data=eval_dataset,
batch_size=eval_cfg.TEST.batch_size,
num_workers=eval_cfg.TEST.num_workers,
log_freq=args.log_interval)
logger.info("[EVAL] eval finished")
if __name__ == '__main__':
args = parse_args()
test_bmn(args)
'''
Calculate AR@N and AUC;
Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)
'''
import sys
sys.path.append('./Evaluation')
from eval_proposal import ANETproposal
import numpy as np
import argparse
import os
parser = argparse.ArgumentParser("Eval AR vs AN of proposal")
parser.add_argument(
'--eval_file',
type=str,
default='bmn_results_validation.json',
help='name of results file to eval')
def run_evaluation(ground_truth_filename,
proposal_filename,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation'):
anet_proposal = ANETproposal(
ground_truth_filename,
proposal_filename,
tiou_thresholds=tiou_thresholds,
max_avg_nr_proposals=max_avg_nr_proposals,
subset=subset,
verbose=True,
check_status=False)
anet_proposal.evaluate()
recall = anet_proposal.recall
average_recall = anet_proposal.avg_recall
average_nr_proposals = anet_proposal.proposals_per_video
return (average_nr_proposals, average_recall, recall)
def plot_metric(average_nr_proposals,
average_recall,
recall,
tiou_thresholds=np.linspace(0.5, 0.95, 10)):
fn_size = 14
plt.figure(num=None, figsize=(12, 8))
ax = plt.subplot(1, 1, 1)
colors = [
'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo'
]
area_under_curve = np.zeros_like(tiou_thresholds)
for i in range(recall.shape[0]):
area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
for idx, tiou in enumerate(tiou_thresholds[::2]):
ax.plot(
average_nr_proposals,
recall[2 * idx, :],
color=colors[idx + 1],
label="tiou=[" + str(tiou) + "], area=" + str(
int(area_under_curve[2 * idx] * 100) / 100.),
linewidth=4,
linestyle='--',
marker=None)
# Plots Average Recall vs Average number of proposals.
ax.plot(
average_nr_proposals,
average_recall,
color=colors[0],
label="tiou = 0.5:0.05:0.95," + " area=" + str(
int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.),
linewidth=4,
linestyle='-',
marker=None)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
[handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
plt.ylabel('Average Recall', fontsize=fn_size)
plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
plt.grid(b=True, which="both")
plt.ylim([0, 1.0])
plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
plt.show()
if __name__ == "__main__":
args = parser.parse_args()
eval_file = args.eval_file
eval_file_path = os.path.join("evaluate_results", eval_file)
uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
"./Evaluation/data/activity_net_1_3_new.json",
eval_file_path,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation')
print("AR@1; AR@5; AR@10; AR@100")
print("%.02f %.02f %.02f %.02f" %
(100 * np.mean(uniform_recall_valid[:, 0]),
100 * np.mean(uniform_recall_valid[:, 4]),
100 * np.mean(uniform_recall_valid[:, 9]),
100 * np.mean(uniform_recall_valid[:, -1])))
{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}}
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import sys
import os
import logging
import paddle.fluid as fluid
sys.path.append('../')
from model import set_device, Input
from bmn_metric import BmnMetric
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("BMN inference.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode, only support dynamic mode at present time")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--device', type=str, default='GPU', help='default use gpu.')
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--save_dir',
type=str,
default="predict_results/",
help='output dir path, default to use ./predict_results/')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Prediction
def infer_bmn(args):
# only support dynamic mode at present time
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
config = parse_config(args.config_file)
infer_cfg = merge_configs(config, 'infer', vars(args))
if not os.path.isdir(config.INFER.output_path):
os.makedirs(config.INFER.output_path)
if not os.path.isdir(config.INFER.result_path):
os.makedirs(config.INFER.result_path)
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
labels = [Input([None, 1], 'int64', name='video_idx')]
#data
infer_dataset = BmnDataset(infer_cfg, 'infer')
model = BMN(config, args.dynamic)
model.prepare(
metrics=BmnMetric(
config, mode='infer'),
inputs=inputs,
labels=labels,
device=device)
# load checkpoint
if args.weights:
assert os.path.exists(
args.weights +
".pdparams"), "Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
# here use model.eval instead of model.test, as post process is required in our case
model.evaluate(
eval_data=infer_dataset,
batch_size=infer_cfg.TEST.batch_size,
num_workers=infer_cfg.TEST.num_workers,
log_freq=args.log_interval)
logger.info("[INFER] infer finished")
if __name__ == '__main__':
args = parse_args()
infer_bmn(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import numpy as np
import json
import logging
import os
import sys
sys.path.append('../')
from distributed import DistributedBatchSampler
from paddle.fluid.io import Dataset, DataLoader
logger = logging.getLogger(__name__)
from config_utils import *
from bmn_utils import iou_with_anchors, ioa_with_anchors
DATATYPE = "float32"
class BmnDataset(Dataset):
def __init__(self, cfg, mode):
self.mode = mode
self.tscale = cfg.MODEL.tscale # 100
self.dscale = cfg.MODEL.dscale # 100
self.anno_file = cfg.MODEL.anno_file
self.feat_path = cfg.MODEL.feat_path
self.file_list = cfg.INFER.filelist
self.subset = cfg[mode.upper()]['subset']
self.tgap = 1. / self.tscale
self.get_dataset_dict()
self.get_match_map()
def __getitem__(self, index):
video_name = self.video_list[index]
video_idx = self.video_list.index(video_name)
video_feat = self.load_file(video_name)
if self.mode == 'infer':
return video_feat, video_idx
else:
gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
if self.mode == 'train' or self.mode == 'valid':
return video_feat, gt_iou_map, gt_start, gt_end
elif self.mode == 'test':
return video_feat, gt_iou_map, gt_start, gt_end, video_idx
def __len__(self):
return len(self.video_list)
def get_dataset_dict(self):
assert (
os.path.exists(self.feat_path)), "Input feature path not exists"
assert (os.listdir(self.feat_path)), "No feature file in feature path"
self.video_dict = {}
if self.mode == "infer":
annos = json.load(open(self.file_list))
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
print("%s subset video numbers: %d" %
(self.subset, len(self.video_list)))
video_name_set = set(
[video_name + '.npy' for video_name in self.video_list])
assert (video_name_set.intersection(set(os.listdir(self.feat_path))) ==
video_name_set), "Input feature not exists in feature path"
def get_match_map(self):
match_map = []
for idx in range(self.tscale):
tmp_match_window = []
xmin = self.tgap * idx
for jdx in range(1, self.tscale + 1):
xmax = xmin + self.tgap * jdx
tmp_match_window.append([xmin, xmax])
match_map.append(tmp_match_window)
match_map = np.array(match_map)
match_map = np.transpose(match_map, [1, 0, 2])
match_map = np.reshape(match_map, [-1, 2])
self.match_map = match_map
self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
def get_video_label(self, video_name):
video_info = self.video_dict[video_name]
video_second = video_info['duration_second']
video_labels = video_info['annotations']
gt_bbox = []
gt_iou_map = []
for gt in video_labels:
tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
gt_bbox.append([tmp_start, tmp_end])
tmp_gt_iou_map = iou_with_anchors(
self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end)
tmp_gt_iou_map = np.reshape(tmp_gt_iou_map,
[self.dscale, self.tscale])
gt_iou_map.append(tmp_gt_iou_map)
gt_iou_map = np.array(gt_iou_map)
gt_iou_map = np.max(gt_iou_map, axis=0)
gt_bbox = np.array(gt_bbox)
gt_xmins = gt_bbox[:, 0]
gt_xmaxs = gt_bbox[:, 1]
gt_len_small = 3 * self.tgap
gt_start_bboxs = np.stack(
(gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
gt_end_bboxs = np.stack(
(gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
match_score_start = []
for jdx in range(len(self.anchor_xmin)):
match_score_start.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
match_score_end = []
for jdx in range(len(self.anchor_xmin)):
match_score_end.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
gt_start = np.array(match_score_start)
gt_end = np.array(match_score_end)
return gt_iou_map.astype(DATATYPE), gt_start.astype(
DATATYPE), gt_end.astype(DATATYPE)
def load_file(self, video_name):
file_name = video_name + ".npy"
file_path = os.path.join(self.feat_path, file_name)
video_feat = np.load(file_path)
video_feat = video_feat.T
video_feat = video_feat.astype("float32")
return video_feat
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch train.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
import argparse
import logging
import sys
import os
sys.path.append('../')
from model import set_device, Input
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Paddle high level api of BMN.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='training batch size. None to use config file setting.')
parser.add_argument(
'--learning_rate',
type=float,
default=0.001,
help='learning rate use for training. None to use config file setting.')
parser.add_argument(
'--resume',
type=str,
default=None,
help='filename to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--device',
type=str,
default='gpu',
help='gpu or cpu, default use gpu.')
parser.add_argument(
'--epoch',
type=int,
default=9,
help='epoch number, 0 for read from config file')
parser.add_argument(
'--valid_interval',
type=int,
default=1,
help='validation epoch interval, 0 for no validation.')
parser.add_argument(
'--save_dir',
type=str,
default="checkpoint",
help='path to save train snapshoot')
parser.add_argument(
'--log_interval',
type=int,
default=10,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Optimizer
def optimizer(cfg, parameter_list):
bd = [cfg.TRAIN.lr_decay_iter]
base_lr = cfg.TRAIN.learning_rate
lr_decay = cfg.TRAIN.learning_rate_decay
l2_weight_decay = cfg.TRAIN.l2_weight_decay
lr = [base_lr, base_lr * lr_decay]
optimizer = fluid.optimizer.Adam(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
parameter_list=parameter_list,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=l2_weight_decay))
return optimizer
# TRAIN
def train_bmn(args):
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
config = parse_config(args.config_file)
train_cfg = merge_configs(config, 'train', vars(args))
val_cfg = merge_configs(config, 'valid', vars(args))
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
gt_iou_map = Input(
[None, config.MODEL.dscale, config.MODEL.tscale],
'float32',
name='gt_iou_map')
gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
labels = [gt_iou_map, gt_start, gt_end]
# data
train_dataset = BmnDataset(train_cfg, 'train')
val_dataset = BmnDataset(val_cfg, 'valid')
# model
model = BMN(config, args.dynamic)
optim = optimizer(config, parameter_list=model.parameters())
model.prepare(
optimizer=optim,
loss_function=BmnLoss(config),
inputs=inputs,
labels=labels,
device=device)
# if resume weights is given, load resume weights directly
if args.resume is not None:
model.load(args.resume)
model.fit(train_data=train_dataset,
eval_data=val_dataset,
batch_size=train_cfg.TRAIN.batch_size,
epochs=args.epoch,
eval_freq=args.valid_interval,
log_freq=args.log_interval,
save_dir=args.save_dir,
shuffle=train_cfg.TRAIN.use_shuffle,
num_workers=train_cfg.TRAIN.num_workers,
drop_last=True)
if __name__ == "__main__":
args = parse_args()
train_bmn(args)
import os
import sys
import cv2
from paddle.fluid.io import Dataset
def has_valid_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x):
return has_valid_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
class DatasetFolder(Dataset):
"""A generic data loader where the samples are arranged in this way:
root/class_a/1.ext
root/class_a/2.ext
root/class_a/3.ext
root/class_b/123.ext
root/class_b/456.ext
root/class_b/789.ext
Args:
root (string): Root directory path.
loader (callable, optional): A function to load a sample given its path.
extensions (tuple[string], optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self,
root,
loader=None,
extensions=None,
transform=None,
target_transform=None,
is_valid_file=None):
self.root = root
if extensions is None:
extensions = IMG_EXTENSIONS
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions,
is_valid_file)
if len(samples) == 0:
raise (RuntimeError(
"Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = cv2_loader if loader is None else loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir),
and class_to_idx is a dictionary.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [
d for d in os.listdir(dir)
if os.path.isdir(os.path.join(dir, d))
]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def cv2_loader(path):
return cv2.imread(path)
# 高级api图像分类
## 数据集准备
在开始训练前,请确保已经下载解压好[ImageNet数据集](http://image-net.org/download),并放在合适的目录下,准备好的数据集的目录结构如下所示:
```bash
/path/to/imagenet
train
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
val
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
```
## 训练
### 单卡训练
执行如下命令进行训练
```bash
python -u main.py --arch resnet50 /path/to/imagenet -d
```
### 多卡训练
执行如下命令进行训练
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 -d /path/to/imagenet
```
## 预测
### 单卡预测
执行如下命令进行预测
```bash
python -u main.py --arch resnet50 -d --evaly-only /path/to/imagenet
```
### 多卡预测
执行如下命令进行多卡预测
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 --evaly-only /path/to/imagenet
```
## 参数说明
* **arch**: 要训练或预测的模型名称
* **device**: 训练使用的设备,'gpu'或'cpu',默认值:'gpu'
* **dynamic**: 是否使用动态图模式训练
* **epoch**: 训练的轮数,默认值:120
* **learning-rate**: 学习率,默认值:0.1
* **batch-size**: 每张卡的batch size,默认值:64
* **output-dir**: 模型文件保存的文件夹,默认值:'output'
* **num-workers**: dataloader的进程数,默认值:4
* **resume**: 恢复训练的模型路径,默认值:None
* **eval-only**: 仅仅进行预测,默认值:False
## 模型
| 模型 | top1 acc | top5 acc |
| --- | --- | --- |
| ResNet50 | 76.28 | 93.04 |
import os
import cv2
import math
import random
import numpy as np
from datasets.folder import DatasetFolder
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i:i + c, j:j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i:i + h, j:j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
return img
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
class ImageNetDataset(DatasetFolder):
def __init__(self, path, mode='train'):
super(ImageNetDataset, self).__init__(path)
self.mode = mode
if self.mode == 'train':
self.transform = compose([
cv2.imread, random_crop_resize, random_flip, normalize_permute
])
else:
self.transform = compose(
[cv2.imread, center_crop_resize, normalize_permute])
def __getitem__(self, idx):
img, label = self.samples[idx]
return self.transform((img, [label]))
def __len__(self):
return len(self.samples)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import sys
sys.path.append('../')
import time
import math
import numpy as np
import models
import paddle.fluid as fluid
from model import CrossEntropy, Input, set_device
from imagenet_dataset import ImageNetDataset
from distributed import DistributedBatchSampler
from paddle.fluid.dygraph.parallel import ParallelEnv
from metrics import Accuracy
from paddle.fluid.io import BatchSampler, DataLoader
def make_optimizer(step_per_epoch, parameter_list=None):
base_lr = FLAGS.lr
momentum = 0.9
weight_decay = 1e-4
boundaries = [step_per_epoch * e for e in [30, 60, 80]]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=5 * step_per_epoch,
start_lr=0.,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=momentum,
regularization=fluid.regularizer.L2Decay(weight_decay),
parameter_list=parameter_list)
return optimizer
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and
not FLAGS.resume)
if FLAGS.resume is not None:
model.load(FLAGS.resume)
inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
train_dataset = ImageNetDataset(
os.path.join(FLAGS.data, 'train'), mode='train')
val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val')
optim = make_optimizer(
np.ceil(
len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks),
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels)
if FLAGS.eval_only:
model.evaluate(
val_dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.num_workers)
return
output_dir = os.path.join(FLAGS.output_dir, FLAGS.arch,
time.strftime('%Y-%m-%d-%H-%M',
time.localtime()))
if ParallelEnv().local_rank == 0 and not os.path.exists(output_dir):
os.makedirs(output_dir)
model.fit(train_dataset,
val_dataset,
batch_size=FLAGS.batch_size,
epochs=FLAGS.epoch,
save_dir=output_dir,
num_workers=FLAGS.num_workers)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument(
'data',
metavar='DIR',
help='path to dataset '
'(should have subdirectories named "train" and "val"')
parser.add_argument(
"--arch", type=str, default='resnet50', help="model name")
parser.add_argument(
"--device", type=str, default='gpu', help="device to run, cpu or gpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=90, type=int, help="number of epoch")
parser.add_argument(
'--lr',
'--learning-rate',
default=0.1,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch-size", default=64, type=int, help="batch size")
parser.add_argument(
"-n", "--num-workers", default=4, type=int, help="dataloader workers")
parser.add_argument(
"--output-dir", type=str, default='output', help="save dir")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
parser.add_argument(
"--eval-only", action='store_true', help="enable dygraph mode")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
...@@ -410,7 +410,8 @@ class StaticGraphAdapter(object): ...@@ -410,7 +410,8 @@ class StaticGraphAdapter(object):
and self.model._optimizer._learning_rate_map: and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue # HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog] lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var new_lr_var = prog.global_block().vars[lr_var.name]
self.model._optimizer._learning_rate_map[prog] = new_lr_var
losses = [] losses = []
metrics = [] metrics = []
......
from .resnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import shutil
import requests
import tqdm
import hashlib
import time
from paddle.fluid.dygraph.parallel import ParallelEnv
import logging
logger = logging.getLogger(__name__)
__all__ = ['get_weights_path']
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
DOWNLOAD_RETRY_LIMIT = 3
def get_weights_path(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
"""
path, _ = get_path(url, WEIGHTS_HOME, md5sum)
return path
def map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
"""
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)
exist_flag = False
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
exist_flag = True
if ParallelEnv().local_rank == 0:
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
return fullpath, exist_flag
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if ParallelEnv().local_rank == 0:
logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from model import Model
from .download import get_weights_path
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
model_urls = {
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'0884c9087266496c41c60d14a96f8530')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(x)
# return fluid.layers.relu(x)
class ResNet(Model):
def __init__(self, Block, depth=50, num_classes=1000):
super(ResNet, self).__init__()
layer_config = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
assert depth in layer_config.keys(), \
"supported depth are {} but input layer is {}".format(
layer_config.keys(), depth)
layers = layer_config[depth]
num_in = [64, 256, 512, 1024]
num_out = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.layers = []
for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False
for b in range(num_blocks):
block = Block(
num_channels=num_in[idx] if b == 0 else num_out[idx] * 4,
num_filters=num_out[idx],
stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut)
blocks.append(block)
shortcut = True
layer = self.add_sublayer("layer_{}".format(idx),
Sequential(*blocks))
self.layers.append(layer)
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.fc_input_dim = num_out[-1] * 4 * 1 * 1
self.fc = Linear(
self.fc_input_dim,
num_classes,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained):
model = ResNet(Block, depth)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def resnet50(pretrained=False):
return _resnet('resnet50', BottleneckBlock, 50, pretrained)
def resnet101(pretrained=False):
return _resnet('resnet101', BottleneckBlock, 101, pretrained)
def resnet152(pretrained=False):
return _resnet('resnet152', BottleneckBlock, 152, pretrained)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册