提交 44d49573 编写于 作者: Q qingqing01

Merge branch 'master' of https://github.com/PaddlePaddle/hapi into cyclegan

# BMN 视频动作定位模型高层API实现
---
## 内容
- [模型简介](#模型简介)
- [代码结构](#代码结构)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。
<p align="center">
<img src="./BMN.png" height=300 width=500 hspace='10'/> <br />
BMN Overview
</p>
## 代码结构
```
├── bmn.yaml # 网络配置文件,快速配置参数
├── run.sh # 快速运行脚本,可直接开始多卡训练
├── train.py # 训练代码,训练网络
├── eval.py # 评估代码,评估网络性能
├── predict.py # 预测代码,针对任意输入预测结果
├── bmn_model.py # 网络结构与损失函数定义
├── bmn_metric.py # 精度评估方法定义
├── reader.py # 数据reader,构造Dataset和Dataloader
├── bmn_utils.py # 模型细节相关代码
├── config_utils.py # 配置细节相关代码
├── eval_anet_prop.py # 计算精度评估指标
└── infer.list # 推断文件列表
```
## 数据准备
BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。对应的标签文件请下载[label](https://paddlemodels.bj.bcebos.com/video_detection/activitynet_1.3_annotations.json)并修改bmn.yaml中的标签文件路径anno\_file。
## 模型训练
数据准备完成后,可通过如下两种方式启动训练:
默认使用4卡训练,启动方式如下:
bash run.sh
若使用单卡训练,启动方式如下:
export CUDA_VISIBLE_DEVICES=0
python train.py
- 代码运行需要先安装pandas
- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
- 单卡训练时,请将配置文件中的batch_size调整为16
**训练策略:**
* 采用Adam优化器,初始learning\_rate=0.001
* 权重衰减系数为1e-4
* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
## 模型评估
训练完成后,可通过如下方式进行模型评估:
python eval.py --weights=$PATH_TO_WEIGHTS
- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。
- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。
使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下:
- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org)
- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:由于第三方评估代码不支持python3,此处建议使用python2进行评估;若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。)
- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn/Evaluation/data目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。
- 计算精度指标
```python eval_anet_prop.py```
在ActivityNet1.3数据集下评估精度如下:
| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
| :---: | :---: | :---: | :---: | :---: |
| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% |
## 模型推断
可通过如下方式启动模型推断:
python predict.py --weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST
- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。
## 参考论文
- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
MODEL:
name: "BMN"
tscale: 100
dscale: 100
feat_dim: 400
prop_boundary_ratio: 0.5
num_sample: 32
num_sample_perbin: 3
anno_file: "./activitynet_1.3_annotations.json"
feat_path: './fix_feat_100'
TRAIN:
subset: "train"
epoch: 9
batch_size: 4
num_workers: 4
use_shuffle: True
device: "gpu"
num_gpus: 4
learning_rate: 0.001
learning_rate_decay: 0.1
lr_decay_iter: 4200
l2_weight_decay: 1e-4
VALID:
subset: "validation"
TEST:
subset: "validation"
batch_size: 1
num_workers: 1
use_buffer: False
snms_alpha: 0.001
snms_t1: 0.5
snms_t2: 0.9
output_path: "output/EVAL/BMN_results"
result_path: "evaluate_results"
INFER:
subset: "test"
batch_size: 1
num_workers: 1
use_buffer: False
snms_alpha: 0.4
snms_t1: 0.5
snms_t2: 0.9
filelist: './infer.list'
output_path: "output/INFER/BMN_results"
result_path: "predict_results"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import pandas as pd
import os
import sys
import json
sys.path.append('../')
from metrics import Metric
from bmn_utils import boundary_choose, bmn_post_processing
class BmnMetric(Metric):
"""
only support update with batch_size=1
"""
def __init__(self, cfg, mode):
super(BmnMetric, self).__init__()
self.cfg = cfg
self.mode = mode
#get video_dict and video_list
if self.mode == 'test':
self.get_test_dataset_dict()
elif self.mode == 'infer':
self.get_infer_dataset_dict()
def add_metric_op(self, preds, label):
pred_bm, pred_start, pred_en = preds
video_index = label[-1]
return [pred_bm, pred_start, pred_en, video_index] #return list
def update(self, pred_bm, pred_start, pred_end, fid):
# generate proposals
pred_start = pred_start[0]
pred_end = pred_end[0]
fid = fid[0]
if self.mode == 'infer':
output_path = self.cfg.INFER.output_path
else:
output_path = self.cfg.TEST.output_path
tscale = self.cfg.MODEL.tscale
dscale = self.cfg.MODEL.dscale
snippet_xmins = [1.0 / tscale * i for i in range(tscale)]
snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)]
cols = ["xmin", "xmax", "score"]
video_name = self.video_list[fid]
pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :]
start_mask = boundary_choose(pred_start)
start_mask[0] = 1.
end_mask = boundary_choose(pred_end)
end_mask[-1] = 1.
score_vector_list = []
for idx in range(dscale):
for jdx in range(tscale):
start_index = jdx
end_index = start_index + idx
if end_index < tscale and start_mask[
start_index] == 1 and end_mask[end_index] == 1:
xmin = snippet_xmins[start_index]
xmax = snippet_xmaxs[end_index]
xmin_score = pred_start[start_index]
xmax_score = pred_end[end_index]
bm_score = pred_bm[idx, jdx]
conf_score = xmin_score * xmax_score * bm_score
score_vector_list.append([xmin, xmax, conf_score])
score_vector_list = np.stack(score_vector_list)
video_df = pd.DataFrame(score_vector_list, columns=cols)
video_df.to_csv(
os.path.join(output_path, "%s.csv" % video_name), index=False)
return 0 # result has saved in output path
def accumulate(self):
return 'post_processing is required...' # required method
def reset(self):
print("Post_processing....This may take a while")
if self.mode == 'test':
bmn_post_processing(self.video_dict, self.cfg.TEST.subset,
self.cfg.TEST.output_path,
self.cfg.TEST.result_path)
elif self.mode == 'infer':
bmn_post_processing(self.video_dict, self.cfg.INFER.subset,
self.cfg.INFER.output_path,
self.cfg.INFER.result_path)
def name(self):
return 'bmn_metric'
def get_test_dataset_dict(self):
anno_file = self.cfg.MODEL.anno_file
annos = json.load(open(anno_file))
subset = self.cfg.TEST.subset
self.video_dict = {}
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
def get_infer_dataset_dict(self):
file_list = self.cfg.INFER.filelist
annos = json.load(open(file_list))
self.video_dict = {}
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
import math
from bmn_utils import get_interp1d_mask
from model import Model, Loss
DATATYPE = 'float32'
# Net
class Conv1D(fluid.dygraph.Layer):
def __init__(self,
prefix,
num_channels=256,
num_filters=256,
size_k=3,
padding=1,
groups=1,
act="relu"):
super(Conv1D, self).__init__()
fan_in = num_channels * size_k * 1
k = 1. / math.sqrt(fan_in)
param_attr = ParamAttr(
name=prefix + "_w",
initializer=fluid.initializer.Uniform(
low=-k, high=k))
bias_attr = ParamAttr(
name=prefix + "_b",
initializer=fluid.initializer.Uniform(
low=-k, high=k))
self._conv2d = fluid.dygraph.Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=(1, size_k),
stride=1,
padding=(0, padding),
groups=groups,
act=act,
param_attr=param_attr,
bias_attr=bias_attr)
def forward(self, x):
x = fluid.layers.unsqueeze(input=x, axes=[2])
x = self._conv2d(x)
x = fluid.layers.squeeze(input=x, axes=[2])
return x
class BMN(Model):
def __init__(self, cfg, is_dygraph=True):
super(BMN, self).__init__()
#init config
self.tscale = cfg.MODEL.tscale
self.dscale = cfg.MODEL.dscale
self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
self.num_sample = cfg.MODEL.num_sample
self.num_sample_perbin = cfg.MODEL.num_sample_perbin
self.is_dygraph = is_dygraph
self.hidden_dim_1d = 256
self.hidden_dim_2d = 128
self.hidden_dim_3d = 512
# Base Module
self.b_conv1 = Conv1D(
prefix="Base_1",
num_channels=400,
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.b_conv2 = Conv1D(
prefix="Base_2",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
# Temporal Evaluation Module
self.ts_conv1 = Conv1D(
prefix="TEM_s1",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.ts_conv2 = Conv1D(
prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid")
self.te_conv1 = Conv1D(
prefix="TEM_e1",
num_filters=self.hidden_dim_1d,
size_k=3,
padding=1,
groups=4,
act="relu")
self.te_conv2 = Conv1D(
prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid")
#Proposal Evaluation Module
self.p_conv1 = Conv1D(
prefix="PEM_1d",
num_filters=self.hidden_dim_2d,
size_k=3,
padding=1,
act="relu")
# init to speed up
sample_mask_array = get_interp1d_mask(
self.tscale, self.dscale, self.prop_boundary_ratio,
self.num_sample, self.num_sample_perbin)
if self.is_dygraph:
self.sample_mask = fluid.dygraph.base.to_variable(
sample_mask_array)
else: # static
self.sample_mask = fluid.layers.create_parameter(
shape=[
self.tscale, self.num_sample * self.dscale * self.tscale
],
dtype=DATATYPE,
attr=fluid.ParamAttr(
name="sample_mask", trainable=False),
default_initializer=fluid.initializer.NumpyArrayInitializer(
sample_mask_array))
self.sample_mask.stop_gradient = True
self.p_conv3d1 = fluid.dygraph.Conv3D(
num_channels=128,
num_filters=self.hidden_dim_3d,
filter_size=(self.num_sample, 1, 1),
stride=(self.num_sample, 1, 1),
padding=0,
act="relu",
param_attr=ParamAttr(name="PEM_3d1_w"),
bias_attr=ParamAttr(name="PEM_3d1_b"))
self.p_conv2d1 = fluid.dygraph.Conv2D(
num_channels=512,
num_filters=self.hidden_dim_2d,
filter_size=1,
stride=1,
padding=0,
act="relu",
param_attr=ParamAttr(name="PEM_2d1_w"),
bias_attr=ParamAttr(name="PEM_2d1_b"))
self.p_conv2d2 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=self.hidden_dim_2d,
filter_size=3,
stride=1,
padding=1,
act="relu",
param_attr=ParamAttr(name="PEM_2d2_w"),
bias_attr=ParamAttr(name="PEM_2d2_b"))
self.p_conv2d3 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=self.hidden_dim_2d,
filter_size=3,
stride=1,
padding=1,
act="relu",
param_attr=ParamAttr(name="PEM_2d3_w"),
bias_attr=ParamAttr(name="PEM_2d3_b"))
self.p_conv2d4 = fluid.dygraph.Conv2D(
num_channels=128,
num_filters=2,
filter_size=1,
stride=1,
padding=0,
act="sigmoid",
param_attr=ParamAttr(name="PEM_2d4_w"),
bias_attr=ParamAttr(name="PEM_2d4_b"))
def forward(self, x):
#Base Module
x = self.b_conv1(x)
x = self.b_conv2(x)
#TEM
xs = self.ts_conv1(x)
xs = self.ts_conv2(xs)
xs = fluid.layers.squeeze(xs, axes=[1])
xe = self.te_conv1(x)
xe = self.te_conv2(xe)
xe = fluid.layers.squeeze(xe, axes=[1])
#PEM
xp = self.p_conv1(x)
#BM layer
xp = fluid.layers.matmul(xp, self.sample_mask)
xp = fluid.layers.reshape(
xp, shape=[0, 0, -1, self.dscale, self.tscale])
xp = self.p_conv3d1(xp)
xp = fluid.layers.squeeze(xp, axes=[2])
xp = self.p_conv2d1(xp)
xp = self.p_conv2d2(xp)
xp = self.p_conv2d3(xp)
xp = self.p_conv2d4(xp)
return xp, xs, xe
class BmnLoss(Loss):
def __init__(self, cfg):
super(BmnLoss, self).__init__()
self.cfg = cfg
def _get_mask(self):
dscale = self.cfg.MODEL.dscale
tscale = self.cfg.MODEL.tscale
bm_mask = []
for idx in range(dscale):
mask_vector = [1 for i in range(tscale - idx)
] + [0 for i in range(idx)]
bm_mask.append(mask_vector)
bm_mask = np.array(bm_mask, dtype=np.float32)
self_bm_mask = fluid.layers.create_global_var(
shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True)
fluid.layers.assign(bm_mask, self_bm_mask)
self_bm_mask.stop_gradient = True
return self_bm_mask
def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end):
def bi_loss(pred_score, gt_label):
pred_score = fluid.layers.reshape(
x=pred_score, shape=[-1], inplace=False)
gt_label = fluid.layers.reshape(
x=gt_label, shape=[-1], inplace=False)
gt_label.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE)
num_entries = fluid.layers.cast(
fluid.layers.shape(pmask), dtype=DATATYPE)
num_positive = fluid.layers.cast(
fluid.layers.reduce_sum(pmask), dtype=DATATYPE)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
temp = fluid.layers.log(pred_score + epsilon)
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
loss = -1 * (loss_pos + loss_neg)
return loss
loss_start = bi_loss(pred_start, gt_start)
loss_end = bi_loss(pred_end, gt_end)
loss = loss_start + loss_end
return loss
def pem_reg_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE)
u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3)
u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.)
u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
u_lmask = fluid.layers.elementwise_mul(u_lmask, mask)
num_h = fluid.layers.cast(
fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
num_m = fluid.layers.cast(
fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
num_l = fluid.layers.cast(
fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
r_m = num_h / num_m
u_smmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
r_l = num_h / num_l
u_slmask = fluid.layers.uniform_random(
shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
dtype=DATATYPE,
min=0.0,
max=1.0)
u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
weights = u_hmask + u_smmask + u_slmask
weights.stop_gradient = True
loss = fluid.layers.square_error_cost(pred_score, gt_iou_map)
loss = fluid.layers.elementwise_mul(loss, weights)
loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
weights)
return loss
def pem_cls_loss_func(self, pred_score, gt_iou_map, mask):
gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
gt_iou_map.stop_gradient = True
pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE)
nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE)
nmask = fluid.layers.elementwise_mul(nmask, mask)
num_positive = fluid.layers.reduce_sum(pmask)
num_entries = num_positive + fluid.layers.reduce_sum(nmask)
ratio = num_entries / num_positive
coef_0 = 0.5 * ratio / (ratio - 1)
coef_1 = 0.5 * ratio
epsilon = 0.000001
loss_pos = fluid.layers.elementwise_mul(
fluid.layers.log(pred_score + epsilon), pmask)
loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos)
loss_neg = fluid.layers.elementwise_mul(
fluid.layers.log(1.0 - pred_score + epsilon), nmask)
loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg)
loss = -1 * (loss_pos + loss_neg) / num_entries
return loss
def forward(self, outputs, labels):
pred_bm, pred_start, pred_end = outputs
if len(labels) == 3:
gt_iou_map, gt_start, gt_end = labels
elif len(labels) == 4: # video_index used in eval mode
gt_iou_map, gt_start, gt_end, video_index = labels
pred_bm_reg = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[0], ends=[1]),
axes=[1])
pred_bm_cls = fluid.layers.squeeze(
fluid.layers.slice(
pred_bm, axes=[1], starts=[1], ends=[2]),
axes=[1])
bm_mask = self._get_mask()
pem_reg_loss = self.pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask)
pem_cls_loss = self.pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask)
tem_loss = self.tem_loss_func(pred_start, pred_end, gt_start, gt_end)
loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
return loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import pandas as pd
import multiprocessing as mp
import json
import os
import math
def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute jaccard score between a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
union_len = len_anchors - inter_len + box_max - box_min
jaccard = np.divide(inter_len, union_len)
return jaccard
def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
"""Compute intersection between score a box and the anchors.
"""
len_anchors = anchors_max - anchors_min
int_xmin = np.maximum(anchors_min, box_min)
int_xmax = np.minimum(anchors_max, box_max)
inter_len = np.maximum(int_xmax - int_xmin, 0.)
scores = np.divide(inter_len, len_anchors)
return scores
def boundary_choose(score_list):
max_score = max(score_list)
mask_high = (score_list > max_score * 0.5)
score_list = list(score_list)
score_middle = np.array([0.0] + score_list + [0.0])
score_front = np.array([0.0, 0.0] + score_list)
score_back = np.array(score_list + [0.0, 0.0])
mask_peak = ((score_middle > score_front) & (score_middle > score_back))
mask_peak = mask_peak[1:-1]
mask = (mask_high | mask_peak).astype('float32')
return mask
def soft_nms(df, alpha, t1, t2):
'''
df: proposals generated by network;
alpha: alpha value of Gaussian decaying function;
t1, t2: threshold for soft nms.
'''
df = df.sort_values(by="score", ascending=False)
tstart = list(df.xmin.values[:])
tend = list(df.xmax.values[:])
tscore = list(df.score.values[:])
rstart = []
rend = []
rscore = []
while len(tscore) > 1 and len(rscore) < 101:
max_index = tscore.index(max(tscore))
tmp_iou_list = iou_with_anchors(
np.array(tstart),
np.array(tend), tstart[max_index], tend[max_index])
for idx in range(0, len(tscore)):
if idx != max_index:
tmp_iou = tmp_iou_list[idx]
tmp_width = tend[max_index] - tstart[max_index]
if tmp_iou > t1 + (t2 - t1) * tmp_width:
tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
alpha)
rstart.append(tstart[max_index])
rend.append(tend[max_index])
rscore.append(tscore[max_index])
tstart.pop(max_index)
tend.pop(max_index)
tscore.pop(max_index)
newDf = pd.DataFrame()
newDf['score'] = rscore
newDf['xmin'] = rstart
newDf['xmax'] = rend
return newDf
def video_process(video_list,
video_dict,
output_path,
result_dict,
snms_alpha=0.4,
snms_t1=0.55,
snms_t2=0.9):
for video_name in video_list:
print("Processing video........" + video_name)
df = pd.read_csv(os.path.join(output_path, video_name + ".csv"))
if len(df) > 1:
df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
video_duration = video_dict[video_name]["duration_second"]
proposal_list = []
for idx in range(min(100, len(df))):
tmp_prop={"score":df.score.values[idx], \
"segment":[max(0,df.xmin.values[idx])*video_duration, \
min(1,df.xmax.values[idx])*video_duration]}
proposal_list.append(tmp_prop)
result_dict[video_name[2:]] = proposal_list
def bmn_post_processing(video_dict, subset, output_path, result_path):
video_list = video_dict.keys()
video_list = list(video_dict.keys())
global result_dict
result_dict = mp.Manager().dict()
pp_num = 12
num_videos = len(video_list)
num_videos_per_thread = int(num_videos / pp_num)
processes = []
for tid in range(pp_num - 1):
tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
num_videos_per_thread]
p = mp.Process(
target=video_process,
args=(tmp_video_list, video_dict, output_path, result_dict))
p.start()
processes.append(p)
tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
p = mp.Process(
target=video_process,
args=(tmp_video_list, video_dict, output_path, result_dict))
p.start()
processes.append(p)
for p in processes:
p.join()
result_dict = dict(result_dict)
output_dict = {
"version": "VERSION 1.3",
"results": result_dict,
"external_data": {}
}
outfile = open(
os.path.join(result_path, "bmn_results_%s.json" % subset), "w")
json.dump(output_dict, outfile)
outfile.close()
def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample,
num_sample_perbin):
""" generate sample mask for a boundary-matching pair """
plen = float(seg_xmax - seg_xmin)
plen_sample = plen / (num_sample * num_sample_perbin - 1.0)
total_samples = [
seg_xmin + plen_sample * ii
for ii in range(num_sample * num_sample_perbin)
]
p_mask = []
for idx in range(num_sample):
bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) *
num_sample_perbin]
bin_vector = np.zeros([tscale])
for sample in bin_samples:
sample_upper = math.ceil(sample)
sample_decimal, sample_down = math.modf(sample)
if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0:
bin_vector[int(sample_down)] += 1 - sample_decimal
if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0:
bin_vector[int(sample_upper)] += sample_decimal
bin_vector = 1.0 / num_sample_perbin * bin_vector
p_mask.append(bin_vector)
p_mask = np.stack(p_mask, axis=1)
return p_mask
def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample,
num_sample_perbin):
""" generate sample mask for each point in Boundary-Matching Map """
mask_mat = []
for start_index in range(tscale):
mask_mat_vector = []
for duration_index in range(dscale):
if start_index + duration_index < tscale:
p_xmin = start_index
p_xmax = start_index + duration_index
center_len = float(p_xmax - p_xmin) + 1
sample_xmin = p_xmin - center_len * prop_boundary_ratio
sample_xmax = p_xmax + center_len * prop_boundary_ratio
p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax,
tscale, num_sample,
num_sample_perbin)
else:
p_mask = np.zeros([tscale, num_sample])
mask_mat_vector.append(p_mask)
mask_mat_vector = np.stack(mask_mat_vector, axis=2)
mask_mat.append(mask_mat_vector)
mask_mat = np.stack(mask_mat, axis=3)
mask_mat = mask_mat.astype(np.float32)
sample_mask = np.reshape(mask_mat, [tscale, -1])
return sample_mask
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import os
import sys
import logging
import paddle.fluid as fluid
sys.path.append('../')
from model import set_device, Input
from bmn_metric import BmnMetric
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("BMN test for performance evaluation.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode, only support dynamic mode at present time")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--device',
type=str,
default='gpu',
help='gpu or cpu, default use gpu.')
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Performance Evaluation
def test_bmn(args):
# only support dynamic mode at present time
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
config = parse_config(args.config_file)
eval_cfg = merge_configs(config, 'test', vars(args))
if not os.path.isdir(config.TEST.output_path):
os.makedirs(config.TEST.output_path)
if not os.path.isdir(config.TEST.result_path):
os.makedirs(config.TEST.result_path)
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
gt_iou_map = Input(
[None, config.MODEL.dscale, config.MODEL.tscale],
'float32',
name='gt_iou_map')
gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
video_idx = Input([None, 1], 'int64', name='video_idx')
labels = [gt_iou_map, gt_start, gt_end, video_idx]
#data
eval_dataset = BmnDataset(eval_cfg, 'test')
#model
model = BMN(config, args.dynamic)
model.prepare(
loss_function=BmnLoss(config),
metrics=BmnMetric(
config, mode='test'),
inputs=inputs,
labels=labels,
device=device)
#load checkpoint
if args.weights:
assert os.path.exists(args.weights + '.pdparams'), \
"Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
model.evaluate(
eval_data=eval_dataset,
batch_size=eval_cfg.TEST.batch_size,
num_workers=eval_cfg.TEST.num_workers,
log_freq=args.log_interval)
logger.info("[EVAL] eval finished")
if __name__ == '__main__':
args = parse_args()
test_bmn(args)
'''
Calculate AR@N and AUC;
Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)
'''
import sys
sys.path.append('./Evaluation')
from eval_proposal import ANETproposal
import numpy as np
import argparse
import os
parser = argparse.ArgumentParser("Eval AR vs AN of proposal")
parser.add_argument(
'--eval_file',
type=str,
default='bmn_results_validation.json',
help='name of results file to eval')
def run_evaluation(ground_truth_filename,
proposal_filename,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation'):
anet_proposal = ANETproposal(
ground_truth_filename,
proposal_filename,
tiou_thresholds=tiou_thresholds,
max_avg_nr_proposals=max_avg_nr_proposals,
subset=subset,
verbose=True,
check_status=False)
anet_proposal.evaluate()
recall = anet_proposal.recall
average_recall = anet_proposal.avg_recall
average_nr_proposals = anet_proposal.proposals_per_video
return (average_nr_proposals, average_recall, recall)
def plot_metric(average_nr_proposals,
average_recall,
recall,
tiou_thresholds=np.linspace(0.5, 0.95, 10)):
fn_size = 14
plt.figure(num=None, figsize=(12, 8))
ax = plt.subplot(1, 1, 1)
colors = [
'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo'
]
area_under_curve = np.zeros_like(tiou_thresholds)
for i in range(recall.shape[0]):
area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
for idx, tiou in enumerate(tiou_thresholds[::2]):
ax.plot(
average_nr_proposals,
recall[2 * idx, :],
color=colors[idx + 1],
label="tiou=[" + str(tiou) + "], area=" + str(
int(area_under_curve[2 * idx] * 100) / 100.),
linewidth=4,
linestyle='--',
marker=None)
# Plots Average Recall vs Average number of proposals.
ax.plot(
average_nr_proposals,
average_recall,
color=colors[0],
label="tiou = 0.5:0.05:0.95," + " area=" + str(
int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.),
linewidth=4,
linestyle='-',
marker=None)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
[handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
plt.ylabel('Average Recall', fontsize=fn_size)
plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
plt.grid(b=True, which="both")
plt.ylim([0, 1.0])
plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
plt.show()
if __name__ == "__main__":
args = parser.parse_args()
eval_file = args.eval_file
eval_file_path = os.path.join("evaluate_results", eval_file)
uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
"./Evaluation/data/activity_net_1_3_new.json",
eval_file_path,
max_avg_nr_proposals=100,
tiou_thresholds=np.linspace(0.5, 0.95, 10),
subset='validation')
print("AR@1; AR@5; AR@10; AR@100")
print("%.02f %.02f %.02f %.02f" %
(100 * np.mean(uniform_recall_valid[:, 0]),
100 * np.mean(uniform_recall_valid[:, 4]),
100 * np.mean(uniform_recall_valid[:, 9]),
100 * np.mean(uniform_recall_valid[:, -1])))
{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}}
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import sys
import os
import logging
import paddle.fluid as fluid
sys.path.append('../')
from model import set_device, Input
from bmn_metric import BmnMetric
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("BMN inference.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode, only support dynamic mode at present time")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--device', type=str, default='GPU', help='default use gpu.')
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--save_dir',
type=str,
default="predict_results/",
help='output dir path, default to use ./predict_results/')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Prediction
def infer_bmn(args):
# only support dynamic mode at present time
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
config = parse_config(args.config_file)
infer_cfg = merge_configs(config, 'infer', vars(args))
if not os.path.isdir(config.INFER.output_path):
os.makedirs(config.INFER.output_path)
if not os.path.isdir(config.INFER.result_path):
os.makedirs(config.INFER.result_path)
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
labels = [Input([None, 1], 'int64', name='video_idx')]
#data
infer_dataset = BmnDataset(infer_cfg, 'infer')
model = BMN(config, args.dynamic)
model.prepare(
metrics=BmnMetric(
config, mode='infer'),
inputs=inputs,
labels=labels,
device=device)
# load checkpoint
if args.weights:
assert os.path.exists(
args.weights +
".pdparams"), "Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
# here use model.eval instead of model.test, as post process is required in our case
model.evaluate(
eval_data=infer_dataset,
batch_size=infer_cfg.TEST.batch_size,
num_workers=infer_cfg.TEST.num_workers,
log_freq=args.log_interval)
logger.info("[INFER] infer finished")
if __name__ == '__main__':
args = parse_args()
infer_bmn(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import numpy as np
import json
import logging
import os
import sys
sys.path.append('../')
from distributed import DistributedBatchSampler
from paddle.fluid.io import Dataset, DataLoader
logger = logging.getLogger(__name__)
from config_utils import *
from bmn_utils import iou_with_anchors, ioa_with_anchors
DATATYPE = "float32"
class BmnDataset(Dataset):
def __init__(self, cfg, mode):
self.mode = mode
self.tscale = cfg.MODEL.tscale # 100
self.dscale = cfg.MODEL.dscale # 100
self.anno_file = cfg.MODEL.anno_file
self.feat_path = cfg.MODEL.feat_path
self.file_list = cfg.INFER.filelist
self.subset = cfg[mode.upper()]['subset']
self.tgap = 1. / self.tscale
self.get_dataset_dict()
self.get_match_map()
def __getitem__(self, index):
video_name = self.video_list[index]
video_idx = self.video_list.index(video_name)
video_feat = self.load_file(video_name)
if self.mode == 'infer':
return video_feat, video_idx
else:
gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
if self.mode == 'train' or self.mode == 'valid':
return video_feat, gt_iou_map, gt_start, gt_end
elif self.mode == 'test':
return video_feat, gt_iou_map, gt_start, gt_end, video_idx
def __len__(self):
return len(self.video_list)
def get_dataset_dict(self):
assert (
os.path.exists(self.feat_path)), "Input feature path not exists"
assert (os.listdir(self.feat_path)), "No feature file in feature path"
self.video_dict = {}
if self.mode == "infer":
annos = json.load(open(self.file_list))
for video_name in annos.keys():
self.video_dict[video_name] = annos[video_name]
else:
annos = json.load(open(self.anno_file))
for video_name in annos.keys():
video_subset = annos[video_name]["subset"]
if self.subset in video_subset:
self.video_dict[video_name] = annos[video_name]
self.video_list = list(self.video_dict.keys())
self.video_list.sort()
print("%s subset video numbers: %d" %
(self.subset, len(self.video_list)))
video_name_set = set(
[video_name + '.npy' for video_name in self.video_list])
assert (video_name_set.intersection(set(os.listdir(self.feat_path))) ==
video_name_set), "Input feature not exists in feature path"
def get_match_map(self):
match_map = []
for idx in range(self.tscale):
tmp_match_window = []
xmin = self.tgap * idx
for jdx in range(1, self.tscale + 1):
xmax = xmin + self.tgap * jdx
tmp_match_window.append([xmin, xmax])
match_map.append(tmp_match_window)
match_map = np.array(match_map)
match_map = np.transpose(match_map, [1, 0, 2])
match_map = np.reshape(match_map, [-1, 2])
self.match_map = match_map
self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
def get_video_label(self, video_name):
video_info = self.video_dict[video_name]
video_second = video_info['duration_second']
video_labels = video_info['annotations']
gt_bbox = []
gt_iou_map = []
for gt in video_labels:
tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
gt_bbox.append([tmp_start, tmp_end])
tmp_gt_iou_map = iou_with_anchors(
self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end)
tmp_gt_iou_map = np.reshape(tmp_gt_iou_map,
[self.dscale, self.tscale])
gt_iou_map.append(tmp_gt_iou_map)
gt_iou_map = np.array(gt_iou_map)
gt_iou_map = np.max(gt_iou_map, axis=0)
gt_bbox = np.array(gt_bbox)
gt_xmins = gt_bbox[:, 0]
gt_xmaxs = gt_bbox[:, 1]
gt_len_small = 3 * self.tgap
gt_start_bboxs = np.stack(
(gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
gt_end_bboxs = np.stack(
(gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
match_score_start = []
for jdx in range(len(self.anchor_xmin)):
match_score_start.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
match_score_end = []
for jdx in range(len(self.anchor_xmin)):
match_score_end.append(
np.max(
ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
gt_start = np.array(match_score_start)
gt_end = np.array(match_score_end)
return gt_iou_map.astype(DATATYPE), gt_start.astype(
DATATYPE), gt_end.astype(DATATYPE)
def load_file(self, video_name):
file_name = video_name + ".npy"
file_path = os.path.join(self.feat_path, file_name)
video_feat = np.load(file_path)
video_feat = video_feat.T
video_feat = video_feat.astype("float32")
return video_feat
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch train.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
import argparse
import logging
import sys
import os
sys.path.append('../')
from model import set_device, Input
from bmn_model import BMN, BmnLoss
from reader import BmnDataset
from config_utils import *
DATATYPE = 'float32'
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Paddle high level api of BMN.")
parser.add_argument(
"-d",
"--dynamic",
default=True,
action='store_true',
help="enable dygraph mode")
parser.add_argument(
'--config_file',
type=str,
default='bmn.yaml',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='training batch size. None to use config file setting.')
parser.add_argument(
'--learning_rate',
type=float,
default=0.001,
help='learning rate use for training. None to use config file setting.')
parser.add_argument(
'--resume',
type=str,
default=None,
help='filename to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.')
parser.add_argument(
'--device',
type=str,
default='gpu',
help='gpu or cpu, default use gpu.')
parser.add_argument(
'--epoch',
type=int,
default=9,
help='epoch number, 0 for read from config file')
parser.add_argument(
'--valid_interval',
type=int,
default=1,
help='validation epoch interval, 0 for no validation.')
parser.add_argument(
'--save_dir',
type=str,
default="checkpoint",
help='path to save train snapshoot')
parser.add_argument(
'--log_interval',
type=int,
default=10,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
# Optimizer
def optimizer(cfg, parameter_list):
bd = [cfg.TRAIN.lr_decay_iter]
base_lr = cfg.TRAIN.learning_rate
lr_decay = cfg.TRAIN.learning_rate_decay
l2_weight_decay = cfg.TRAIN.l2_weight_decay
lr = [base_lr, base_lr * lr_decay]
optimizer = fluid.optimizer.Adam(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
parameter_list=parameter_list,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=l2_weight_decay))
return optimizer
# TRAIN
def train_bmn(args):
device = set_device(args.device)
fluid.enable_dygraph(device) if args.dynamic else None
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
config = parse_config(args.config_file)
train_cfg = merge_configs(config, 'train', vars(args))
val_cfg = merge_configs(config, 'valid', vars(args))
inputs = [
Input(
[None, config.MODEL.feat_dim, config.MODEL.tscale],
'float32',
name='feat_input')
]
gt_iou_map = Input(
[None, config.MODEL.dscale, config.MODEL.tscale],
'float32',
name='gt_iou_map')
gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
labels = [gt_iou_map, gt_start, gt_end]
# data
train_dataset = BmnDataset(train_cfg, 'train')
val_dataset = BmnDataset(val_cfg, 'valid')
# model
model = BMN(config, args.dynamic)
optim = optimizer(config, parameter_list=model.parameters())
model.prepare(
optimizer=optim,
loss_function=BmnLoss(config),
inputs=inputs,
labels=labels,
device=device)
# if resume weights is given, load resume weights directly
if args.resume is not None:
model.load(args.resume)
model.fit(train_data=train_dataset,
eval_data=val_dataset,
batch_size=train_cfg.TRAIN.batch_size,
epochs=args.epoch,
eval_freq=args.valid_interval,
log_freq=args.log_interval,
save_dir=args.save_dir,
shuffle=train_cfg.TRAIN.use_shuffle,
num_workers=train_cfg.TRAIN.num_workers,
drop_last=True)
if __name__ == "__main__":
args = parse_args()
train_bmn(args)
import os
import sys
import cv2
from paddle.fluid.io import Dataset
def has_valid_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x):
return has_valid_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
class DatasetFolder(Dataset):
"""A generic data loader where the samples are arranged in this way:
root/class_a/1.ext
root/class_a/2.ext
root/class_a/3.ext
root/class_b/123.ext
root/class_b/456.ext
root/class_b/789.ext
Args:
root (string): Root directory path.
loader (callable, optional): A function to load a sample given its path.
extensions (tuple[string], optional): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self,
root,
loader=None,
extensions=None,
transform=None,
target_transform=None,
is_valid_file=None):
self.root = root
if extensions is None:
extensions = IMG_EXTENSIONS
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions,
is_valid_file)
if len(samples) == 0:
raise (RuntimeError(
"Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = cv2_loader if loader is None else loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir),
and class_to_idx is a dictionary.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [
d for d in os.listdir(dir)
if os.path.isdir(os.path.join(dir, d))
]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def cv2_loader(path):
return cv2.imread(path)
# 高级api图像分类
## 数据集准备
在开始训练前,请确保已经下载解压好[ImageNet数据集](http://image-net.org/download),并放在合适的目录下,准备好的数据集的目录结构如下所示:
```bash
/path/to/imagenet
train
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
val
n01440764
xxx.jpg
...
n01443537
xxx.jpg
...
...
```
## 训练
### 单卡训练
执行如下命令进行训练
```bash
python -u main.py --arch resnet50 /path/to/imagenet -d
```
### 多卡训练
执行如下命令进行训练
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 -d /path/to/imagenet
```
## 预测
### 单卡预测
执行如下命令进行预测
```bash
python -u main.py --arch resnet50 -d --evaly-only /path/to/imagenet
```
### 多卡预测
执行如下命令进行多卡预测
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 --evaly-only /path/to/imagenet
```
## 参数说明
* **arch**: 要训练或预测的模型名称
* **device**: 训练使用的设备,'gpu'或'cpu',默认值:'gpu'
* **dynamic**: 是否使用动态图模式训练
* **epoch**: 训练的轮数,默认值:120
* **learning-rate**: 学习率,默认值:0.1
* **batch-size**: 每张卡的batch size,默认值:64
* **output-dir**: 模型文件保存的文件夹,默认值:'output'
* **num-workers**: dataloader的进程数,默认值:4
* **resume**: 恢复训练的模型路径,默认值:None
* **eval-only**: 仅仅进行预测,默认值:False
## 模型
| 模型 | top1 acc | top5 acc |
| --- | --- | --- |
| ResNet50 | 76.28 | 93.04 |
import os
import cv2
import math
import random
import numpy as np
from datasets.folder import DatasetFolder
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i:i + c, j:j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i:i + h, j:j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
return img
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
class ImageNetDataset(DatasetFolder):
def __init__(self, path, mode='train'):
super(ImageNetDataset, self).__init__(path)
self.mode = mode
if self.mode == 'train':
self.transform = compose([
cv2.imread, random_crop_resize, random_flip, normalize_permute
])
else:
self.transform = compose(
[cv2.imread, center_crop_resize, normalize_permute])
def __getitem__(self, idx):
img, label = self.samples[idx]
return self.transform((img, [label]))
def __len__(self):
return len(self.samples)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import sys
sys.path.append('../')
import time
import math
import numpy as np
import models
import paddle.fluid as fluid
from model import CrossEntropy, Input, set_device
from imagenet_dataset import ImageNetDataset
from distributed import DistributedBatchSampler
from paddle.fluid.dygraph.parallel import ParallelEnv
from metrics import Accuracy
from paddle.fluid.io import BatchSampler, DataLoader
def make_optimizer(step_per_epoch, parameter_list=None):
base_lr = FLAGS.lr
momentum = 0.9
weight_decay = 1e-4
boundaries = [step_per_epoch * e for e in [30, 60, 80]]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=5 * step_per_epoch,
start_lr=0.,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=momentum,
regularization=fluid.regularizer.L2Decay(weight_decay),
parameter_list=parameter_list)
return optimizer
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and
not FLAGS.resume)
if FLAGS.resume is not None:
model.load(FLAGS.resume)
inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
train_dataset = ImageNetDataset(
os.path.join(FLAGS.data, 'train'), mode='train')
val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val')
optim = make_optimizer(
np.ceil(
len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks),
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels)
if FLAGS.eval_only:
model.evaluate(
val_dataset,
batch_size=FLAGS.batch_size,
num_workers=FLAGS.num_workers)
return
output_dir = os.path.join(FLAGS.output_dir, FLAGS.arch,
time.strftime('%Y-%m-%d-%H-%M',
time.localtime()))
if ParallelEnv().local_rank == 0 and not os.path.exists(output_dir):
os.makedirs(output_dir)
model.fit(train_dataset,
val_dataset,
batch_size=FLAGS.batch_size,
epochs=FLAGS.epoch,
save_dir=output_dir,
num_workers=FLAGS.num_workers)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument(
'data',
metavar='DIR',
help='path to dataset '
'(should have subdirectories named "train" and "val"')
parser.add_argument(
"--arch", type=str, default='resnet50', help="model name")
parser.add_argument(
"--device", type=str, default='gpu', help="device to run, cpu or gpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=90, type=int, help="number of epoch")
parser.add_argument(
'--lr',
'--learning-rate',
default=0.1,
type=float,
metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch-size", default=64, type=int, help="batch size")
parser.add_argument(
"-n", "--num-workers", default=4, type=int, help="dataloader workers")
parser.add_argument(
"--output-dir", type=str, default='output', help="save dir")
parser.add_argument(
"-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume")
parser.add_argument(
"--eval-only", action='store_true', help="enable dygraph mode")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
......@@ -410,7 +410,8 @@ class StaticGraphAdapter(object):
and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var
new_lr_var = prog.global_block().vars[lr_var.name]
self.model._optimizer._learning_rate_map[prog] = new_lr_var
losses = []
metrics = []
......
from .resnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import shutil
import requests
import tqdm
import hashlib
import time
from paddle.fluid.dygraph.parallel import ParallelEnv
import logging
logger = logging.getLogger(__name__)
__all__ = ['get_weights_path']
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
DOWNLOAD_RETRY_LIMIT = 3
def get_weights_path(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
"""
path, _ = get_path(url, WEIGHTS_HOME, md5sum)
return path
def map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
"""
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)
exist_flag = False
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
exist_flag = True
if ParallelEnv().local_rank == 0:
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
return fullpath, exist_flag
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if ParallelEnv().local_rank == 0:
logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from model import Model
from .download import get_weights_path
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
model_urls = {
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
'0884c9087266496c41c60d14a96f8530')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(x)
# return fluid.layers.relu(x)
class ResNet(Model):
def __init__(self, Block, depth=50, num_classes=1000):
super(ResNet, self).__init__()
layer_config = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
assert depth in layer_config.keys(), \
"supported depth are {} but input layer is {}".format(
layer_config.keys(), depth)
layers = layer_config[depth]
num_in = [64, 256, 512, 1024]
num_out = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.layers = []
for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False
for b in range(num_blocks):
block = Block(
num_channels=num_in[idx] if b == 0 else num_out[idx] * 4,
num_filters=num_out[idx],
stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut)
blocks.append(block)
shortcut = True
layer = self.add_sublayer("layer_{}".format(idx),
Sequential(*blocks))
self.layers.append(layer)
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.fc_input_dim = num_out[-1] * 4 * 1 * 1
self.fc = Linear(
self.fc_input_dim,
num_classes,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained):
model = ResNet(Block, depth)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def resnet50(pretrained=False):
return _resnet('resnet50', BottleneckBlock, 50, pretrained)
def resnet101(pretrained=False):
return _resnet('resnet101', BottleneckBlock, 101, pretrained)
def resnet152(pretrained=False):
return _resnet('resnet152', BottleneckBlock, 152, pretrained)
此差异已折叠。
## Transformer
以下是本例的简要目录结构及说明:
```text
.
├── images # README 文档中的图片
├── utils # 工具包
├── gen_data.sh # 数据生成脚本
├── predict.py # 预测脚本
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
├── model.py # 模型定义文件
└── transformer.yaml # 配置文件
```
## 模型简介
机器翻译(machine translation, MT)是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程,输入为源语言句子,输出为相应的目标语言的句子。
本项目是机器翻译领域主流模型 Transformer 的 PaddlePaddle 实现, 包含模型训练,预测以及使用自定义数据等内容。用户可以基于发布的内容搭建自己的翻译模型。
## 快速开始
### 安装说明
1. paddle安装
本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 下载代码
克隆代码库到本地
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/dygraph/transformer
```
3. 环境依赖
请参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.6/beginners_guide/install/index_cn.html)部分的内容
### 数据准备
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)作为示例提供。运行 `gen_data.sh` 脚本进行 WMT'16 EN-DE 数据集的下载和预处理(时间较长,建议后台运行)。数据处理过程主要包括 Tokenize 和 [BPE 编码(byte-pair encoding)](https://arxiv.org/pdf/1508.07909)。运行成功后,将会生成文件夹 `gen_data`,其目录结构如下:
```text
.
├── wmt16_ende_data # WMT16 英德翻译数据
├── wmt16_ende_data_bpe # BPE 编码的 WMT16 英德翻译数据
├── mosesdecoder # Moses 机器翻译工具集,包含了 Tokenize、BLEU 评估等脚本
└── subword-nmt # BPE 编码的代码
```
另外我们也整理提供了一份处理好的 WMT'16 EN-DE 数据以供[下载](https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz)使用,其中包含词典(`vocab_all.bpe.32000`文件)、训练所需的 BPE 数据(`train.tok.clean.bpe.32000.en-de`文件)、预测所需的 BPE 数据(`newstest2016.tok.bpe.32000.en-de`等文件)和相应的评估预测结果所需的 tokenize 数据(`newstest2016.tok.de`等文件)。
自定义数据:如果需要使用自定义数据,本项目程序中可直接支持的数据格式为制表符 \t 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔。提供以上格式的数据文件(可以分多个part,数据读取支持文件通配符)和相应的词典文件即可直接运行。
### 单机训练
### 单机单卡
以提供的英德翻译数据为例,可以执行以下命令进行模型训练:
```sh
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0
python -u train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096
```
以上命令中传入了训练轮数(`epoch`)和训练数据文件路径(注意请正确设置,支持通配符)等参数,更多参数的使用以及支持的模型超参数可以参见 `transformer.yaml` 配置文件,其中默认提供了 Transformer base model 的配置,如需调整可以在配置文件中更改或通过命令行传入(命令行传入内容将覆盖配置文件中的设置)。可以通过以下命令来训练 Transformer 论文中的 big model:
```sh
# setting visible devices for training
export CUDA_VISIBLE_DEVICES=0
python -u train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
另外,如果在执行训练时若提供了 `save_model`(默认为 trained_models),则每隔一定 iteration 后(通过参数 `save_step` 设置,默认为10000)将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的 `transformer.pdparams``transformer.pdopt` 两个文件),每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出:
```txt
[2019-08-02 15:30:51,656 INFO train.py:262] step_idx: 150100, epoch: 32, batch: 1364, avg loss: 2.880427, normalized loss: 1.504687, ppl: 17.821888, speed: 3.34 step/s
[2019-08-02 15:31:19,824 INFO train.py:262] step_idx: 150200, epoch: 32, batch: 1464, avg loss: 2.955965, normalized loss: 1.580225, ppl: 19.220257, speed: 3.55 step/s
[2019-08-02 15:31:48,151 INFO train.py:262] step_idx: 150300, epoch: 32, batch: 1564, avg loss: 2.951180, normalized loss: 1.575439, ppl: 19.128502, speed: 3.53 step/s
[2019-08-02 15:32:16,401 INFO train.py:262] step_idx: 150400, epoch: 32, batch: 1664, avg loss: 3.027281, normalized loss: 1.651540, ppl: 20.641024, speed: 3.54 step/s
[2019-08-02 15:32:44,764 INFO train.py:262] step_idx: 150500, epoch: 32, batch: 1764, avg loss: 3.069125, normalized loss: 1.693385, ppl: 21.523066, speed: 3.53 step/s
[2019-08-02 15:33:13,199 INFO train.py:262] step_idx: 150600, epoch: 32, batch: 1864, avg loss: 2.869379, normalized loss: 1.493639, ppl: 17.626074, speed: 3.52 step/s
[2019-08-02 15:33:41,601 INFO train.py:262] step_idx: 150700, epoch: 32, batch: 1964, avg loss: 2.980905, normalized loss: 1.605164, ppl: 19.705633, speed: 3.52 step/s
[2019-08-02 15:34:10,079 INFO train.py:262] step_idx: 150800, epoch: 32, batch: 2064, avg loss: 3.047716, normalized loss: 1.671976, ppl: 21.067181, speed: 3.51 step/s
[2019-08-02 15:34:38,598 INFO train.py:262] step_idx: 150900, epoch: 32, batch: 2164, avg loss: 2.956475, normalized loss: 1.580735, ppl: 19.230072, speed: 3.51 step/s
```
也可以使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度较慢。
#### 单机多卡
Paddle动态图支持多进程多卡进行模型训练,启动训练的方式如下:
```sh
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 100 \
--use_cuda True \
--save_step 10000
```
此时,程序会将每个进程的输出log导入到`./mylog`路径下,只有第一个工作进程会保存模型。
```
.
├── mylog
│   ├── workerlog.0
│   ├── workerlog.1
│   ├── workerlog.2
│   ├── workerlog.3
│   ├── workerlog.4
│   ├── workerlog.5
│   ├── workerlog.6
│   └── workerlog.7
```
### 模型推断
以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
```sh
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u predict.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt
```
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
```sh
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u predict.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params trained_params/step_100000 \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--n_head 16 \
--d_model 1024 \
--d_inner_hid 4096 \
--prepostprocess_dropout 0.3
```
### 模型评估
预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标):
```sh
# 还原 predict.txt 中的预测结果为 tokenize 后的数据
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
# 若无 BLEU 评估工具,需先进行下载
# git clone https://github.com/moses-smt/mosesdecoder.git
# 以英德翻译 newstest2014 测试数据为例
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2014.tok.de < predict.tok.txt
```
可以看到类似如下的结果:
```
BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078)
```
使用本项目中提供的内容,英德翻译 base model 和 big model 八卡训练 100K 个 iteration 后测试有大约如下的 BLEU 值:
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| Base | 26.35 | 29.07 | 33.30 |
| Big | 27.07 | 30.09 | 34.38 |
### 预训练模型
我们这里提供了对应有以上 BLEU 值的 [base model](https://transformer-res.bj.bcebos.com/base_model_dygraph.tar.gz)[big model](https://transformer-res.bj.bcebos.com/big_model_dygraph.tar.gz) 的模型参数提供下载使用(注意,模型使用了提供下载的数据进行训练和测试)。
## 进阶使用
### 背景介绍
Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构,其完全使用注意力(Attention)机制来实现序列到序列的建模[1]。
相较于此前 Seq2Seq 模型中广泛使用的循环神经网络(Recurrent Neural Network, RNN),使用(Self)Attention 进行输入序列到输出序列的变换主要具有以下优势:
- 计算复杂度小
- 特征维度为 d 、长度为 n 的序列,在 RNN 中计算复杂度为 `O(n * d * d)` (n 个时间步,每个时间步计算 d 维的矩阵向量乘法),在 Self-Attention 中计算复杂度为 `O(n * n * d)` (n 个时间步两两计算 d 维的向量点积或其他相关度函数),n 通常要小于 d 。
- 计算并行度高
- RNN 中当前时间步的计算要依赖前一个时间步的计算结果;Self-Attention 中各时间步的计算只依赖输入不依赖之前时间步输出,各时间步可以完全并行。
- 容易学习长程依赖(long-range dependencies)
- RNN 中相距为 n 的两个位置间的关联需要 n 步才能建立;Self-Attention 中任何两个位置都直接相连;路径越短信号传播越容易。
Transformer 中引入使用的基于 Self-Attention 的序列建模模块结构,已被广泛应用在 Bert [2]等语义表示模型中,取得了显著效果。
### 模型概览
Transformer 同样使用了 Seq2Seq 模型中典型的编码器-解码器(Encoder-Decoder)的框架结构,整体网络结构如图1所示。
<p align="center">
<img src="images/transformer_network.png" height=400 hspace='10'/> <br />
图 1. Transformer 网络结构图
</p>
可以看到,和以往 Seq2Seq 模型不同,Transformer 的 Encoder 和 Decoder 中不再使用 RNN 的结构。
### 模型特点
Transformer 中的 Encoder 由若干相同的 layer 堆叠组成,每个 layer 主要由多头注意力(Multi-Head Attention)和全连接的前馈(Feed-Forward)网络这两个 sub-layer 构成。
- Multi-Head Attention 在这里用于实现 Self-Attention,相比于简单的 Attention 机制,其将输入进行多路线性变换后分别计算 Attention 的结果,并将所有结果拼接后再次进行线性变换作为输出。参见图2,其中 Attention 使用的是点积(Dot-Product),并在点积后进行了 scale 的处理以避免因点积结果过大进入 softmax 的饱和区域。
- Feed-Forward 网络会对序列中的每个位置进行相同的计算(Position-wise),其采用的是两次线性变换中间加以 ReLU 激活的结构。
此外,每个 sub-layer 后还施以 Residual Connection [3]和 Layer Normalization [4]来促进梯度传播和模型收敛。
<p align="center">
<img src="images/multi_head_attention.png" height=300 hspace='10'/> <br />
图 2. Multi-Head Attention
</p>
Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 layer ,在组成 Decoder 的 layer 中还多了一个 Multi-Head Attention 的 sub-layer 来实现对 Encoder 输出的 Attention,这个 Encoder-Decoder Attention 在其他 Seq2Seq 模型中也是存在的。
## FAQ
**Q:** 预测结果中样本数少于输入的样本数是什么原因
**A:** 若样本中最大长度超过 `transformer.yaml``max_length` 的默认设置,请注意运行时增大 `--max_length` 的设置,否则超长样本将被过滤。
**Q:** 预测时最大长度超过了训练时的最大长度怎么办
**A:** 由于训练时 `max_length` 的设置决定了保存模型 position encoding 的大小,若预测时长度超过 `max_length`,请调大该值,会重新生成更大的 position encoding 表。
## 参考文献
1. Vaswani A, Shazeer N, Parmar N, et al. [Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)[C]//Advances in Neural Information Processing Systems. 2017: 6000-6010.
2. Devlin J, Chang M W, Lee K, et al. [Bert: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805)[J]. arXiv preprint arXiv:1810.04805, 2018.
3. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
4. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
5. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015.
## 作者
- [guochengCS](https://github.com/guoshengCS)
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.io import DataLoader
from paddle.fluid.layers.utils import flatten
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
from model import Input, set_device
from reader import prepare_infer_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import InferTransformer, position_encoding_init
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def do_predict(args):
device = set_device("gpu" if args.use_cuda else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None
inputs = [
Input(
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
# define data
dataset = Seq2SeqDataset(
fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
trg_idx2word = Seq2SeqDataset.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
max_length=args.max_length)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None
if fluid.in_dygraph_mode() else [x.forward() for x in inputs],
collate_fn=partial(
prepare_infer_input, src_pad_idx=args.eos_idx, n_head=args.n_head),
num_workers=0,
return_list=True)
# define model
transformer = InferTransformer(
args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
transformer.prepare(inputs=inputs)
# load the trained model
assert args.init_from_params, (
"Please set init_from_params to load the infer model.")
transformer.load(os.path.join(args.init_from_params, "transformer"))
# TODO: use model.predict when support variant length
f = open(args.output_file, "wb")
for data in data_loader():
finished_seq = transformer.test(inputs=flatten(data))[0]
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n"
f.write(sequence)
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_predict(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import six
import os
import tarfile
import itertools
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Put all padded data needed by training into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
lbl_word = lbl_word.reshape(-1, 1)
lbl_weight = lbl_weight.reshape(-1, 1)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]
return data_inputs
def prepare_infer_input(insts, src_pad_idx, n_head):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias
]
return data_inputs
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst))
for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones(
(inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class Seq2SeqDataset(Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
tar_fname=None,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
only_src=False):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter.encode("utf8")
token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, tar_fname)
def load_src_trg_ids(self, fpattern, tar_fname):
converters = [
Converter(vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
]
if not self._only_src:
converters.append(
Converter(vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
converters = ComposedConverter(converters)
self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0])
lens = [len(src_trg_ids[0])]
if not self._only_src:
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
assert len(fpaths) > 0, "no matching file to the provided data path"
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f:
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(b"\n")
else:
word_dict[line.strip(b"\n")] = idx
return word_dict
def get_vocab_summary(self):
return len(self._src_vocab), len(
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]
) if not self._only_src else self._src_seq_ids[idx]
def __len__(self):
return len(self._sample_infos)
class Seq2SeqBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
seed=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id
def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos,
key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size *
self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
# for multi-device
for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank:
batch_indices = [info.i for info in batch]
yield batch_indices
if self._local_rank > len(batches) % self._nranks:
yield batch_indices
def __len__(self):
return 100
python -u train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny \
--validation_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 1 \
--use_cuda True \
--random_seed 1000 \
--save_step 10 \
--eager_run True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo `date`
python -u predict.py \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 64 \
--init_from_params base_model_dygraph/step_100000/ \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--eager_run True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo `date`
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.io import DataLoader
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
from model import Input, set_device
from callbacks import ProgBarLogger
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
class LoggerCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
super(LoggerCallback, self).__init__(log_freq, verbose)
# TODO: wrap these override function to simplify
self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None):
super(LoggerCallback, self).on_train_begin(logs)
self.train_metrics += ["normalized loss", "ppl"]
def on_train_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None):
super(LoggerCallback, self).on_eval_begin(logs)
self.eval_metrics += ["normalized loss", "ppl"]
def on_eval_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_eval_batch_end(step, logs)
def do_train(args):
device = set_device("gpu" if args.use_cuda else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None
# set seed for CE
random_seed = eval(str(args.random_seed))
if random_seed is not None:
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
# define inputs
inputs = [
Input(
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, None], "int64", name="trg_word"),
Input(
[None, None], "int64", name="trg_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
labels = [
Input(
[None, 1], "int64", name="label"),
Input(
[None, 1], "float32", name="weight"),
]
# def dataloader
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None if fluid.in_dygraph_mode() else
[x.forward() for x in inputs + labels],
collate_fn=partial(
prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout,
args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
args.weight_sharing, args.bos_idx, args.eos_idx)
transformer.prepare(
fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(args.d_model,
args.warmup_steps),
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
transformer.load(
os.path.join(args.init_from_checkpoint, "transformer"))
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
transformer.load(
os.path.join(args.init_from_pretrain_model, "transformer"),
reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# model train
transformer.fit(train_data=train_loader,
eval_data=eval_loader,
epochs=1,
eval_freq=1,
save_freq=1,
verbose=2,
callbacks=[
LoggerCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_train(args)
此差异已折叠。
# used for continuous evaluation
enable_ce: False
eager_run: False
# The frequency to save trained models when training.
save_step: 10000
# The frequency to fetch and print output when training.
print_step: 100
# path of the checkpoint, to resume the previous training
init_from_checkpoint: ""
# path of the pretrain model, to better solve the current task
init_from_pretrain_model: ""
# path of trained parameter, to make prediction
init_from_params: "trained_params/step_100000/"
# the directory for saving model
save_model: "trained_models"
# the directory for saving inference model.
inference_model_dir: "infer_model"
# Set seed for CE or debug
random_seed: None
# The pattern to match training data files.
training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match validation data files.
validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
# The pattern to match test data files.
predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
output_file: "predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The path of vocabulary file of target language.
trg_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token: ["<s>", "<e>", "<unk>"]
# whether to use cuda
use_cuda: True
# args for reader, see reader.py for details
token_delimiter: " "
use_token_batch: True
pool_size: 200000
sort_type: "pool"
shuffle: True
shuffle_batch: True
batch_size: 4096
# Hyparams for training:
# the number of epoches for training
epoch: 30
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate: 2.0
beta1: 0.9
beta2: 0.997
eps: 1e-9
# the parameters for learning rate scheduling.
warmup_steps: 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps: 0.1
# Hyparams for generation:
# the parameters for beam search.
beam_size: 5
max_out_len: 256
# the number of decoded sentences to output.
n_best: 1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size: 10000
# size of target word dictionay
trg_vocab_size: 10000
# index for <bos> token
bos_idx: 0
# index for <eos> token
eos_idx: 1
# index for <unk> token
unk_idx: 2
# max length of sequences deciding the size of position encoding table.
max_length: 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model: 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid: 2048
# the dimension that keys are projected to for dot-product attention.
d_key: 64
# the dimension that values are projected to for dot-product attention.
d_value: 64
# number of head used in multi-head attention.
n_head: 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer: 6
# dropout rates of different modules.
prepostprocess_dropout: 0.1
attention_dropout: 0.1
relu_dropout: 0.1
# to process before each sub-layer
preprocess_cmd: "n" # layer normalization
# to process after each sub-layer
postprocess_cmd: "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing: True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version']
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册