提交 8aca373d 编写于 作者: G guosheng

Merge branch 'master' of https://github.com/PaddlePaddle/hapi into add-hapi-seq2seq

# BMN 视频动作定位模型高层API实现
---
## 内容
- [模型简介](#模型简介)
- [代码结构](#代码结构)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。
<p align="center">
<img src="./BMN.png" height=300 width=500 hspace='10'/> <br />
BMN Overview
</p>
## 代码结构
```
├── bmn.yaml # 网络配置文件,快速配置参数
├── run.sh # 快速运行脚本,可直接开始多卡训练
├── train.py # 训练代码,训练网络
├── eval.py # 评估代码,评估网络性能
├── predict.py # 预测代码,针对任意输入预测结果
├── bmn_model.py # 网络结构与损失函数定义
├── bmn_metric.py # 精度评估方法定义
├── reader.py # 数据reader,构造Dataset和Dataloader
├── bmn_utils.py # 模型细节相关代码
├── config_utils.py # 配置细节相关代码
├── eval_anet_prop.py # 计算精度评估指标
└── infer.list # 推断文件列表
```
## 数据准备
BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。对应的标签文件请下载[label](https://paddlemodels.bj.bcebos.com/video_detection/activitynet_1.3_annotations.json)并修改bmn.yaml中的标签文件路径anno\_file。
## 模型训练
数据准备完成后,可通过如下两种方式启动训练:
默认使用4卡训练,启动方式如下:
bash run.sh
若使用单卡训练,启动方式如下:
export CUDA_VISIBLE_DEVICES=0
python train.py
- 代码运行需要先安装pandas
- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
- 单卡训练时,请将配置文件中的batch_size调整为16
**训练策略:**
* 采用Adam优化器,初始learning\_rate=0.001
* 权重衰减系数为1e-4
* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
## 模型评估
训练完成后,可通过如下方式进行模型评估:
python eval.py --weights=$PATH_TO_WEIGHTS
- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。
- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。
使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下:
- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org)
- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:由于第三方评估代码不支持python3,此处建议使用python2进行评估;若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。)
- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn/Evaluation/data目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。
- 计算精度指标
```python eval_anet_prop.py```
在ActivityNet1.3数据集下评估精度如下:
| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
| :---: | :---: | :---: | :---: | :---: |
| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% |
## 模型推断
可通过如下方式启动模型推断:
python predict.py --weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST
- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。
## 参考论文
- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
MODEL:
name: "BMN"
tscale: 100
dscale: 100
feat_dim: 400
prop_boundary_ratio: 0.5
num_sample: 32
num_sample_perbin: 3
anno_file: "./activitynet_1.3_annotations.json"
feat_path: './fix_feat_100'
TRAIN:
subset: "train"
epoch: 9
batch_size: 4
num_workers: 4
use_shuffle: True
device: "gpu"
num_gpus: 4
learning_rate: 0.001
learning_rate_decay: 0.1
lr_decay_iter: 4200
l2_weight_decay: 1e-4
VALID:
subset: "validation"
TEST:
subset: "validation"
batch_size: 1
num_workers: 1
use_buffer: False
snms_alpha: 0.001
snms_t1: 0.5
snms_t2: 0.9
output_path: "output/EVAL/BMN_results"
result_path: "evaluate_results"
INFER:
subset: "test"
batch_size: 1
num_workers: 1
use_buffer: False
snms_alpha: 0.4
snms_t1: 0.5
snms_t2: 0.9
filelist: './infer.list'
output_path: "output/INFER/BMN_results"
result_path: "predict_results"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import pandas as pd
import os
import sys
import json
sys.path.append('../')
from metrics import Metric
from bmn_utils import boundary_choose, bmn_post_processing
class BmnMetric(Metric):
"""
only support update with batch_size=1
"""
def __init__(self, cfg, mode):
super(BmnMetric, self).__init__()
self.cfg = cfg
self.mode = mode
#get video_dict and video_list
if self.mode == 'test':
self.get_test_dataset_dict()
elif self.mode == 'infer':
self.get_infer_dataset_dict()
def add_metric_op(self, preds, label):
pred_bm, pred_start, pred_en = preds
video_index = label[-1]
return [pred_bm, pred_start, pred_en, video_index] #return list
def update(self, pred_bm, pred_start, pred_end, fid):
# generate proposals
pred_start = pred_start[0]
pred_end = pred_end[0]
fid = fid[0]
if self.mode == 'infer':
output_path = self.cfg.INFER.output_path
else:
output_path = self.cfg.TEST.output_path
tscale = self.cfg.MODEL.tscale
dscale = self.cfg.MODEL.dscale
snippet_xmins = [1.0 / tscale * i for i in range(tscale)]
snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)]
cols = ["xmin", "xmax", "score"]
video_name = self.video_list[fid]
pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :]
start_mask = boundary_choose(pred_start)
start_mask[0] = 1.
end_mask = boundary_choose(pred_end)
end_mask[-1] = 1.
score_vector_list = []
for idx in range(dscale):
for jdx in range(tscale):
start_index = jdx
end_index = start_index + idx
if end_index < tscale and start_mask[
start_index] == 1 and end_mask[end_index] == 1:
xmin = snippet_xmins[start_index]
xmax = snippet_xmaxs[end_index]
xmin_score = pred_start[start_index]
xmax_score = pred_end[end_index]
bm_score = pred_bm[idx, jdx]
conf_score = xmin_score * xmax_score * bm_score
score_vector_list.append([xmin, xmax, conf_score])
score_vector_list = np.stack(score_vector_list)
video_df = pd.DataFrame(score_vector_list, columns=cols)
video_df.to_csv(
os.path.join(output_path, "%s.csv" % video_name), index=False)
return 0 # result has saved in output path
def accumulate(self):
return 'post_processing is required...' # required method
def reset(self):
print("Post_processing....This may take a while")
if self.mode == 'test':
bmn_post_processing(self.video_dict, self.cfg.TEST.subset,
self.cfg.TEST.output_path,
self.cfg.TEST.result_path)
elif self.mode == 'infer':
bmn_post_processing(self.video_dict, self.cfg.INFER.subset,
self.cfg.INFER.output_path,
self.cfg.INFER.result_path)
def name(self):
return 'bmn_metric'
def get_test_dataset_dict(self):
anno_file = self.cfg.MODEL.anno_file
annos = json.load(open(anno_file))
subset = self.cfg.TEST.subset
self.video_dict = {}
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
def get_infer_dataset_dict(self):
file_list = self.cfg.INFER.filelist
annos = json.load(open(file_list))
self.video_dict = {}
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
import math
from bmn_utils import get_interp1d_mask
from model import Model, Loss
DATATYPE = 'float32'
# Net
class Conv1D(fluid.dygraph.Layer):
def __init__(self,
prefix,
num_channels=256,
num_filters=256,
size_k=3,
padding=1,
groups=1,
act="relu"):
super(Conv1D, self).__init__()
fan_in = num_channels * size_k * 1
k = 1. / math.sqrt(fan_in)
param_attr = ParamAttr(
name=prefix + "_w",
initializer=fluid.initializer.Uniform(
low=-k, high=k))
bias_attr = ParamAttr(
name=prefix + "_b",
initializer=fluid.initializer.Uniform(
low=-k, high=k))
self._conv2d = fluid.dygraph.Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=(1, size_k),
stride=1,
padding=(0, padding),
groups=groups,
act=act,
param_attr=param_attr,
bias_attr=bias_attr)
def forward(self, x):
x = fluid.layers.unsqueeze(input=x, axes=[2])
x = self._conv2d(x)
x = fluid.layers.squeeze(input=x, axes=[2])
return x
class BMN(Model):
def __init__(self, cfg, is_dygraph=True):
super(BMN, self).__init__()
#init config
self.tscale = cfg.MODEL.tscale
self.dscale = cfg.MODEL.dscale
self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
self.num_sample = cfg.MODEL.num_sample
self.num_sample_perbin = cfg.MODEL.num_sample_perbin
self.is_dygraph = is_dygraph
self.hidden_dim_1d = 256
self.hidden_dim_2d = 128
self.hidden_dim_3d = 512
# Base Module
self.b_conv1 = Conv1D(
prefix="Base_1",
num_channels=400,
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.b_conv2 = Conv1D(
prefix="Base_2",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
# Temporal Evaluation Module
self.ts_conv1 = Conv1D(
prefix="TEM_s1",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.ts_conv2 = Conv1D(
prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid")
self.te_conv1 = Conv1D(
prefix="TEM_e1",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.te_conv2 = Conv1D(
prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid")
#Proposal Evaluation Module
self.p_conv1 = Conv1D(
prefix="PEM_1d",
num_filters=self.hidden_dim_2d,
size_k=3,
padding=1,
act="relu")
# init to speed up
sample_mask_array = get_interp1d_mask(
self.tscale, self.dscale, self.prop_boundary_ratio,
self.num_sample, self.num_sample_perbin)
if self.is_dygraph:
self.sample_mask = fluid.dygraph.base.to_variable(
sample_mask_array)
else: # static
self.sample_mask = fluid.layers.create_parameter(
shape=[
self.tscale, self.num_sample * self.dscale * self.tscale
],
dtype=DATATYPE,
attr=fluid.ParamAttr(
name="sample_mask", trainable=False),
default_initializer=fluid.initializer.NumpyArrayInitializer(
sample_mask_array))
self.sample_mask.stop_gradient = True
self.p_conv3d1 = fluid.dygraph.Conv3D(
num_channels=128,
num_filters=self.hidden_dim_3d,
filter_size=(self.num_sample, 1, 1),
stride=(self.num_sample, 1, 1),
padding=0,
act="relu",
param_attr=ParamAttr(name="PEM_3d1_w"),
bias_attr=ParamAttr(name="PEM_3d1_b"))
self.p_conv2d1 = fluid.dygraph.Conv2D(
num_channels=512,
num_filters=self.hidden_dim_2d,
filter_size=1,
stride=1,
padding=0,
act="relu",
param_attr=ParamAttr(name="PEM_2d1_w"),
bias_attr=ParamAttr(name="PEM_2d1_b"))
self.p_conv2d2 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=self.hidden_dim_2d,
filter_size=3,
stride=1,
padding=1,
act="relu",
param_attr=ParamAttr(name="PEM_2d2_w"),
bias_attr=ParamAttr(name="PEM_2d2_b"))
self.p_conv2d3 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=self.hidden_dim_2d,
filter_size=3,
stride=1,
padding=1,
act="relu",
param_attr=ParamAttr(name="PEM_2d3_w"),
bias_attr=ParamAttr(name="PEM_2d3_b"))
self.p_conv2d4 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=2,
filter_size=1,
stride=1,
padding=0,
act="sigmoid",
param_attr=ParamAttr(name="PEM_2d4_w"),
bias_attr=ParamAttr(name="PEM_2d4_b"))
def forward(self, x):
#Base Module
x = self.b_conv1(x)
x = self.b_conv2(x)
#TEM
xs = self.ts_conv1(x)
xs = self.ts_conv2(xs)
xs = fluid.layers.squeeze(xs, axes=[1])
xe = self.te_conv1(x)
xe = self.te_conv2(xe)
xe = fluid.layers.squeeze(xe, axes=[1])
#PEM
xp = self.p_conv1(x)
#BM layer
xp = fluid.layers.matmul(xp, self.sample_mask)
xp = fluid.layers.reshape(
xp, shape=[0, 0, -1, self.dscale, self.tscale])
xp = self.p_conv3d1(xp)
xp = fluid.layers.squeeze(xp, axes=[2])
xp = self.p_conv2d1(xp)
xp = self.p_conv2d2(xp)
xp = self.p_conv2d3(xp)
xp = self.p_conv2d4(xp)
return xp, xs, xe
class BmnLoss(Loss):
def __init__(self, cfg):
super(BmnLoss, self).__init__()
self.cfg = cfg
def _get_mask(self):
dscale = self.cfg.MODEL.dscale
tscale = self.cfg.MODEL.tscale
bm_mask = []
for idx in range(dscale):
mask_vector = [1 for i in range(tscale - idx)
] + [0 for i in range(idx)]
bm_mask.append(mask_vector)
bm_mask = np.array(bm_mask, dtype=np.float32)
self_bm_mask = fluid.layers.create_global_var(
shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True)
fluid.layers.assign(bm_mask, self_bm_mask)
self_bm_mask.stop_gradient = True
return self_bm_mask
def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end):
def bi_loss(pred_score, gt_label):
pred_score = fluid.layers.reshape(
x=pred_score, shape=[-1], inplace=False)
gt_label = fluid.layers.reshape(
x=gt_label, shape=[-1], inplace=False)
gt_label.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE)
num_entries = fluid.layers.cast(
fluid.layers.shape(pmask), dtype=DATATYPE)
num_positive = fluid.layers.cast(
fluid.layers.reduce_sum(pmask), dtype=DATATYPE)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
temp = fluid.layers.log(pred_score + epsilon)
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
loss = -1 * (loss_pos + loss_neg)
return loss
loss_start = bi_loss(pred_start, gt_start)
loss_end = bi_loss(pred_end, gt_end)
loss = loss_start + loss_end
return loss
def pem_reg_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE)
u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3)
u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.)
u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
u_lmask = fluid.layers.elementwise_mul(u_lmask, mask)
num_h = fluid.layers.cast(
fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
num_m = fluid.layers.cast(
fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
num_l = fluid.layers.cast(
fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
r_m = num_h / num_m
u_smmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
r_l = num_h / num_l
u_slmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
weights = u_hmask + u_smmask + u_slmask
weights.stop_gradient = True
loss = fluid.layers.square_error_cost(pred_score, gt_iou_map)
loss = fluid.layers.elementwise_mul(loss, weights)
loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
weights)
return loss
def pem_cls_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
gt_iou_map.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE)
nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE)
nmask = fluid.layers.elementwise_mul(nmask, mask)
num_positive = fluid.layers.reduce_sum(pmask)
num_entries = num_positive + fluid.layers.reduce_sum(nmask)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), nmask)
loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg)
loss = -1 * (loss_pos + loss_neg) / num_entries
return loss
def forward(self, outputs, labels):
pred_bm, pred_start, pred_end = outputs
if len(labels) == 3:
gt_iou_map, gt_start, gt_end = labels
elif len(labels) == 4: # video_index used in eval mode
gt_iou_map, gt_start, gt_end, video_index = labels
pred_bm_reg = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[0], ends=[1]),
axes=[1])
pred_bm_cls = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[1], ends=[2]),
axes=[1])
bm_mask = self._get_mask()
pem_reg_loss = self.pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask)
pem_cls_loss = self.pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask)
tem_loss = self.tem_loss_func(pred_start, pred_end, gt_start, gt_end)
loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import pandas as pd
import multiprocessing as mp
import json
import os
import math
def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute jaccard score between a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
union_len = len_anchors - inter_len + box_max - box_min
jaccard = np.divide(inter_len, union_len)
return jaccard
def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute intersection between score a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
scores = np.divide(inter_len, len_anchors)
return scores
def boundary_choose(score_list):
max_score = max(score_list)
mask_high = (score_list > max_score * 0.5)
score_list = list(score_list)
score_middle = np.array([0.0] + score_list + [0.0])
score_front = np.array([0.0, 0.0] + score_list)
score_back = np.array(score_list + [0.0, 0.0])
mask_peak = ((score_middle > score_front) & (score_middle > score_back))
mask_peak = mask_peak[1:-1]
mask = (mask_high | mask_peak).astype('float32')
return mask
def soft_nms(df, alpha, t1, t2):
'''
df: proposals generated by network;
alpha: alpha value of Gaussian decaying function;
t1, t2: threshold for soft nms.
'''
df = df.sort_values(by="score", ascending=False)
tstart = list(df.xmin.values[:])
tend = list(df.xmax.values[:])
tscore = list(df.score.values[:])
rstart = []
rend = []
rscore = []
while len(tscore) > 1 and len(rscore) < 101:
max_index = tscore.index(max(tscore))
tmp_iou_list = iou_with_anchors(
np.array(tstart),
np.array(tend), tstart[max_index], tend[max_index])
for idx in range(0, len(tscore)):
if idx != max_index:
tmp_iou = tmp_iou_list[idx]
tmp_width = tend[max_index] - tstart[max_index]
if tmp_iou > t1 + (t2 - t1) * tmp_width:
tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
alpha)
rstart.append(tstart[max_index])
rend.append(tend[max_index])
rscore.append(tscore[max_index])
tstart.pop(max_index)
tend.pop(max_index)
tscore.pop(max_index)
newDf = pd.DataFrame()
newDf['score'] = rscore
newDf['xmin'] = rstart
newDf['xmax'] = rend
return newDf
def video_process(video_list,
video_dict,
output_path,
result_dict,
snms_alpha=0.4,
snms_t1=0.55,
snms_t2=0.9):
for video_name in video_list:
print("Processing video........" + video_name)
df = pd.read_csv(os.path.join(output_path, video_name + ".csv"))
if len(df) > 1:
df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
video_duration = video_dict[video_name]["duration_second"]
proposal_list = []
for idx in range(min(100, len(df))):
tmp_prop={"score":df.score.values[idx], \
"segment":[max(0,df.xmin.values[idx])*video_duration, \
min(1,df.xmax.values[idx])*video_duration]}
proposal_list.append(tmp_prop)
result_dict[video_name[2:]] = proposal_list
def bmn_post_processing(video_dict, subset, output_path, result_path):
video_list = video_dict.keys()
video_list = list(video_dict.keys())
global result_dict
result_dict = mp.Manager().dict()
pp_num = 12
num_videos = len(video_list)
num_videos_per_thread = int(num_videos / pp_num)
processes = []
for tid in range(pp_num - 1):
tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
num_videos_per_thread]
p = mp.Process(
target=video_process,
args=(tmp_video_list, video_dict, output_path, result_dict))
p.start()
processes.append(p)
tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
p = mp.Process(
target=video_process,
args=(tmp_video_list, video_dict, output_path, result_dict))
p.start()
processes.append(p)
for p in processes:
p.join()
result_dict = dict(result_dict)
output_dict = {
"version": "VERSION 1.3",
"results": result_dict,
"external_data": {}
}
outfile = open(
os.path.join(result_path, "bmn_results_%s.json" % subset), "w")
json.dump(output_dict, outfile)
outfile.close()
def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample,
num_sample_perbin):
""" generate sample mask for a boundary-matching pair """
plen = float(seg_xmax - seg_xmin)
plen_sample = plen / (num_sample * num_sample_perbin - 1.0)
total_samples = [
seg_xmin + plen_sample * ii
for ii in range(num_sample * num_sample_perbin)
]
p_mask = []
for idx in range(num_sample):
bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) *
num_sample_perbin]
bin_vector = np.zeros([tscale])
for sample in bin_samples:
sample_upper = math.ceil(sample)
sample_decimal, sample_down = math.modf(sample)
if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0:
bin_vector[int(sample_down)] += 1 - sample_decimal
if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0:
bin_vector[int(sample_upper)] += sample_decimal
bin_vector = 1.0 / num_sample_perbin * bin_vector
p_mask.append(bin_vector)
p_mask = np.stack(p_mask, axis=1)
return p_mask
def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample,
num_sample_perbin):
""" generate sample mask for each point in Boundary-Matching Map """
mask_mat = []
for start_index in range(tscale):
mask_mat_vector = []
for duration_index in range(dscale):
if start_index + duration_index < tscale:
p_xmin = start_index
p_xmax = start_index + duration_index
center_len = float(p_xmax - p_xmin) + 1
sample_xmin = p_xmin - center_len * prop_boundary_ratio
sample_xmax = p_xmax + center_len * prop_boundary_ratio
p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax,
tscale, num_sample,
num_sample_perbin)
else:
p_mask = np.zeros([tscale, num_sample])
mask_mat_vector.append(p_mask)
mask_mat_vector = np.stack(mask_mat_vector, axis=2)
mask_mat.append(mask_mat_vector)
mask_mat = np.stack(mask_mat, axis=3)
mask_mat = mask_mat.astype(np.float32)
sample_mask = np.reshape(mask_mat, [tscale, -1])
return sample_mask
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import os
import sys
import logging
import paddle.fluid as fluid
sys.path.append('../')
from model import set_device, Input
from bmn_metric import BmnMetric
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("BMN test for performance evaluation.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode, only support dynamic mode at present time")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--device',
type=str,
default='gpu',
help='gpu or cpu, default use gpu.')
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Performance Evaluation
def test_bmn(args):
# only support dynamic mode at present time
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
config = parse_config(args.config_file)
eval_cfg = merge_configs(config, 'test', vars(args))
if not os.path.isdir(config.TEST.output_path):
os.makedirs(config.TEST.output_path)
if not os.path.isdir(config.TEST.result_path):
os.makedirs(config.TEST.result_path)
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
gt_iou_map = Input(
[None, config.MODEL.dscale, config.MODEL.tscale],
'float32',
name='gt_iou_map')
gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
video_idx = Input([None, 1], 'int64', name='video_idx')
labels = [gt_iou_map, gt_start, gt_end, video_idx]
#data
eval_dataset = BmnDataset(eval_cfg, 'test')
#model
model = BMN(config, args.dynamic)
model.prepare(
loss_function=BmnLoss(config),
metrics=BmnMetric(
config, mode='test'),
inputs=inputs,
labels=labels,
device=device)
#load checkpoint
if args.weights:
assert os.path.exists(args.weights + '.pdparams'), \
"Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
model.evaluate(
eval_data=eval_dataset,
batch_size=eval_cfg.TEST.batch_size,
num_workers=eval_cfg.TEST.num_workers,
log_freq=args.log_interval)
logger.info("[EVAL] eval finished")
if __name__ == '__main__':
args = parse_args()
test_bmn(args)
'''
Calculate AR@N and AUC;
Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)
'''
import sys
sys.path.append('./Evaluation')
from eval_proposal import ANETproposal
import numpy as np
import argparse
import os
parser = argparse.ArgumentParser("Eval AR vs AN of proposal")
parser.add_argument(
'--eval_file',
type=str,
default='bmn_results_validation.json',
help='name of results file to eval')
def run_evaluation(ground_truth_filename,
proposal_filename,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation'):
anet_proposal = ANETproposal(
ground_truth_filename,
proposal_filename,
tiou_thresholds=tiou_thresholds,
max_avg_nr_proposals=max_avg_nr_proposals,
subset=subset,
verbose=True,
check_status=False)
anet_proposal.evaluate()
recall = anet_proposal.recall
average_recall = anet_proposal.avg_recall
average_nr_proposals = anet_proposal.proposals_per_video
return (average_nr_proposals, average_recall, recall)
def plot_metric(average_nr_proposals,
average_recall,
recall,
tiou_thresholds=np.linspace(0.5, 0.95, 10)):
fn_size = 14
plt.figure(num=None, figsize=(12, 8))
ax = plt.subplot(1, 1, 1)
colors = [
'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo'
]
area_under_curve = np.zeros_like(tiou_thresholds)
for i in range(recall.shape[0]):
area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
for idx, tiou in enumerate(tiou_thresholds[::2]):
ax.plot(
average_nr_proposals,
recall[2 * idx, :],
color=colors[idx + 1],
label="tiou=[" + str(tiou) + "], area=" + str(
int(area_under_curve[2 * idx] * 100) / 100.),
linewidth=4,
linestyle='--',
marker=None)
# Plots Average Recall vs Average number of proposals.
ax.plot(
average_nr_proposals,
average_recall,
color=colors[0],
label="tiou = 0.5:0.05:0.95," + " area=" + str(
int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.),
linewidth=4,
linestyle='-',
marker=None)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
[handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
plt.ylabel('Average Recall', fontsize=fn_size)
plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
plt.grid(b=True, which="both")
plt.ylim([0, 1.0])
plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
plt.show()
if __name__ == "__main__":
args = parser.parse_args()
eval_file = args.eval_file
eval_file_path = os.path.join("evaluate_results", eval_file)
uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
"./Evaluation/data/activity_net_1_3_new.json",
eval_file_path,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation')
print("AR@1; AR@5; AR@10; AR@100")
print("%.02f %.02f %.02f %.02f" %
(100 * np.mean(uniform_recall_valid[:, 0]),
100 * np.mean(uniform_recall_valid[:, 4]),
100 * np.mean(uniform_recall_valid[:, 9]),
100 * np.mean(uniform_recall_valid[:, -1])))
{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}}
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import sys
import os
import logging
import paddle.fluid as fluid
sys.path.append('../')
from model import set_device, Input
from bmn_metric import BmnMetric
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("BMN inference.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode, only support dynamic mode at present time")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--device', type=str, default='GPU', help='default use gpu.')
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--save_dir',
type=str,
default="predict_results/",
help='output dir path, default to use ./predict_results/')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Prediction
def infer_bmn(args):
# only support dynamic mode at present time
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
config = parse_config(args.config_file)
infer_cfg = merge_configs(config, 'infer', vars(args))
if not os.path.isdir(config.INFER.output_path):
os.makedirs(config.INFER.output_path)
if not os.path.isdir(config.INFER.result_path):
os.makedirs(config.INFER.result_path)
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
labels = [Input([None, 1], 'int64', name='video_idx')]
#data
infer_dataset = BmnDataset(infer_cfg, 'infer')
model = BMN(config, args.dynamic)
model.prepare(
metrics=BmnMetric(
config, mode='infer'),
inputs=inputs,
labels=labels,
device=device)
# load checkpoint
if args.weights:
assert os.path.exists(
args.weights +
".pdparams"), "Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
# here use model.eval instead of model.test, as post process is required in our case
model.evaluate(
eval_data=infer_dataset,
batch_size=infer_cfg.TEST.batch_size,
num_workers=infer_cfg.TEST.num_workers,
log_freq=args.log_interval)
logger.info("[INFER] infer finished")
if __name__ == '__main__':
args = parse_args()
infer_bmn(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import numpy as np
import json
import logging
import os
import sys
sys.path.append('../')
from distributed import DistributedBatchSampler
from paddle.fluid.io import Dataset, DataLoader
logger = logging.getLogger(__name__)
from config_utils import *
from bmn_utils import iou_with_anchors, ioa_with_anchors
DATATYPE = "float32"
class BmnDataset(Dataset):
def __init__(self, cfg, mode):
self.mode = mode
self.tscale = cfg.MODEL.tscale # 100
self.dscale = cfg.MODEL.dscale # 100
self.anno_file = cfg.MODEL.anno_file
self.feat_path = cfg.MODEL.feat_path
self.file_list = cfg.INFER.filelist
self.subset = cfg[mode.upper()]['subset']
self.tgap = 1. / self.tscale
self.get_dataset_dict()
self.get_match_map()
def __getitem__(self, index):
video_name = self.video_list[index]
video_idx = self.video_list.index(video_name)
video_feat = self.load_file(video_name)
if self.mode == 'infer':
return video_feat, video_idx
else:
gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
if self.mode == 'train' or self.mode == 'valid':
return video_feat, gt_iou_map, gt_start, gt_end
elif self.mode == 'test':
return video_feat, gt_iou_map, gt_start, gt_end, video_idx
def __len__(self):
return len(self.video_list)
def get_dataset_dict(self):
assert (
os.path.exists(self.feat_path)), "Input feature path not exists"
assert (os.listdir(self.feat_path)), "No feature file in feature path"
self.video_dict = {}
if self.mode == "infer":
annos = json.load(open(self.file_list))
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
print("%s subset video numbers: %d" %
(self.subset, len(self.video_list)))
video_name_set = set(
[video_name + '.npy' for video_name in self.video_list])
assert (video_name_set.intersection(set(os.listdir(self.feat_path))) ==
video_name_set), "Input feature not exists in feature path"
def get_match_map(self):
match_map = []
for idx in range(self.tscale):
tmp_match_window = []
xmin = self.tgap * idx
for jdx in range(1, self.tscale + 1):
xmax = xmin + self.tgap * jdx
tmp_match_window.append([xmin, xmax])
match_map.append(tmp_match_window)
match_map = np.array(match_map)
match_map = np.transpose(match_map, [1, 0, 2])
match_map = np.reshape(match_map, [-1, 2])
self.match_map = match_map
self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
def get_video_label(self, video_name):
video_info = self.video_dict[video_name]
video_second = video_info['duration_second']
video_labels = video_info['annotations']
gt_bbox = []
gt_iou_map = []
for gt in video_labels:
tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
gt_bbox.append([tmp_start, tmp_end])
tmp_gt_iou_map = iou_with_anchors(
self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end)
tmp_gt_iou_map = np.reshape(tmp_gt_iou_map,
[self.dscale, self.tscale])
gt_iou_map.append(tmp_gt_iou_map)
gt_iou_map = np.array(gt_iou_map)
gt_iou_map = np.max(gt_iou_map, axis=0)
gt_bbox = np.array(gt_bbox)
gt_xmins = gt_bbox[:, 0]
gt_xmaxs = gt_bbox[:, 1]
gt_len_small = 3 * self.tgap
gt_start_bboxs = np.stack(
(gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
gt_end_bboxs = np.stack(
(gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
match_score_start = []
for jdx in range(len(self.anchor_xmin)):
match_score_start.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
match_score_end = []
for jdx in range(len(self.anchor_xmin)):
match_score_end.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
gt_start = np.array(match_score_start)
gt_end = np.array(match_score_end)
return gt_iou_map.astype(DATATYPE), gt_start.astype(
DATATYPE), gt_end.astype(DATATYPE)
def load_file(self, video_name):
file_name = video_name + ".npy"
file_path = os.path.join(self.feat_path, file_name)
video_feat = np.load(file_path)
video_feat = video_feat.T
video_feat = video_feat.astype("float32")
return video_feat
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch train.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
import argparse
import logging
import sys
import os
sys.path.append('../')
from model import set_device, Input
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Paddle high level api of BMN.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='training batch size. None to use config file setting.')
parser.add_argument(
'--learning_rate',
type=float,
default=0.001,
help='learning rate use for training. None to use config file setting.')
parser.add_argument(
'--resume',
type=str,
default=None,
help='filename to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--device',
type=str,
default='gpu',
help='gpu or cpu, default use gpu.')
parser.add_argument(
'--epoch',
type=int,
default=9,
help='epoch number, 0 for read from config file')
parser.add_argument(
'--valid_interval',
type=int,
default=1,
help='validation epoch interval, 0 for no validation.')
parser.add_argument(
'--save_dir',
type=str,
default="checkpoint",
help='path to save train snapshoot')
parser.add_argument(
'--log_interval',
type=int,
default=10,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Optimizer
def optimizer(cfg, parameter_list):
bd = [cfg.TRAIN.lr_decay_iter]
base_lr = cfg.TRAIN.learning_rate
lr_decay = cfg.TRAIN.learning_rate_decay
l2_weight_decay = cfg.TRAIN.l2_weight_decay
lr = [base_lr, base_lr * lr_decay]
optimizer = fluid.optimizer.Adam(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
parameter_list=parameter_list,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=l2_weight_decay))
return optimizer
# TRAIN
def train_bmn(args):
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
config = parse_config(args.config_file)
train_cfg = merge_configs(config, 'train', vars(args))
val_cfg = merge_configs(config, 'valid', vars(args))
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
gt_iou_map = Input(
[None, config.MODEL.dscale, config.MODEL.tscale],
'float32',
name='gt_iou_map')
gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
labels = [gt_iou_map, gt_start, gt_end]
# data
train_dataset = BmnDataset(train_cfg, 'train')
val_dataset = BmnDataset(val_cfg, 'valid')
# model
model = BMN(config, args.dynamic)
optim = optimizer(config, parameter_list=model.parameters())
model.prepare(
optimizer=optim,
loss_function=BmnLoss(config),
inputs=inputs,
labels=labels,
device=device)
# if resume weights is given, load resume weights directly
if args.resume is not None:
model.load(args.resume)
model.fit(train_data=train_dataset,
eval_data=val_dataset,
batch_size=train_cfg.TRAIN.batch_size,
epochs=args.epoch,
eval_freq=args.valid_interval,
log_freq=args.log_interval,
save_dir=args.save_dir,
shuffle=train_cfg.TRAIN.use_shuffle,
num_workers=train_cfg.TRAIN.num_workers,
drop_last=True)
if __name__ == "__main__":
args = parse_args()
train_bmn(args)
# Cycle GAN
---
## 内容
- [安装](#安装)
- [简介](#简介)
- [代码结构](#代码结构)
- [数据准备](#数据准备)
- [模型训练与预测](#模型训练与预测)
## 安装
运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
## 简介
Cycle GAN 是一种image to image 的图像生成网络,实现了非对称图像数据集的生成和风格迁移。模型结构如下图所示,我们的模型包含两个生成网络 G: X → Y 和 F: Y → X,以及相关的判别器 DY 和 DX 。通过训练DY,使G将X图尽量转换为Y图,反之亦然。同时引入两个“周期一致性损失”,它们保证:如果我们从一个领域转换到另一个领域,它还可以被转换回去:(b)正向循环一致性损失:x→G(x)→F(G(x))≈x, (c)反向循环一致性损失:y→F(y)→G(F(y))≈y
<p align="center">
<img src="image/net.png" hspace='10'/> <br />
图1.网络结构
</p>
## 代码结构
```
├── data.py # 读取、处理数据。
├── layers.py # 封装定义基础的layers。
├── cyclegan.py # 定义基础生成网络和判别网络。
├── train.py # 训练脚本。
└── infer.py # 预测脚本。
```
## 数据准备
CycleGAN 支持的数据集可以参考download.py中的`cycle_pix_dataset`,可以通过指定`python download.py --dataset xxx` 下载得到。
由于版权问题,cityscapes 数据集无法通过脚本直接获得,需要从[官方](https://www.cityscapes-dataset.com/)下载数据,
下载完之后执行`python prepare_cityscapes_dataset.py --gtFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./data/cityscapes/`处理,
将数据存放在`data/cityscapes`
数据下载处理完毕后,需要您将数据组织为以下路径结构:
```
data
|-- cityscapes
| |-- testA
| |-- testB
| |-- trainA
| |-- trainB
```
然后运行txt生成脚本:`python generate_txt.py`,最终数据组织如下所示:
```
data
|-- cityscapes
| |-- testA
| |-- testA.txt
| |-- testB
| |-- testB.txt
| |-- trainA
| |-- trainA.txt
| |-- trainB
| `-- trainB.txt
```
以上数据文件中,`data`文件夹需要放在训练脚本`train.py`同级目录下。`testA`为存放真实街景图片的文件夹,`testB`为存放语义分割图片的文件夹,`testA.txt``testB.txt`分别为测试图片路径列表文件,格式如下:
```
data/cityscapes/testA/234_A.jpg
data/cityscapes/testA/292_A.jpg
data/cityscapes/testA/412_A.jpg
```
训练数据组织方式与测试数据相同。
## 模型训练与预测
### 训练
在GPU单卡上训练:
```
env CUDA_VISIBLE_DEVICES=0 python train.py
```
执行`python train.py --help`可查看更多使用方式和参数详细说明。
图1为训练152轮的训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'g_loss','da_loss'和'db_loss'分别为生成器、判别器A和判别器B的训练损失。
### 测试
执行以下命令可以选择已保存的训练权重,对测试集进行测试,通过 `--epoch` 制定权重轮次:
```
env CUDA_VISIBLE_DEVICES=0 python test.py --init_model=checkpoint/199
```
生成结果在 `output/eval`
### 预测
执行以下命令读取单张或多张图片进行预测:
真实街景生成分割图像:
```
env CUDA_VISIBLE_DEVICES=0 python infer.py \
--init_model="./checkpoints/199" --input="./image/testA/123_A.jpg" \
--input_style=A
```
分割图像生成真实街景:
```
env CUDA_VISIBLE_DEVICES=0 python infer.py \
--init_model="checkpoints/199" --input="./image/testB/78_B.jpg" \
--input_style=B
```
生成结果在 `output/single`
训练180轮的模型预测效果如fakeA和fakeB所示:
<p align="center">
<img src="image/A2B.png" width="620" hspace='10'/> <br/>
<strong>A2B</strong>
</p>
<p align="center">
<img src="image/B2A.png" width="620" hspace='10'/> <br/>
<strong>B2A</strong>
</p>
>在本文示例中,均可通过修改`CUDA_VISIBLE_DEVICES`改变使用的显卡号。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
__all__ = ['check_gpu', 'check_version']
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
print(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.7.0')
except Exception as e:
print(err)
sys.exit(1)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from layers import ConvBN, DeConvBN
import paddle.fluid as fluid
from model import Model, Loss
class ResnetBlock(fluid.dygraph.Layer):
def __init__(self, dim, dropout=False):
super(ResnetBlock, self).__init__()
self.dropout = dropout
self.conv0 = ConvBN(dim, dim, 3, 1)
self.conv1 = ConvBN(dim, dim, 3, 1, act=None)
def forward(self, inputs):
out_res = fluid.layers.pad2d(inputs, [1, 1, 1, 1], mode="reflect")
out_res = self.conv0(out_res)
if self.dropout:
out_res = fluid.layers.dropout(out_res, dropout_prob=0.5)
out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
out_res = self.conv1(out_res)
return out_res + inputs
class ResnetGenerator(fluid.dygraph.Layer):
def __init__(self, input_channel, n_blocks=9, dropout=False):
super(ResnetGenerator, self).__init__()
self.conv0 = ConvBN(input_channel, 32, 7, 1)
self.conv1 = ConvBN(32, 64, 3, 2, padding=1)
self.conv2 = ConvBN(64, 128, 3, 2, padding=1)
dim = 128
self.resnet_blocks = []
for i in range(n_blocks):
block = self.add_sublayer("generator_%d" % (i + 1),
ResnetBlock(dim, dropout))
self.resnet_blocks.append(block)
self.deconv0 = DeConvBN(
dim, 32 * 2, 3, 2, padding=[1, 1], outpadding=[0, 1, 0, 1])
self.deconv1 = DeConvBN(
32 * 2, 32, 3, 2, padding=[1, 1], outpadding=[0, 1, 0, 1])
self.conv3 = ConvBN(
32, input_channel, 7, 1, norm=False, act=False, use_bias=True)
def forward(self, inputs):
pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect")
y = self.conv0(pad_input)
y = self.conv1(y)
y = self.conv2(y)
for resnet_block in self.resnet_blocks:
y = resnet_block(y)
y = self.deconv0(y)
y = self.deconv1(y)
y = fluid.layers.pad2d(y, [3, 3, 3, 3], mode="reflect")
y = self.conv3(y)
y = fluid.layers.tanh(y)
return y
class NLayerDiscriminator(fluid.dygraph.Layer):
def __init__(self, input_channel, d_dims=64, d_nlayers=3):
super(NLayerDiscriminator, self).__init__()
self.conv0 = ConvBN(
input_channel,
d_dims,
4,
2,
1,
norm=False,
use_bias=True,
relufactor=0.2)
nf_mult, nf_mult_prev = 1, 1
self.conv_layers = []
for n in range(1, d_nlayers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
conv = self.add_sublayer(
'discriminator_%d' % (n),
ConvBN(
d_dims * nf_mult_prev,
d_dims * nf_mult,
4,
2,
1,
relufactor=0.2))
self.conv_layers.append(conv)
nf_mult_prev = nf_mult
nf_mult = min(2**d_nlayers, 8)
self.conv4 = ConvBN(
d_dims * nf_mult_prev, d_dims * nf_mult, 4, 1, 1, relufactor=0.2)
self.conv5 = ConvBN(
d_dims * nf_mult,
1,
4,
1,
1,
norm=False,
act=None,
use_bias=True,
relufactor=0.2)
def forward(self, inputs):
y = self.conv0(inputs)
for conv in self.conv_layers:
y = conv(y)
y = self.conv4(y)
y = self.conv5(y)
return y
class Generator(Model):
def __init__(self, input_channel=3):
super(Generator, self).__init__()
self.g = ResnetGenerator(input_channel)
def forward(self, input):
fake = self.g(input)
return fake
class GeneratorCombine(Model):
def __init__(self, g_AB=None, g_BA=None, d_A=None, d_B=None,
is_train=True):
super(GeneratorCombine, self).__init__()
self.g_AB = g_AB
self.g_BA = g_BA
self.is_train = is_train
if self.is_train:
self.d_A = d_A
self.d_B = d_B
def forward(self, input_A, input_B):
# Translate images to the other domain
fake_B = self.g_AB(input_A)
fake_A = self.g_BA(input_B)
# Translate images back to original domain
cyc_A = self.g_BA(fake_B)
cyc_B = self.g_AB(fake_A)
if not self.is_train:
return fake_A, fake_B, cyc_A, cyc_B
# Identity mapping of images
idt_A = self.g_AB(input_B)
idt_B = self.g_BA(input_A)
# Discriminators determines validity of translated images
# d_A(g_AB(A))
valid_A = self.d_A.d(fake_B)
# d_B(g_BA(A))
valid_B = self.d_B.d(fake_A)
return input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B
class GLoss(Loss):
def __init__(self, lambda_A=10., lambda_B=10., lambda_identity=0.5):
super(GLoss, self).__init__()
self.lambda_A = lambda_A
self.lambda_B = lambda_B
self.lambda_identity = lambda_identity
def forward(self, outputs, labels=None):
input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B = outputs
def mse(a, b):
return fluid.layers.reduce_mean(fluid.layers.square(a - b))
def mae(a, b): # L1Loss
return fluid.layers.reduce_mean(fluid.layers.abs(a - b))
g_A_loss = mse(valid_A, 1.)
g_B_loss = mse(valid_B, 1.)
g_loss = g_A_loss + g_B_loss
cyc_A_loss = mae(input_A, cyc_A) * self.lambda_A
cyc_B_loss = mae(input_B, cyc_B) * self.lambda_B
cyc_loss = cyc_A_loss + cyc_B_loss
idt_loss_A = mae(input_B, idt_A) * (self.lambda_B *
self.lambda_identity)
idt_loss_B = mae(input_A, idt_B) * (self.lambda_A *
self.lambda_identity)
idt_loss = idt_loss_A + idt_loss_B
loss = cyc_loss + g_loss + idt_loss
return loss
class Discriminator(Model):
def __init__(self, input_channel=3):
super(Discriminator, self).__init__()
self.d = NLayerDiscriminator(input_channel)
def forward(self, real, fake):
pred_real = self.d(real)
pred_fake = self.d(fake)
return pred_real, pred_fake
class DLoss(Loss):
def __init__(self):
super(DLoss, self).__init__()
def forward(self, inputs, labels=None):
pred_real, pred_fake = inputs
loss = fluid.layers.square(pred_fake) + fluid.layers.square(pred_real -
1.)
loss = fluid.layers.reduce_mean(loss / 2.0)
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import numpy as np
from PIL import Image, ImageOps
DATASET = "cityscapes"
A_LIST_FILE = "./data/" + DATASET + "/trainA.txt"
B_LIST_FILE = "./data/" + DATASET + "/trainB.txt"
A_TEST_LIST_FILE = "./data/" + DATASET + "/testA.txt"
B_TEST_LIST_FILE = "./data/" + DATASET + "/testB.txt"
IMAGES_ROOT = "./data/" + DATASET + "/"
import paddle.fluid as fluid
class Cityscapes(fluid.io.Dataset):
def __init__(self, root_path, file_path, mode='train', return_name=False):
self.root_path = root_path
self.file_path = file_path
self.mode = mode
self.return_name = return_name
self.images = [root_path + l for l in open(file_path, 'r').readlines()]
def _train(self, image):
## Resize
image = image.resize((286, 286), Image.BICUBIC)
## RandomCrop
i = np.random.randint(0, 30)
j = np.random.randint(0, 30)
image = image.crop((i, j, i + 256, j + 256))
# RandomHorizontalFlip
if np.random.rand() > 0.5:
image = ImageOps.mirror(image)
return image
def __getitem__(self, idx):
f = self.images[idx].strip("\n\r\t ")
image = Image.open(f)
if self.mode == 'train':
image = self._train(image)
else:
image = image.resize((256, 256), Image.BICUBIC)
# ToTensor
image = np.array(image).transpose([2, 0, 1]).astype('float32')
image = image / 255.0
# Normalize, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
image = (image - 0.5) / 0.5
if self.return_name:
return [image], os.path.basename(f)
else:
return [image]
def __len__(self):
return len(self.images)
def DataA(root=IMAGES_ROOT, fpath=A_LIST_FILE):
"""
Reader of images with A style for training.
"""
return Cityscapes(root, fpath)
def DataB(root=IMAGES_ROOT, fpath=B_LIST_FILE):
"""
Reader of images with B style for training.
"""
return Cityscapes(root, fpath)
def TestDataA(root=IMAGES_ROOT, fpath=A_TEST_LIST_FILE):
"""
Reader of images with A style for training.
"""
return Cityscapes(root, fpath, mode='test', return_name=True)
def TestDataB(root=IMAGES_ROOT, fpath=B_TEST_LIST_FILE):
"""
Reader of images with B style for training.
"""
return Cityscapes(root, fpath, mode='test', return_name=True)
class ImagePool(object):
def __init__(self, pool_size=50):
self.pool = []
self.count = 0
self.pool_size = pool_size
def get(self, image):
if self.count < self.pool_size:
self.pool.append(image)
self.count += 1
return image
else:
p = random.random()
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1)
temp = self.pool[random_id]
self.pool[random_id] = image
return temp
else:
return image
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import numpy as np
import argparse
from PIL import Image
from scipy.misc import imsave
import paddle.fluid as fluid
from check import check_gpu, check_version
from model import Model, Input, set_device
from cyclegan import Generator, GeneratorCombine
def main():
place = set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
# Generators
g_AB = Generator()
g_BA = Generator()
g = GeneratorCombine(g_AB, g_BA, is_train=False)
im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
g.prepare(inputs=[input_A, input_B])
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
out_path = FLAGS.output + "/single"
if not os.path.exists(out_path):
os.makedirs(out_path)
for f in glob.glob(FLAGS.input):
image_name = os.path.basename(f)
image = Image.open(f).convert('RGB')
image = image.resize((256, 256), Image.BICUBIC)
image = np.array(image) / 127.5 - 1
image = image[:, :, 0:3].astype("float32")
data = image.transpose([2, 0, 1])[np.newaxis, :]
if FLAGS.input_style == "A":
_, fake, _, _ = g.test([data, data])
if FLAGS.input_style == "B":
fake, _, _, _ = g.test([data, data])
fake = np.squeeze(fake[0]).transpose([1, 2, 0])
opath = "{}/fake{}{}".format(out_path, FLAGS.input_style, image_name)
imsave(opath, ((fake + 1) * 127.5).astype(np.uint8))
print("transfer {} to {}".format(f, opath))
if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN inference")
parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument(
"-p",
"--device",
type=str,
default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-i",
"--input",
type=str,
default='./image/testA/123_A.jpg',
help="input image")
parser.add_argument(
"-o",
'--output',
type=str,
default='output',
help="The test result to be saved to.")
parser.add_argument(
"-m",
"--init_model",
type=str,
default='checkpoint/199',
help="The init model file of directory.")
parser.add_argument(
"-s", "--input_style", type=str, default='A', help="A or B")
FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, BatchNorm
# cudnn is not better when batch size is 1.
use_cudnn = False
import numpy as np
class ConvBN(fluid.dygraph.Layer):
"""docstring for Conv2D"""
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
stddev=0.02,
norm=True,
is_test=False,
act='leaky_relu',
relufactor=0.0,
use_bias=False):
super(ConvBN, self).__init__()
pattr = fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
self.conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
use_cudnn=use_cudnn,
param_attr=pattr,
bias_attr=use_bias)
if norm:
self.bn = BatchNorm(
num_filters,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(1.0,
0.02)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0)),
is_test=False,
trainable_statistics=True)
self.relufactor = relufactor
self.norm = norm
self.act = act
def forward(self, inputs):
conv = self.conv(inputs)
if self.norm:
conv = self.bn(conv)
if self.act == 'leaky_relu':
conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor)
elif self.act == 'relu':
conv = fluid.layers.relu(conv)
else:
conv = conv
return conv
class DeConvBN(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=[0, 0],
outpadding=[0, 0, 0, 0],
stddev=0.02,
act='leaky_relu',
norm=True,
is_test=False,
relufactor=0.0,
use_bias=False):
super(DeConvBN, self).__init__()
pattr = fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev))
self._deconv = Conv2DTranspose(
num_channels,
num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=pattr,
bias_attr=use_bias)
if norm:
self.bn = BatchNorm(
num_filters,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(1.0,
0.02)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0)),
is_test=False,
trainable_statistics=True)
self.outpadding = outpadding
self.relufactor = relufactor
self.use_bias = use_bias
self.norm = norm
self.act = act
def forward(self, inputs):
conv = self._deconv(inputs)
conv = fluid.layers.pad2d(
conv, paddings=self.outpadding, mode='constant', pad_value=0.0)
if self.norm:
conv = self.bn(conv)
if self.act == 'leaky_relu':
conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor)
elif self.act == 'relu':
conv = fluid.layers.relu(conv)
else:
conv = conv
return conv
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from scipy.misc import imsave
import paddle.fluid as fluid
from check import check_gpu, check_version
from model import Model, Input, set_device
from cyclegan import Generator, GeneratorCombine
import data as data
def main():
place = set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
# Generators
g_AB = Generator()
g_BA = Generator()
g = GeneratorCombine(g_AB, g_BA, is_train=False)
im_shape = [-1, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
g.prepare(inputs=[input_A, input_B])
g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
if not os.path.exists(FLAGS.output):
os.makedirs(FLAGS.output)
test_data_A = data.TestDataA()
test_data_B = data.TestDataB()
for i in range(len(test_data_A)):
data_A, A_name = test_data_A[i]
data_B, B_name = test_data_B[i]
data_A = np.array(data_A).astype("float32")
data_B = np.array(data_B).astype("float32")
fake_A, fake_B, cyc_A, cyc_B = g.test([data_A, data_B])
datas = [fake_A, fake_B, cyc_A, cyc_B, data_A, data_B]
odatas = []
for o in datas:
d = np.squeeze(o[0]).transpose([1, 2, 0])
im = ((d + 1) * 127.5).astype(np.uint8)
odatas.append(im)
imsave(FLAGS.output + "/fakeA_" + B_name, odatas[0])
imsave(FLAGS.output + "/fakeB_" + A_name, odatas[1])
imsave(FLAGS.output + "/cycA_" + A_name, odatas[2])
imsave(FLAGS.output + "/cycB_" + B_name, odatas[3])
imsave(FLAGS.output + "/inputA_" + A_name, odatas[4])
imsave(FLAGS.output + "/inputB_" + B_name, odatas[5])
if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN test")
parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument(
"-p",
"--device",
type=str,
default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-b", "--batch_size", default=1, type=int, help="batch size")
parser.add_argument(
"-o",
'--output',
type=str,
default='output/eval',
help="The test result to be saved to.")
parser.add_argument(
"-m",
"--init_model",
type=str,
default='checkpoint/199',
help="The init model file of directory.")
FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import argparse
import contextlib
import time
import paddle
import paddle.fluid as fluid
from check import check_gpu, check_version
from model import Model, Input, set_device
import data as data
from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss
step_per_epoch = 2974
def opt(parameters):
lr_base = 0.0002
bounds = [100, 120, 140, 160, 180]
lr = [1., 0.8, 0.6, 0.4, 0.2, 0.1]
bounds = [i * step_per_epoch for i in bounds]
lr = [i * lr_base for i in lr]
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bounds, values=lr),
parameter_list=parameters,
beta1=0.5)
return optimizer
def main():
place = set_device(FLAGS.device)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
# Generators
g_AB = Generator()
g_BA = Generator()
# Discriminators
d_A = Discriminator()
d_B = Discriminator()
g = GeneratorCombine(g_AB, g_BA, d_A, d_B)
da_params = d_A.parameters()
db_params = d_B.parameters()
g_params = g_AB.parameters() + g_BA.parameters()
da_optimizer = opt(da_params)
db_optimizer = opt(db_params)
g_optimizer = opt(g_params)
im_shape = [None, 3, 256, 256]
input_A = Input(im_shape, 'float32', 'input_A')
input_B = Input(im_shape, 'float32', 'input_B')
fake_A = Input(im_shape, 'float32', 'fake_A')
fake_B = Input(im_shape, 'float32', 'fake_B')
g_AB.prepare(inputs=[input_A])
g_BA.prepare(inputs=[input_B])
g.prepare(g_optimizer, GLoss(), inputs=[input_A, input_B])
d_A.prepare(da_optimizer, DLoss(), inputs=[input_B, fake_B])
d_B.prepare(db_optimizer, DLoss(), inputs=[input_A, fake_A])
if FLAGS.resume:
g.load(FLAGS.resume)
loader_A = fluid.io.DataLoader(
data.DataA(),
places=place,
shuffle=True,
return_list=True,
batch_size=FLAGS.batch_size)
loader_B = fluid.io.DataLoader(
data.DataB(),
places=place,
shuffle=True,
return_list=True,
batch_size=FLAGS.batch_size)
A_pool = data.ImagePool()
B_pool = data.ImagePool()
for epoch in range(FLAGS.epoch):
for i, (data_A, data_B) in enumerate(zip(loader_A, loader_B)):
data_A = data_A[0][0] if not FLAGS.dynamic else data_A[0]
data_B = data_B[0][0] if not FLAGS.dynamic else data_B[0]
start = time.time()
fake_B = g_AB.test(data_A)[0]
fake_A = g_BA.test(data_B)[0]
g_loss = g.train([data_A, data_B])[0]
fake_pb = B_pool.get(fake_B)
da_loss = d_A.train([data_B, fake_pb])[0]
fake_pa = A_pool.get(fake_A)
db_loss = d_B.train([data_A, fake_pa])[0]
t = time.time() - start
if i % 20 == 0:
print("epoch: {} | step: {:3d} | g_loss: {:.4f} | " \
"da_loss: {:.4f} | db_loss: {:.4f} | s/step {:.4f}".
format(epoch, i, g_loss[0], da_loss[0], db_loss[0], t))
g.save('{}/{}'.format(FLAGS.checkpoint_path, epoch))
if __name__ == "__main__":
parser = argparse.ArgumentParser("CycleGAN Training on Cityscapes")
parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument(
"-p",
"--device",
type=str,
default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-e", "--epoch", default=200, type=int, help="Epoch number")
parser.add_argument(
"-b", "--batch_size", default=1, type=int, help="batch size")
parser.add_argument(
"-o",
"--checkpoint_path",
type=str,
default='checkpoint',
help="path to save checkpoint")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import cv2
from paddle.fluid.io import Dataset
def has_valid_extension(filename, extensions):
"""Checks if a file is a vilid extension.
Args:
filename (str): path to a file
extensions (tuple of str): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x):
return has_valid_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
class DatasetFolder(Dataset):
"""A generic data loader where the samples are arranged in this way:
root/class_a/1.ext
root/class_a/2.ext
root/class_a/3.ext
root/class_b/123.ext
root/class_b/456.ext
root/class_b/789.ext
Args:
root (string): Root directory path.
loader (callable, optional): A function to load a sample given its path.
extensions (tuple[string], optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self,
root,
loader=None,
extensions=None,
transform=None,
target_transform=None,
is_valid_file=None):
self.root = root
if extensions is None:
extensions = IMG_EXTENSIONS
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions,
is_valid_file)
if len(samples) == 0:
raise (RuntimeError(
"Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = cv2_loader if loader is None else loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir),
and class_to_idx is a dictionary.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [
d for d in os.listdir(dir)
if os.path.isdir(os.path.join(dir, d))
]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def cv2_loader(path):
return cv2.imread(path)
# 高级api图像分类
## 数据集准备
在开始训练前,请确保已经下载解压好[ImageNet数据集](http://image-net.org/download),并放在合适的目录下,准备好的数据集的目录结构如下所示:
```bash
/path/to/imagenet
train
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
val
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
```
## 训练
### 单卡训练
执行如下命令进行训练
```bash
python -u main.py --arch resnet50 /path/to/imagenet -d
```
-d 是使用动态模式训练,默认为静态图模式。
### 多卡训练
执行如下命令进行训练
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 -d /path/to/imagenet
```
## 预测
### 单卡预测
执行如下命令进行预测
```bash
python -u main.py --arch resnet50 -d --evaly-only /path/to/imagenet
```
### 多卡预测
执行如下命令进行多卡预测
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 --evaly-only /path/to/imagenet
```
## 参数说明
* **arch**: 要训练或预测的模型名称
* **device**: 训练使用的设备,'gpu'或'cpu',默认值:'gpu'
* **dynamic**: 是否使用动态图模式训练
* **epoch**: 训练的轮数,默认值:120
* **learning-rate**: 学习率,默认值:0.1
* **batch-size**: 每张卡的batch size,默认值:64
* **output-dir**: 模型文件保存的文件夹,默认值:'output'
* **num-workers**: dataloader的进程数,默认值:4
* **resume**: 恢复训练的模型路径,默认值:None
* **eval-only**: 是否仅仅进行预测
* **lr-scheduler**: 学习率衰减策略,默认值:piecewise
* **milestones**: piecewise学习率衰减策略的边界,默认值:[30, 60, 80]
* **weight-decay**: 模型权重正则化系数,默认值:1e-4
* **momentum**: SGD优化器的动量,默认值:0.9
## 模型
| 模型 | top1 acc | top5 acc |
| --- | --- | --- |
| [ResNet50](https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams) | 76.28 | 93.04 |
| [vgg16](https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams) | 71.84 | 90.71 |
| [mobilenet_v1](https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams) | 71.25 | 89.92 |
| [mobilenet_v2](https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams) | 72.27 | 90.66 |
上述模型的复现参数请参考scripts下的脚本。
## 参考文献
- ResNet: [Deep Residual Learning for Image Recognitio](https://arxiv.org/abs/1512.03385), Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
- MobileNetV1: [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861), Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
- MobileNetV2: [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/pdf/1801.04381v4.pdf), Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
- VGG: [Very Deep Convolutional Networks for Large-scale Image Recognition](https://arxiv.org/pdf/1409.1556), Karen Simonyan, Andrew Zisserman
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import math
import random
import numpy as np
from datasets.folder import DatasetFolder
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i:i + c, j:j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i:i + h, j:j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
return img
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
class ImageNetDataset(DatasetFolder):
def __init__(self, path, mode='train'):
super(ImageNetDataset, self).__init__(path)
self.mode = mode
if self.mode == 'train':
self.transform = compose([
cv2.imread, random_crop_resize, random_flip, normalize_permute
])
else:
self.transform = compose(
[cv2.imread, center_crop_resize, normalize_permute])
def __getitem__(self, idx):
img, label = self.samples[idx]
return self.transform((img, [label]))
def __len__(self):
return len(self.samples)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import sys
sys.path.append('../')
import time
import math
import numpy as np
import models
import paddle.fluid as fluid
from model import CrossEntropy, Input, set_device
from imagenet_dataset import ImageNetDataset
from distributed import DistributedBatchSampler
from paddle.fluid.dygraph.parallel import ParallelEnv
from metrics import Accuracy
from paddle.fluid.io import BatchSampler, DataLoader
def make_optimizer(step_per_epoch, parameter_list=None):
base_lr = FLAGS.lr
lr_scheduler = FLAGS.lr_scheduler
momentum = FLAGS.momentum
weight_decay = FLAGS.weight_decay
if lr_scheduler == 'piecewise':
milestones = FLAGS.milestones
boundaries = [step_per_epoch * e for e in milestones]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
elif lr_scheduler == 'cosine':
learning_rate = fluid.layers.cosine_decay(base_lr, step_per_epoch,
FLAGS.epoch)
else:
raise ValueError(
"Expected lr_scheduler in ['piecewise', 'cosine'], but got {}".
format(lr_scheduler))
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=5 * step_per_epoch,
start_lr=0.,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=momentum,
regularization=fluid.regularizer.L2Decay(weight_decay),
parameter_list=parameter_list)
return optimizer
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and
not FLAGS.resume)
if FLAGS.resume is not None:
model.load(FLAGS.resume)
inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
train_dataset = ImageNetDataset(
os.path.join(FLAGS.data, 'train'), mode='train')
val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val')
optim = make_optimizer(
np.ceil(
len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks),
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels)
if FLAGS.eval_only:
model.evaluate(
val_dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.num_workers)
return
output_dir = os.path.join(FLAGS.output_dir, FLAGS.arch,
time.strftime('%Y-%m-%d-%H-%M',
time.localtime()))
if ParallelEnv().local_rank == 0 and not os.path.exists(output_dir):
os.makedirs(output_dir)
model.fit(train_dataset,
val_dataset,
batch_size=FLAGS.batch_size,
epochs=FLAGS.epoch,
save_dir=output_dir,
num_workers=FLAGS.num_workers)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument(
'data',
metavar='DIR',
help='path to dataset '
'(should have subdirectories named "train" and "val"')
parser.add_argument(
"--arch", type=str, default='resnet50', help="model name")
parser.add_argument(
"--device", type=str, default='gpu', help="device to run, cpu or gpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=90, type=int, help="number of epoch")
parser.add_argument(
'--lr',
'--learning-rate',
default=0.1,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch-size", default=64, type=int, help="batch size")
parser.add_argument(
"-n", "--num-workers", default=4, type=int, help="dataloader workers")
parser.add_argument(
"--output-dir", type=str, default='output', help="save dir")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
parser.add_argument(
"--eval-only", action='store_true', help="enable dygraph mode")
parser.add_argument(
"--lr-scheduler",
default='piecewise',
type=str,
help="learning rate scheduler")
parser.add_argument(
"--milestones",
nargs='+',
type=int,
default=[30, 60, 80],
help="piecewise decay milestones")
parser.add_argument(
"--weight-decay", default=1e-4, type=float, help="weight decay")
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
此差异已折叠。
......@@ -42,6 +42,14 @@ __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input', 'set_device']
def set_device(device):
"""
Args:
device (str): specify device type, 'cpu' or 'gpu'.
Returns:
fluid.CUDAPlace or fluid.CPUPlace: Created GPU or CPU place.
"""
assert isinstance(device, six.string_types) and device.lower() in ['cpu', 'gpu'], \
"Expected device in ['cpu', 'gpu'], but got {}".format(device)
......@@ -114,9 +122,9 @@ class Loss(object):
def forward(self, outputs, labels):
raise NotImplementedError()
def __call__(self, outputs, labels):
def __call__(self, outputs, labels=None):
labels = to_list(labels)
if in_dygraph_mode():
if in_dygraph_mode() and labels:
labels = [to_variable(l) for l in labels]
losses = to_list(self.forward(to_list(outputs), labels))
if self.average:
......@@ -410,7 +418,8 @@ class StaticGraphAdapter(object):
and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var
new_lr_var = prog.global_block().vars[lr_var.name]
self.model._optimizer._learning_rate_map[prog] = new_lr_var
losses = []
metrics = []
......@@ -852,8 +861,6 @@ class Model(fluid.dygraph.Layer):
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict in static graph mode")
if loss_function and not isinstance(labels, (list, Input)):
raise TypeError("'labels' must be list in static graph mode")
metrics = metrics or []
for metric in to_list(metrics):
......@@ -1083,7 +1090,11 @@ class Model(fluid.dygraph.Layer):
return eval_result
def predict(self, test_data, batch_size=1, num_workers=0):
def predict(self,
test_data,
batch_size=1,
num_workers=0,
stack_outputs=True):
"""
FIXME: add more comments and usage
Args:
......@@ -1096,6 +1107,12 @@ class Model(fluid.dygraph.Layer):
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
stack_output (bool): whether stack output field like a batch, as for an output
filed of a sample is in shape [X, Y], test_data contains N samples, predict
output field will be in shape [N, X, Y] if stack_output is True, and will
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
is False. stack_outputs as False is used for LoDTensor output situation,
it is recommended set as True if outputs contains no LoDTensor. Default False
"""
if fluid.in_dygraph_mode():
......@@ -1122,19 +1139,16 @@ class Model(fluid.dygraph.Layer):
if not isinstance(test_loader, Iterable):
loader = test_loader()
outputs = None
outputs = []
for data in tqdm.tqdm(loader):
if not fluid.in_dygraph_mode():
data = data[0]
outs = self.test(*data)
data = flatten(data)
outputs.append(self.test(data[:len(self._inputs)]))
if outputs is None:
outputs = outs
else:
outputs = [
np.vstack([x, outs[i]]) for i, x in enumerate(outputs)
]
# NOTE: for lod tensor output, we should not stack outputs
# for stacking may loss its detail info
outputs = list(zip(*outputs))
if stack_outputs:
outputs = [np.stack(outs, axis=0) for outs in outputs]
self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from . import resnet
from . import vgg
from . import mobilenetv1
from . import mobilenetv2
from . import darknet
from . import yolov3
from . import tsm
from .resnet import *
from .mobilenetv1 import *
from .mobilenetv2 import *
from .vgg import *
from .darknet import *
from .yolov3 import *
from .tsm import *
__all__ = resnet.__all__ \
+ vgg.__all__ \
+ mobilenetv1.__all__ \
+ mobilenetv2.__all__ \
+ darknet.__all__ \
+ yolov3.__all__ \
+ tsm.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from model import Model
from .download import get_weights_path
__all__ = ['DarkNet53', 'ConvBNLayer', 'darknet53']
# {num_layers: (url, md5)}
pretrain_infos = {
53: ('https://paddlemodels.bj.bcebos.com/hapi/darknet53.pdparams',
'2506357a5c31e865785112fc614a487d')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act="leaky"):
super(ConvBNLayer, self).__init__()
self.conv = Conv2D(
num_channels=ch_in,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=False,
act=None)
self.batch_norm = BatchNorm(
num_channels=ch_out,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs):
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
class DownSample(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=2,
padding=1):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding)
self.ch_out = ch_out
def forward(self, inputs):
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0)
self.conv2 = ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out*2,
filter_size=3,
stride=1,
padding=1)
def forward(self, inputs):
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out, count):
super(LayerWarp,self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out)
self.res_out_list = []
for i in range(1,count):
res_out = self.add_sublayer("basic_block_%d" % (i),
BasicBlock(
ch_out*2,
ch_out))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self,inputs):
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
class DarkNet53(Model):
def __init__(self, num_layers=53, ch_in=3):
super(DarkNet53, self).__init__()
assert num_layers in DarkNet_cfg.keys(), \
"only support num_layers in {} currently" \
.format(DarkNet_cfg.keys())
self.stages = DarkNet_cfg[num_layers]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=32,
filter_size=3,
stride=1,
padding=1)
self.downsample0 = DownSample(
ch_in=32,
ch_out=32 * 2)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64,128,256,512,1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer(
"stage_%d" % (i),
LayerWarp(
int(ch_in[i]),
32*(2**i),
stage))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i,
DownSample(
ch_in = 32*(2**(i+1)),
ch_out = 32*(2**(i+2))))
self.downsample_list.append(downsample)
def forward(self,inputs):
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]
def _darknet(num_layers=53, input_channels=3, pretrained=True):
model = DarkNet53(num_layers, input_channels)
if pretrained:
assert num_layers in pretrain_infos.keys(), \
"DarkNet{} do not have pretrained weights now, " \
"pretrained should be set as False".format(num_layers)
weight_path = get_weights_path(*(pretrain_infos[num_layers]))
assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def darknet53(input_channels=3, pretrained=True):
return _darknet(53, input_channels, pretrained)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import shutil
import requests
import tqdm
import hashlib
import time
from paddle.fluid.dygraph.parallel import ParallelEnv
import logging
logger = logging.getLogger(__name__)
__all__ = ['get_weights_path']
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
DOWNLOAD_RETRY_LIMIT = 3
def get_weights_path(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
"""
path, _ = get_path(url, WEIGHTS_HOME, md5sum)
return path
def map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
"""
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)
exist_flag = False
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
exist_flag = True
if ParallelEnv().local_rank == 0:
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
return fullpath, exist_flag
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if ParallelEnv().local_rank == 0:
logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from model import Model
from .download import get_weights_path
__all__ = ['MobileNetV1', 'mobilenet_v1']
model_urls = {
'mobilenetv1_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams',
'bf0d25cb0bed1114d9dac9384ce2b4a6')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False)
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
y = self._pointwise_conv(y)
return y
class MobileNetV1(Model):
"""MobileNetV1 model from
`"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" <https://arxiv.org/abs/1704.04861>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000.
"""
def __init__(self, scale=1.0, class_dim=1000):
super(MobileNetV1, self).__init__()
self.scale = scale
self.dwsl = []
self.conv1 = ConvBNLayer(
num_channels=3,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
dws21 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale),
name="conv2_1")
self.dwsl.append(dws21)
dws22 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale),
name="conv2_2")
self.dwsl.append(dws22)
dws31 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale),
name="conv3_1")
self.dwsl.append(dws31)
dws32 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale),
name="conv3_2")
self.dwsl.append(dws32)
dws41 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale),
name="conv4_1")
self.dwsl.append(dws41)
dws42 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale),
name="conv4_2")
self.dwsl.append(dws42)
for i in range(5):
tmp = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale),
name="conv5_" + str(i + 1))
self.dwsl.append(tmp)
dws56 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale),
name="conv5_6")
self.dwsl.append(dws56)
dws6 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale),
name="conv6")
self.dwsl.append(dws6)
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.out = Linear(
int(1024 * scale),
class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
y = dws(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV1(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def mobilenet_v1(pretrained=False, scale=1.0):
model = _mobilenet('mobilenetv1_' + str(scale), pretrained, scale=scale)
return model
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from model import Model
from .download import get_weights_path
__all__ = ['MobileNetV2', 'mobilenet_v2']
model_urls = {
'mobilenetv2_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams',
'8ff74f291f72533f2a7956a4efff9d88')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
use_cudnn=True):
super(ConvBNLayer, self).__init__()
tmp_param = ParamAttr(name=self.full_name() + "_weights")
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=tmp_param,
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs, if_act=True):
y = self._conv(inputs)
y = self._batch_norm(y)
if if_act:
y = fluid.layers.relu6(y)
return y
class InvertedResidualUnit(fluid.dygraph.Layer):
def __init__(
self,
num_channels,
num_in_filter,
num_filters,
stride,
filter_size,
padding,
expansion_factor, ):
super(InvertedResidualUnit, self).__init__()
num_expfilter = int(round(num_in_filter * expansion_factor))
self._expand_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
self._bottleneck_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
use_cudnn=False)
self._linear_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
def forward(self, inputs, ifshortcut):
y = self._expand_conv(inputs, if_act=True)
y = self._bottleneck_conv(y, if_act=True)
y = self._linear_conv(y, if_act=False)
if ifshortcut:
y = fluid.layers.elementwise_add(inputs, y)
return y
class InvresiBlocks(fluid.dygraph.Layer):
def __init__(self, in_c, t, c, n, s):
super(InvresiBlocks, self).__init__()
self._first_block = InvertedResidualUnit(
num_channels=in_c,
num_in_filter=in_c,
num_filters=c,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t)
self._inv_blocks = []
for i in range(1, n):
tmp = self.add_sublayer(
sublayer=InvertedResidualUnit(
num_channels=c,
num_in_filter=c,
num_filters=c,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t),
name=self.full_name() + "_" + str(i + 1))
self._inv_blocks.append(tmp)
def forward(self, inputs):
y = self._first_block(inputs, ifshortcut=False)
for inv_block in self._inv_blocks:
y = inv_block(y, ifshortcut=True)
return y
class MobileNetV2(Model):
"""MobileNetV2 model from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000.
"""
def __init__(self, scale=1.0, class_dim=1000):
super(MobileNetV2, self).__init__()
self.scale = scale
self.class_dim = class_dim
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
#1. conv1
self._conv1 = ConvBNLayer(
num_channels=3,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1)
#2. bottleneck sequences
self._invl = []
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
tmp = self.add_sublayer(
sublayer=InvresiBlocks(
in_c=in_c, t=t, c=int(c * scale), n=n, s=s),
name='conv' + str(i))
self._invl.append(tmp)
in_c = int(c * scale)
#3. last_conv
self._out_c = int(1280 * scale) if scale > 1.0 else 1280
self._conv9 = ConvBNLayer(
num_channels=in_c,
num_filters=self._out_c,
filter_size=1,
stride=1,
padding=0)
#4. pool
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
#5. fc
tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
self._fc = Linear(
self._out_c,
class_dim,
act='softmax',
param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
y = inv(y)
y = self._conv9(y, if_act=True)
y = self._pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def mobilenet_v2(pretrained=False, scale=1.0):
"""MobileNetV2
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = _mobilenet('mobilenetv2_' + str(scale), pretrained, scale=scale)
return model
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册