diff --git a/bmn/BMN.png b/bmn/BMN.png
new file mode 100644
index 0000000000000000000000000000000000000000..46b9743cf5b5e79ada93ebb6b61da96967f350a4
Binary files /dev/null and b/bmn/BMN.png differ
diff --git a/bmn/README.md b/bmn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dbf593512c07da9645418b92b2a6819c9be13fa5
--- /dev/null
+++ b/bmn/README.md
@@ -0,0 +1,120 @@
+# BMN 视频动作定位模型高层API实现
+
+---
+## 内容
+
+- [模型简介](#模型简介)
+- [代码结构](#代码结构)
+- [数据准备](#数据准备)
+- [模型训练](#模型训练)
+- [模型评估](#模型评估)
+- [模型推断](#模型推断)
+- [参考论文](#参考论文)
+
+
+## 模型简介
+
+BMN模型是百度自研,2019年ActivityNet夺冠方案,为视频动作定位问题中proposal的生成提供高效的解决方案,在PaddlePaddle上首次开源。此模型引入边界匹配(Boundary-Matching, BM)机制来评估proposal的置信度,按照proposal开始边界的位置及其长度将所有可能存在的proposal组合成一个二维的BM置信度图,图中每个点的数值代表其所对应的proposal的置信度分数。网络由三个模块组成,基础模块作为主干网络处理输入的特征序列,TEM模块预测每一个时序位置属于动作开始、动作结束的概率,PEM模块生成BM置信度图。
+
+
+
+BMN Overview
+
+
+
+## 代码结构
+```
+├── bmn.yaml # 网络配置文件,快速配置参数
+├── run.sh # 快速运行脚本,可直接开始多卡训练
+├── train.py # 训练代码,训练网络
+├── eval.py # 评估代码,评估网络性能
+├── predict.py # 预测代码,针对任意输入预测结果
+├── bmn_model.py # 网络结构与损失函数定义
+├── bmn_metric.py # 精度评估方法定义
+├── reader.py # 数据reader,构造Dataset和Dataloader
+├── bmn_utils.py # 模型细节相关代码
+├── config_utils.py # 配置细节相关代码
+├── eval_anet_prop.py # 计算精度评估指标
+└── infer.list # 推断文件列表
+```
+
+
+## 数据准备
+
+BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理好的视频特征,请下载[bmn\_feat](https://paddlemodels.bj.bcebos.com/video_detection/bmn_feat.tar.gz)数据后解压,同时相应的修改bmn.yaml中的特征路径feat\_path。对应的标签文件请下载[label](https://paddlemodels.bj.bcebos.com/video_detection/activitynet_1.3_annotations.json)并修改bmn.yaml中的标签文件路径anno\_file。
+
+
+## 模型训练
+
+数据准备完成后,可通过如下两种方式启动训练:
+
+默认使用4卡训练,启动方式如下:
+
+ bash run.sh
+
+若使用单卡训练,启动方式如下:
+
+ export CUDA_VISIBLE_DEVICES=0
+ python train.py
+
+- 代码运行需要先安装pandas
+
+- 从头开始训练,使用上述启动命令行或者脚本程序即可启动训练,不需要用到预训练模型
+
+- 单卡训练时,请将配置文件中的batch_size调整为16
+
+**训练策略:**
+
+* 采用Adam优化器,初始learning\_rate=0.001
+* 权重衰减系数为1e-4
+* 学习率在迭代次数达到4200的时候做一次衰减,衰减系数为0.1
+
+
+## 模型评估
+
+训练完成后,可通过如下方式进行模型评估:
+
+ python eval.py --weights=$PATH_TO_WEIGHTS
+
+- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
+
+- 上述程序会将运行结果保存在output/EVAL/BMN\_results文件夹下,测试结果保存在evaluate\_results/bmn\_results\_validation.json文件中。
+
+- 注:评估时可能会出现loss为nan的情况。这是由于评估时用的是单个样本,可能存在没有iou>0.6的样本,所以为nan,对最终的评估结果没有影响。
+
+
+使用ActivityNet官方提供的测试脚本,即可计算AR@AN和AUC。具体计算过程如下:
+
+- ActivityNet数据集的具体使用说明可以参考其[官方网站](http://activity-net.org)
+
+- 下载指标评估代码,请从[ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)下载,将Evaluation文件夹拷贝至models/dygraph/bmn目录下。(注:由于第三方评估代码不支持python3,此处建议使用python2进行评估;若使用python3,print函数需要添加括号,请对Evaluation目录下的.py文件做相应修改。)
+
+- 请下载[activity\_net\_1\_3\_new.json](https://paddlemodels.bj.bcebos.com/video_detection/activity_net_1_3_new.json)文件,并将其放置在models/dygraph/bmn/Evaluation/data目录下,相较于原始的activity\_net.v1-3.min.json文件,我们过滤了其中一些失效的视频条目。
+
+- 计算精度指标
+
+ ```python eval_anet_prop.py```
+
+
+在ActivityNet1.3数据集下评估精度如下:
+
+| AR@1 | AR@5 | AR@10 | AR@100 | AUC |
+| :---: | :---: | :---: | :---: | :---: |
+| 33.46 | 49.25 | 56.25 | 75.40 | 67.16% |
+
+
+## 模型推断
+
+可通过如下方式启动模型推断:
+
+ python predict.py --weights=$PATH_TO_WEIGHTS \
+ --filelist=$FILELIST
+
+- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,如果不设置,将使用默认参数文件checkpoint/final.pdparams。
+
+- 上述程序会将运行结果保存在output/INFER/BMN\_results文件夹下,测试结果保存在predict\_results/bmn\_results\_test.json文件中。
+
+
+## 参考论文
+
+- [BMN: Boundary-Matching Network for Temporal Action Proposal Generation](https://arxiv.org/abs/1907.09702), Tianwei Lin, Xiao Liu, Xin Li, Errui Ding, Shilei Wen.
diff --git a/bmn/bmn.yaml b/bmn/bmn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..da50ea4f7c654d40fbf2498863cb7e87664fe55a
--- /dev/null
+++ b/bmn/bmn.yaml
@@ -0,0 +1,50 @@
+MODEL:
+ name: "BMN"
+ tscale: 100
+ dscale: 100
+ feat_dim: 400
+ prop_boundary_ratio: 0.5
+ num_sample: 32
+ num_sample_perbin: 3
+ anno_file: "./activitynet_1.3_annotations.json"
+ feat_path: './fix_feat_100'
+
+TRAIN:
+ subset: "train"
+ epoch: 9
+ batch_size: 4
+ num_workers: 4
+ use_shuffle: True
+ device: "gpu"
+ num_gpus: 4
+ learning_rate: 0.001
+ learning_rate_decay: 0.1
+ lr_decay_iter: 4200
+ l2_weight_decay: 1e-4
+
+VALID:
+ subset: "validation"
+
+TEST:
+ subset: "validation"
+ batch_size: 1
+ num_workers: 1
+ use_buffer: False
+ snms_alpha: 0.001
+ snms_t1: 0.5
+ snms_t2: 0.9
+ output_path: "output/EVAL/BMN_results"
+ result_path: "evaluate_results"
+
+INFER:
+ subset: "test"
+ batch_size: 1
+ num_workers: 1
+ use_buffer: False
+ snms_alpha: 0.4
+ snms_t1: 0.5
+ snms_t2: 0.9
+ filelist: './infer.list'
+ output_path: "output/INFER/BMN_results"
+ result_path: "predict_results"
+
diff --git a/bmn/bmn_metric.py b/bmn/bmn_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..a19f87c6b42b8737ddeb52c3a330f59dcc932004
--- /dev/null
+++ b/bmn/bmn_metric.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import numpy as np
+import pandas as pd
+import os
+import sys
+import json
+
+sys.path.append('../')
+
+from metrics import Metric
+from bmn_utils import boundary_choose, bmn_post_processing
+
+
+class BmnMetric(Metric):
+ """
+ only support update with batch_size=1
+ """
+
+ def __init__(self, cfg, mode):
+ super(BmnMetric, self).__init__()
+ self.cfg = cfg
+ self.mode = mode
+ #get video_dict and video_list
+ if self.mode == 'test':
+ self.get_test_dataset_dict()
+ elif self.mode == 'infer':
+ self.get_infer_dataset_dict()
+
+ def add_metric_op(self, preds, label):
+ pred_bm, pred_start, pred_en = preds
+ video_index = label[-1]
+ return [pred_bm, pred_start, pred_en, video_index] #return list
+
+ def update(self, pred_bm, pred_start, pred_end, fid):
+ # generate proposals
+ pred_start = pred_start[0]
+ pred_end = pred_end[0]
+ fid = fid[0]
+
+ if self.mode == 'infer':
+ output_path = self.cfg.INFER.output_path
+ else:
+ output_path = self.cfg.TEST.output_path
+ tscale = self.cfg.MODEL.tscale
+ dscale = self.cfg.MODEL.dscale
+ snippet_xmins = [1.0 / tscale * i for i in range(tscale)]
+ snippet_xmaxs = [1.0 / tscale * i for i in range(1, tscale + 1)]
+ cols = ["xmin", "xmax", "score"]
+
+ video_name = self.video_list[fid]
+ pred_bm = pred_bm[0, 0, :, :] * pred_bm[0, 1, :, :]
+ start_mask = boundary_choose(pred_start)
+ start_mask[0] = 1.
+ end_mask = boundary_choose(pred_end)
+ end_mask[-1] = 1.
+ score_vector_list = []
+ for idx in range(dscale):
+ for jdx in range(tscale):
+ start_index = jdx
+ end_index = start_index + idx
+ if end_index < tscale and start_mask[
+ start_index] == 1 and end_mask[end_index] == 1:
+ xmin = snippet_xmins[start_index]
+ xmax = snippet_xmaxs[end_index]
+ xmin_score = pred_start[start_index]
+ xmax_score = pred_end[end_index]
+ bm_score = pred_bm[idx, jdx]
+ conf_score = xmin_score * xmax_score * bm_score
+ score_vector_list.append([xmin, xmax, conf_score])
+
+ score_vector_list = np.stack(score_vector_list)
+ video_df = pd.DataFrame(score_vector_list, columns=cols)
+ video_df.to_csv(
+ os.path.join(output_path, "%s.csv" % video_name), index=False)
+
+ return 0 # result has saved in output path
+
+ def accumulate(self):
+ return 'post_processing is required...' # required method
+
+ def reset(self):
+ print("Post_processing....This may take a while")
+ if self.mode == 'test':
+ bmn_post_processing(self.video_dict, self.cfg.TEST.subset,
+ self.cfg.TEST.output_path,
+ self.cfg.TEST.result_path)
+ elif self.mode == 'infer':
+ bmn_post_processing(self.video_dict, self.cfg.INFER.subset,
+ self.cfg.INFER.output_path,
+ self.cfg.INFER.result_path)
+
+ def name(self):
+ return 'bmn_metric'
+
+ def get_test_dataset_dict(self):
+ anno_file = self.cfg.MODEL.anno_file
+ annos = json.load(open(anno_file))
+ subset = self.cfg.TEST.subset
+ self.video_dict = {}
+ for video_name in annos.keys():
+ video_subset = annos[video_name]["subset"]
+ if subset in video_subset:
+ self.video_dict[video_name] = annos[video_name]
+ self.video_list = list(self.video_dict.keys())
+ self.video_list.sort()
+
+ def get_infer_dataset_dict(self):
+ file_list = self.cfg.INFER.filelist
+ annos = json.load(open(file_list))
+ self.video_dict = {}
+ for video_name in annos.keys():
+ self.video_dict[video_name] = annos[video_name]
+ self.video_list = list(self.video_dict.keys())
+ self.video_list.sort()
diff --git a/bmn/bmn_model.py b/bmn/bmn_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfde7bcc5cdaac8aa5ea3c069f580308b49ec01f
--- /dev/null
+++ b/bmn/bmn_model.py
@@ -0,0 +1,364 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle.fluid as fluid
+from paddle.fluid import ParamAttr
+import numpy as np
+import math
+
+from bmn_utils import get_interp1d_mask
+from model import Model, Loss
+
+DATATYPE = 'float32'
+
+
+# Net
+class Conv1D(fluid.dygraph.Layer):
+ def __init__(self,
+ prefix,
+ num_channels=256,
+ num_filters=256,
+ size_k=3,
+ padding=1,
+ groups=1,
+ act="relu"):
+ super(Conv1D, self).__init__()
+ fan_in = num_channels * size_k * 1
+ k = 1. / math.sqrt(fan_in)
+ param_attr = ParamAttr(
+ name=prefix + "_w",
+ initializer=fluid.initializer.Uniform(
+ low=-k, high=k))
+ bias_attr = ParamAttr(
+ name=prefix + "_b",
+ initializer=fluid.initializer.Uniform(
+ low=-k, high=k))
+
+ self._conv2d = fluid.dygraph.Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=(1, size_k),
+ stride=1,
+ padding=(0, padding),
+ groups=groups,
+ act=act,
+ param_attr=param_attr,
+ bias_attr=bias_attr)
+
+ def forward(self, x):
+ x = fluid.layers.unsqueeze(input=x, axes=[2])
+ x = self._conv2d(x)
+ x = fluid.layers.squeeze(input=x, axes=[2])
+ return x
+
+
+class BMN(Model):
+ def __init__(self, cfg, is_dygraph=True):
+ super(BMN, self).__init__()
+
+ #init config
+ self.tscale = cfg.MODEL.tscale
+ self.dscale = cfg.MODEL.dscale
+ self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
+ self.num_sample = cfg.MODEL.num_sample
+ self.num_sample_perbin = cfg.MODEL.num_sample_perbin
+ self.is_dygraph = is_dygraph
+
+ self.hidden_dim_1d = 256
+ self.hidden_dim_2d = 128
+ self.hidden_dim_3d = 512
+
+ # Base Module
+ self.b_conv1 = Conv1D(
+ prefix="Base_1",
+ num_channels=400,
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.b_conv2 = Conv1D(
+ prefix="Base_2",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+
+ # Temporal Evaluation Module
+ self.ts_conv1 = Conv1D(
+ prefix="TEM_s1",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.ts_conv2 = Conv1D(
+ prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid")
+ self.te_conv1 = Conv1D(
+ prefix="TEM_e1",
+ num_filters=self.hidden_dim_1d,
+ size_k=3,
+ padding=1,
+ groups=4,
+ act="relu")
+ self.te_conv2 = Conv1D(
+ prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid")
+
+ #Proposal Evaluation Module
+ self.p_conv1 = Conv1D(
+ prefix="PEM_1d",
+ num_filters=self.hidden_dim_2d,
+ size_k=3,
+ padding=1,
+ act="relu")
+
+ # init to speed up
+ sample_mask_array = get_interp1d_mask(
+ self.tscale, self.dscale, self.prop_boundary_ratio,
+ self.num_sample, self.num_sample_perbin)
+ if self.is_dygraph:
+ self.sample_mask = fluid.dygraph.base.to_variable(
+ sample_mask_array)
+ else: # static
+ self.sample_mask = fluid.layers.create_parameter(
+ shape=[
+ self.tscale, self.num_sample * self.dscale * self.tscale
+ ],
+ dtype=DATATYPE,
+ attr=fluid.ParamAttr(
+ name="sample_mask", trainable=False),
+ default_initializer=fluid.initializer.NumpyArrayInitializer(
+ sample_mask_array))
+
+ self.sample_mask.stop_gradient = True
+
+ self.p_conv3d1 = fluid.dygraph.Conv3D(
+ num_channels=128,
+ num_filters=self.hidden_dim_3d,
+ filter_size=(self.num_sample, 1, 1),
+ stride=(self.num_sample, 1, 1),
+ padding=0,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_3d1_w"),
+ bias_attr=ParamAttr(name="PEM_3d1_b"))
+
+ self.p_conv2d1 = fluid.dygraph.Conv2D(
+ num_channels=512,
+ num_filters=self.hidden_dim_2d,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d1_w"),
+ bias_attr=ParamAttr(name="PEM_2d1_b"))
+ self.p_conv2d2 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=self.hidden_dim_2d,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d2_w"),
+ bias_attr=ParamAttr(name="PEM_2d2_b"))
+ self.p_conv2d3 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=self.hidden_dim_2d,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ act="relu",
+ param_attr=ParamAttr(name="PEM_2d3_w"),
+ bias_attr=ParamAttr(name="PEM_2d3_b"))
+ self.p_conv2d4 = fluid.dygraph.Conv2D(
+ num_channels=128,
+ num_filters=2,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ act="sigmoid",
+ param_attr=ParamAttr(name="PEM_2d4_w"),
+ bias_attr=ParamAttr(name="PEM_2d4_b"))
+
+ def forward(self, x):
+ #Base Module
+ x = self.b_conv1(x)
+ x = self.b_conv2(x)
+
+ #TEM
+ xs = self.ts_conv1(x)
+ xs = self.ts_conv2(xs)
+ xs = fluid.layers.squeeze(xs, axes=[1])
+ xe = self.te_conv1(x)
+ xe = self.te_conv2(xe)
+ xe = fluid.layers.squeeze(xe, axes=[1])
+
+ #PEM
+ xp = self.p_conv1(x)
+ #BM layer
+ xp = fluid.layers.matmul(xp, self.sample_mask)
+ xp = fluid.layers.reshape(
+ xp, shape=[0, 0, -1, self.dscale, self.tscale])
+
+ xp = self.p_conv3d1(xp)
+ xp = fluid.layers.squeeze(xp, axes=[2])
+ xp = self.p_conv2d1(xp)
+ xp = self.p_conv2d2(xp)
+ xp = self.p_conv2d3(xp)
+ xp = self.p_conv2d4(xp)
+ return xp, xs, xe
+
+
+class BmnLoss(Loss):
+ def __init__(self, cfg):
+ super(BmnLoss, self).__init__()
+ self.cfg = cfg
+
+ def _get_mask(self):
+ dscale = self.cfg.MODEL.dscale
+ tscale = self.cfg.MODEL.tscale
+ bm_mask = []
+ for idx in range(dscale):
+ mask_vector = [1 for i in range(tscale - idx)
+ ] + [0 for i in range(idx)]
+ bm_mask.append(mask_vector)
+ bm_mask = np.array(bm_mask, dtype=np.float32)
+ self_bm_mask = fluid.layers.create_global_var(
+ shape=[dscale, tscale], value=0, dtype=DATATYPE, persistable=True)
+ fluid.layers.assign(bm_mask, self_bm_mask)
+ self_bm_mask.stop_gradient = True
+ return self_bm_mask
+
+ def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end):
+ def bi_loss(pred_score, gt_label):
+ pred_score = fluid.layers.reshape(
+ x=pred_score, shape=[-1], inplace=False)
+ gt_label = fluid.layers.reshape(
+ x=gt_label, shape=[-1], inplace=False)
+ gt_label.stop_gradient = True
+ pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE)
+ num_entries = fluid.layers.cast(
+ fluid.layers.shape(pmask), dtype=DATATYPE)
+ num_positive = fluid.layers.cast(
+ fluid.layers.reduce_sum(pmask), dtype=DATATYPE)
+ ratio = num_entries / num_positive
+ coef_0 = 0.5 * ratio / (ratio - 1)
+ coef_1 = 0.5 * ratio
+ epsilon = 0.000001
+ temp = fluid.layers.log(pred_score + epsilon)
+ loss_pos = fluid.layers.elementwise_mul(
+ fluid.layers.log(pred_score + epsilon), pmask)
+ loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos)
+ loss_neg = fluid.layers.elementwise_mul(
+ fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask))
+ loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg)
+ loss = -1 * (loss_pos + loss_neg)
+ return loss
+
+ loss_start = bi_loss(pred_start, gt_start)
+ loss_end = bi_loss(pred_end, gt_end)
+ loss = loss_start + loss_end
+ return loss
+
+ def pem_reg_loss_func(self, pred_score, gt_iou_map, mask):
+
+ gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
+
+ u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE)
+ u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3)
+ u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE)
+ u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.)
+ u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE)
+ u_lmask = fluid.layers.elementwise_mul(u_lmask, mask)
+
+ num_h = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE)
+ num_m = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE)
+ num_l = fluid.layers.cast(
+ fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE)
+
+ r_m = num_h / num_m
+ u_smmask = fluid.layers.uniform_random(
+ shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
+ dtype=DATATYPE,
+ min=0.0,
+ max=1.0)
+ u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask)
+ u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE)
+
+ r_l = num_h / num_l
+ u_slmask = fluid.layers.uniform_random(
+ shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]],
+ dtype=DATATYPE,
+ min=0.0,
+ max=1.0)
+ u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask)
+ u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE)
+
+ weights = u_hmask + u_smmask + u_slmask
+ weights.stop_gradient = True
+ loss = fluid.layers.square_error_cost(pred_score, gt_iou_map)
+ loss = fluid.layers.elementwise_mul(loss, weights)
+ loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum(
+ weights)
+
+ return loss
+
+ def pem_cls_loss_func(self, pred_score, gt_iou_map, mask):
+ gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask)
+ gt_iou_map.stop_gradient = True
+ pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE)
+ nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE)
+ nmask = fluid.layers.elementwise_mul(nmask, mask)
+
+ num_positive = fluid.layers.reduce_sum(pmask)
+ num_entries = num_positive + fluid.layers.reduce_sum(nmask)
+ ratio = num_entries / num_positive
+ coef_0 = 0.5 * ratio / (ratio - 1)
+ coef_1 = 0.5 * ratio
+ epsilon = 0.000001
+ loss_pos = fluid.layers.elementwise_mul(
+ fluid.layers.log(pred_score + epsilon), pmask)
+ loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos)
+ loss_neg = fluid.layers.elementwise_mul(
+ fluid.layers.log(1.0 - pred_score + epsilon), nmask)
+ loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg)
+ loss = -1 * (loss_pos + loss_neg) / num_entries
+ return loss
+
+ def forward(self, outputs, labels):
+ pred_bm, pred_start, pred_end = outputs
+ if len(labels) == 3:
+ gt_iou_map, gt_start, gt_end = labels
+ elif len(labels) == 4: # video_index used in eval mode
+ gt_iou_map, gt_start, gt_end, video_index = labels
+ pred_bm_reg = fluid.layers.squeeze(
+ fluid.layers.slice(
+ pred_bm, axes=[1], starts=[0], ends=[1]),
+ axes=[1])
+ pred_bm_cls = fluid.layers.squeeze(
+ fluid.layers.slice(
+ pred_bm, axes=[1], starts=[1], ends=[2]),
+ axes=[1])
+
+ bm_mask = self._get_mask()
+
+ pem_reg_loss = self.pem_reg_loss_func(pred_bm_reg, gt_iou_map, bm_mask)
+ pem_cls_loss = self.pem_cls_loss_func(pred_bm_cls, gt_iou_map, bm_mask)
+
+ tem_loss = self.tem_loss_func(pred_start, pred_end, gt_start, gt_end)
+
+ loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
+ return loss
diff --git a/bmn/bmn_utils.py b/bmn/bmn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..06812e636fdaf6ccc419ca58151402ab50082112
--- /dev/null
+++ b/bmn/bmn_utils.py
@@ -0,0 +1,217 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import numpy as np
+import pandas as pd
+import multiprocessing as mp
+import json
+import os
+import math
+
+
+def iou_with_anchors(anchors_min, anchors_max, box_min, box_max):
+ """Compute jaccard score between a box and the anchors.
+ """
+ len_anchors = anchors_max - anchors_min
+ int_xmin = np.maximum(anchors_min, box_min)
+ int_xmax = np.minimum(anchors_max, box_max)
+ inter_len = np.maximum(int_xmax - int_xmin, 0.)
+ union_len = len_anchors - inter_len + box_max - box_min
+ jaccard = np.divide(inter_len, union_len)
+ return jaccard
+
+
+def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max):
+ """Compute intersection between score a box and the anchors.
+ """
+ len_anchors = anchors_max - anchors_min
+ int_xmin = np.maximum(anchors_min, box_min)
+ int_xmax = np.minimum(anchors_max, box_max)
+ inter_len = np.maximum(int_xmax - int_xmin, 0.)
+ scores = np.divide(inter_len, len_anchors)
+ return scores
+
+
+def boundary_choose(score_list):
+ max_score = max(score_list)
+ mask_high = (score_list > max_score * 0.5)
+ score_list = list(score_list)
+ score_middle = np.array([0.0] + score_list + [0.0])
+ score_front = np.array([0.0, 0.0] + score_list)
+ score_back = np.array(score_list + [0.0, 0.0])
+ mask_peak = ((score_middle > score_front) & (score_middle > score_back))
+ mask_peak = mask_peak[1:-1]
+ mask = (mask_high | mask_peak).astype('float32')
+ return mask
+
+
+def soft_nms(df, alpha, t1, t2):
+ '''
+ df: proposals generated by network;
+ alpha: alpha value of Gaussian decaying function;
+ t1, t2: threshold for soft nms.
+ '''
+ df = df.sort_values(by="score", ascending=False)
+ tstart = list(df.xmin.values[:])
+ tend = list(df.xmax.values[:])
+ tscore = list(df.score.values[:])
+
+ rstart = []
+ rend = []
+ rscore = []
+
+ while len(tscore) > 1 and len(rscore) < 101:
+ max_index = tscore.index(max(tscore))
+ tmp_iou_list = iou_with_anchors(
+ np.array(tstart),
+ np.array(tend), tstart[max_index], tend[max_index])
+ for idx in range(0, len(tscore)):
+ if idx != max_index:
+ tmp_iou = tmp_iou_list[idx]
+ tmp_width = tend[max_index] - tstart[max_index]
+ if tmp_iou > t1 + (t2 - t1) * tmp_width:
+ tscore[idx] = tscore[idx] * np.exp(-np.square(tmp_iou) /
+ alpha)
+
+ rstart.append(tstart[max_index])
+ rend.append(tend[max_index])
+ rscore.append(tscore[max_index])
+ tstart.pop(max_index)
+ tend.pop(max_index)
+ tscore.pop(max_index)
+
+ newDf = pd.DataFrame()
+ newDf['score'] = rscore
+ newDf['xmin'] = rstart
+ newDf['xmax'] = rend
+ return newDf
+
+
+def video_process(video_list,
+ video_dict,
+ output_path,
+ result_dict,
+ snms_alpha=0.4,
+ snms_t1=0.55,
+ snms_t2=0.9):
+
+ for video_name in video_list:
+ print("Processing video........" + video_name)
+ df = pd.read_csv(os.path.join(output_path, video_name + ".csv"))
+ if len(df) > 1:
+ df = soft_nms(df, snms_alpha, snms_t1, snms_t2)
+
+ video_duration = video_dict[video_name]["duration_second"]
+ proposal_list = []
+ for idx in range(min(100, len(df))):
+ tmp_prop={"score":df.score.values[idx], \
+ "segment":[max(0,df.xmin.values[idx])*video_duration, \
+ min(1,df.xmax.values[idx])*video_duration]}
+ proposal_list.append(tmp_prop)
+ result_dict[video_name[2:]] = proposal_list
+
+
+def bmn_post_processing(video_dict, subset, output_path, result_path):
+ video_list = video_dict.keys()
+ video_list = list(video_dict.keys())
+ global result_dict
+ result_dict = mp.Manager().dict()
+ pp_num = 12
+
+ num_videos = len(video_list)
+ num_videos_per_thread = int(num_videos / pp_num)
+ processes = []
+ for tid in range(pp_num - 1):
+ tmp_video_list = video_list[tid * num_videos_per_thread:(tid + 1) *
+ num_videos_per_thread]
+ p = mp.Process(
+ target=video_process,
+ args=(tmp_video_list, video_dict, output_path, result_dict))
+ p.start()
+ processes.append(p)
+ tmp_video_list = video_list[(pp_num - 1) * num_videos_per_thread:]
+ p = mp.Process(
+ target=video_process,
+ args=(tmp_video_list, video_dict, output_path, result_dict))
+ p.start()
+ processes.append(p)
+ for p in processes:
+ p.join()
+
+ result_dict = dict(result_dict)
+ output_dict = {
+ "version": "VERSION 1.3",
+ "results": result_dict,
+ "external_data": {}
+ }
+ outfile = open(
+ os.path.join(result_path, "bmn_results_%s.json" % subset), "w")
+
+ json.dump(output_dict, outfile)
+ outfile.close()
+
+
+def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample,
+ num_sample_perbin):
+ """ generate sample mask for a boundary-matching pair """
+ plen = float(seg_xmax - seg_xmin)
+ plen_sample = plen / (num_sample * num_sample_perbin - 1.0)
+ total_samples = [
+ seg_xmin + plen_sample * ii
+ for ii in range(num_sample * num_sample_perbin)
+ ]
+ p_mask = []
+ for idx in range(num_sample):
+ bin_samples = total_samples[idx * num_sample_perbin:(idx + 1) *
+ num_sample_perbin]
+ bin_vector = np.zeros([tscale])
+ for sample in bin_samples:
+ sample_upper = math.ceil(sample)
+ sample_decimal, sample_down = math.modf(sample)
+ if int(sample_down) <= (tscale - 1) and int(sample_down) >= 0:
+ bin_vector[int(sample_down)] += 1 - sample_decimal
+ if int(sample_upper) <= (tscale - 1) and int(sample_upper) >= 0:
+ bin_vector[int(sample_upper)] += sample_decimal
+ bin_vector = 1.0 / num_sample_perbin * bin_vector
+ p_mask.append(bin_vector)
+ p_mask = np.stack(p_mask, axis=1)
+ return p_mask
+
+
+def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample,
+ num_sample_perbin):
+ """ generate sample mask for each point in Boundary-Matching Map """
+ mask_mat = []
+ for start_index in range(tscale):
+ mask_mat_vector = []
+ for duration_index in range(dscale):
+ if start_index + duration_index < tscale:
+ p_xmin = start_index
+ p_xmax = start_index + duration_index
+ center_len = float(p_xmax - p_xmin) + 1
+ sample_xmin = p_xmin - center_len * prop_boundary_ratio
+ sample_xmax = p_xmax + center_len * prop_boundary_ratio
+ p_mask = _get_interp1d_bin_mask(sample_xmin, sample_xmax,
+ tscale, num_sample,
+ num_sample_perbin)
+ else:
+ p_mask = np.zeros([tscale, num_sample])
+ mask_mat_vector.append(p_mask)
+ mask_mat_vector = np.stack(mask_mat_vector, axis=2)
+ mask_mat.append(mask_mat_vector)
+ mask_mat = np.stack(mask_mat, axis=3)
+ mask_mat = mask_mat.astype(np.float32)
+
+ sample_mask = np.reshape(mask_mat, [tscale, -1])
+ return sample_mask
diff --git a/bmn/config_utils.py b/bmn/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbef865781061e3f716de01e76d93f782f55b3a6
--- /dev/null
+++ b/bmn/config_utils.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import yaml
+import logging
+
+logger = logging.getLogger(__name__)
+
+CONFIG_SECS = [
+ 'train',
+ 'valid',
+ 'test',
+ 'infer',
+]
+
+
+class AttrDict(dict):
+ def __getattr__(self, key):
+ return self[key]
+
+ def __setattr__(self, key, value):
+ if key in self.__dict__:
+ self.__dict__[key] = value
+ else:
+ self[key] = value
+
+
+def parse_config(cfg_file):
+ """Load a config file into AttrDict"""
+ with open(cfg_file, 'r') as fopen:
+ yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
+ create_attr_dict(yaml_config)
+ return yaml_config
+
+
+def create_attr_dict(yaml_config):
+ from ast import literal_eval
+ for key, value in yaml_config.items():
+ if type(value) is dict:
+ yaml_config[key] = value = AttrDict(value)
+ if isinstance(value, str):
+ try:
+ value = literal_eval(value)
+ except BaseException:
+ pass
+ if isinstance(value, AttrDict):
+ create_attr_dict(yaml_config[key])
+ else:
+ yaml_config[key] = value
+ return
+
+
+def merge_configs(cfg, sec, args_dict):
+ assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
+ sec_dict = getattr(cfg, sec.upper())
+ for k, v in args_dict.items():
+ if v is None:
+ continue
+ try:
+ if hasattr(sec_dict, k):
+ setattr(sec_dict, k, v)
+ except:
+ pass
+ return cfg
+
+
+def print_configs(cfg, mode):
+ logger.info("---------------- {:>5} Arguments ----------------".format(
+ mode))
+ for sec, sec_items in cfg.items():
+ logger.info("{}:".format(sec))
+ for k, v in sec_items.items():
+ logger.info(" {}:{}".format(k, v))
+ logger.info("-------------------------------------------------")
diff --git a/bmn/eval.py b/bmn/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25fc5c79d21fd55743def09445db5821e3e93af
--- /dev/null
+++ b/bmn/eval.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import argparse
+import os
+import sys
+import logging
+import paddle.fluid as fluid
+
+sys.path.append('../')
+
+from model import set_device, Input
+from bmn_metric import BmnMetric
+from bmn_model import BMN, BmnLoss
+from reader import BmnDataset
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("BMN test for performance evaluation.")
+ parser.add_argument(
+ "-d",
+ "--dynamic",
+ default=True,
+ action='store_true',
+ help="enable dygraph mode, only support dynamic mode at present time")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='gpu',
+ help='gpu or cpu, default use gpu.')
+ parser.add_argument(
+ '--weights',
+ type=str,
+ default="checkpoint/final",
+ help='weight path, None to automatically download weights provided by Paddle.'
+ )
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=1,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+# Performance Evaluation
+def test_bmn(args):
+ # only support dynamic mode at present time
+ device = set_device(args.device)
+ fluid.enable_dygraph(device) if args.dynamic else None
+
+ config = parse_config(args.config_file)
+ eval_cfg = merge_configs(config, 'test', vars(args))
+ if not os.path.isdir(config.TEST.output_path):
+ os.makedirs(config.TEST.output_path)
+ if not os.path.isdir(config.TEST.result_path):
+ os.makedirs(config.TEST.result_path)
+
+ inputs = [
+ Input(
+ [None, config.MODEL.feat_dim, config.MODEL.tscale],
+ 'float32',
+ name='feat_input')
+ ]
+ gt_iou_map = Input(
+ [None, config.MODEL.dscale, config.MODEL.tscale],
+ 'float32',
+ name='gt_iou_map')
+ gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
+ gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
+ video_idx = Input([None, 1], 'int64', name='video_idx')
+ labels = [gt_iou_map, gt_start, gt_end, video_idx]
+
+ #data
+ eval_dataset = BmnDataset(eval_cfg, 'test')
+
+ #model
+ model = BMN(config, args.dynamic)
+ model.prepare(
+ loss_function=BmnLoss(config),
+ metrics=BmnMetric(
+ config, mode='test'),
+ inputs=inputs,
+ labels=labels,
+ device=device)
+
+ #load checkpoint
+ if args.weights:
+ assert os.path.exists(args.weights + '.pdparams'), \
+ "Given weight dir {} not exist.".format(args.weights)
+ logger.info('load test weights from {}'.format(args.weights))
+ model.load(args.weights)
+
+ model.evaluate(
+ eval_data=eval_dataset,
+ batch_size=eval_cfg.TEST.batch_size,
+ num_workers=eval_cfg.TEST.num_workers,
+ log_freq=args.log_interval)
+
+ logger.info("[EVAL] eval finished")
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ test_bmn(args)
diff --git a/bmn/eval_anet_prop.py b/bmn/eval_anet_prop.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a91c6effff8083ad54180c0931585731f2fa6d8
--- /dev/null
+++ b/bmn/eval_anet_prop.py
@@ -0,0 +1,110 @@
+'''
+Calculate AR@N and AUC;
+Modefied from ActivityNet Gitub repository](https://github.com/activitynet/ActivityNet.git)
+'''
+
+import sys
+sys.path.append('./Evaluation')
+
+from eval_proposal import ANETproposal
+import numpy as np
+import argparse
+import os
+
+parser = argparse.ArgumentParser("Eval AR vs AN of proposal")
+parser.add_argument(
+ '--eval_file',
+ type=str,
+ default='bmn_results_validation.json',
+ help='name of results file to eval')
+
+
+def run_evaluation(ground_truth_filename,
+ proposal_filename,
+ max_avg_nr_proposals=100,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
+ subset='validation'):
+
+ anet_proposal = ANETproposal(
+ ground_truth_filename,
+ proposal_filename,
+ tiou_thresholds=tiou_thresholds,
+ max_avg_nr_proposals=max_avg_nr_proposals,
+ subset=subset,
+ verbose=True,
+ check_status=False)
+ anet_proposal.evaluate()
+ recall = anet_proposal.recall
+ average_recall = anet_proposal.avg_recall
+ average_nr_proposals = anet_proposal.proposals_per_video
+
+ return (average_nr_proposals, average_recall, recall)
+
+
+def plot_metric(average_nr_proposals,
+ average_recall,
+ recall,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10)):
+ fn_size = 14
+ plt.figure(num=None, figsize=(12, 8))
+ ax = plt.subplot(1, 1, 1)
+
+ colors = [
+ 'k', 'r', 'yellow', 'b', 'c', 'm', 'b', 'pink', 'lawngreen', 'indigo'
+ ]
+ area_under_curve = np.zeros_like(tiou_thresholds)
+ for i in range(recall.shape[0]):
+ area_under_curve[i] = np.trapz(recall[i], average_nr_proposals)
+
+ for idx, tiou in enumerate(tiou_thresholds[::2]):
+ ax.plot(
+ average_nr_proposals,
+ recall[2 * idx, :],
+ color=colors[idx + 1],
+ label="tiou=[" + str(tiou) + "], area=" + str(
+ int(area_under_curve[2 * idx] * 100) / 100.),
+ linewidth=4,
+ linestyle='--',
+ marker=None)
+
+ # Plots Average Recall vs Average number of proposals.
+ ax.plot(
+ average_nr_proposals,
+ average_recall,
+ color=colors[0],
+ label="tiou = 0.5:0.05:0.95," + " area=" + str(
+ int(np.trapz(average_recall, average_nr_proposals) * 100) / 100.),
+ linewidth=4,
+ linestyle='-',
+ marker=None)
+
+ handles, labels = ax.get_legend_handles_labels()
+ ax.legend(
+ [handles[-1]] + handles[:-1], [labels[-1]] + labels[:-1], loc='best')
+
+ plt.ylabel('Average Recall', fontsize=fn_size)
+ plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
+ plt.grid(b=True, which="both")
+ plt.ylim([0, 1.0])
+ plt.setp(plt.axes().get_xticklabels(), fontsize=fn_size)
+ plt.setp(plt.axes().get_yticklabels(), fontsize=fn_size)
+ plt.show()
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ eval_file = args.eval_file
+ eval_file_path = os.path.join("evaluate_results", eval_file)
+ uniform_average_nr_proposals_valid, uniform_average_recall_valid, uniform_recall_valid = run_evaluation(
+ "./Evaluation/data/activity_net_1_3_new.json",
+ eval_file_path,
+ max_avg_nr_proposals=100,
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
+ subset='validation')
+
+ print("AR@1; AR@5; AR@10; AR@100")
+ print("%.02f %.02f %.02f %.02f" %
+ (100 * np.mean(uniform_recall_valid[:, 0]),
+ 100 * np.mean(uniform_recall_valid[:, 4]),
+ 100 * np.mean(uniform_recall_valid[:, 9]),
+ 100 * np.mean(uniform_recall_valid[:, -1])))
diff --git a/bmn/infer.list b/bmn/infer.list
new file mode 100644
index 0000000000000000000000000000000000000000..44768f089e70e40913d9787571ae0a7151232558
--- /dev/null
+++ b/bmn/infer.list
@@ -0,0 +1 @@
+{"v_4Lu8ECLHvK4": {"duration_second": 124.23, "subset": "validation", "duration_frame": 3718, "annotations": [{"segment": [0.01, 124.22675736961452], "label": "Playing kickball"}], "feature_frame": 3712}, "v_5qsXmDi8d74": {"duration_second": 186.59599999999998, "subset": "validation", "duration_frame": 5596, "annotations": [{"segment": [61.402645865834636, 173.44250858034323], "label": "Sumo"}], "feature_frame": 5600}, "v_2D22fVcAcyo": {"duration_second": 215.78400000000002, "subset": "validation", "duration_frame": 6473, "annotations": [{"segment": [10.433652106084244, 25.242706708268333], "label": "Slacklining"}, {"segment": [38.368914196567864, 66.30417628705149], "label": "Slacklining"}, {"segment": [74.71841185647428, 91.2103135725429], "label": "Slacklining"}, {"segment": [103.66338221528862, 126.8866723868955], "label": "Slacklining"}, {"segment": [132.27178315132608, 180.0855070202808], "label": "Slacklining"}], "feature_frame": 6464}, "v_wPYr19iFxhw": {"duration_second": 56.611000000000004, "subset": "validation", "duration_frame": 1693, "annotations": [{"segment": [0.01, 56.541], "label": "Welding"}], "feature_frame": 1696}, "v_K6Tm5xHkJ5c": {"duration_second": 114.64, "subset": "validation", "duration_frame": 2745, "annotations": [{"segment": [25.81087088455538, 50.817943021840875], "label": "Playing accordion"}, {"segment": [52.78278440405616, 110.6562942074883], "label": "Playing accordion"}], "feature_frame": 2736}}
\ No newline at end of file
diff --git a/bmn/predict.py b/bmn/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..e52927b60562425a1f03cfea12ab6cb21e76b3ef
--- /dev/null
+++ b/bmn/predict.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import argparse
+import sys
+import os
+import logging
+import paddle.fluid as fluid
+
+sys.path.append('../')
+
+from model import set_device, Input
+from bmn_metric import BmnMetric
+from bmn_model import BMN, BmnLoss
+from reader import BmnDataset
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("BMN inference.")
+ parser.add_argument(
+ "-d",
+ "--dynamic",
+ default=True,
+ action='store_true',
+ help="enable dygraph mode, only support dynamic mode at present time")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--device', type=str, default='GPU', help='default use gpu.')
+ parser.add_argument(
+ '--weights',
+ type=str,
+ default="checkpoint/final",
+ help='weight path, None to automatically download weights provided by Paddle.'
+ )
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default="predict_results/",
+ help='output dir path, default to use ./predict_results/')
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=1,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+# Prediction
+def infer_bmn(args):
+ # only support dynamic mode at present time
+ device = set_device(args.device)
+ fluid.enable_dygraph(device) if args.dynamic else None
+
+ config = parse_config(args.config_file)
+ infer_cfg = merge_configs(config, 'infer', vars(args))
+
+ if not os.path.isdir(config.INFER.output_path):
+ os.makedirs(config.INFER.output_path)
+ if not os.path.isdir(config.INFER.result_path):
+ os.makedirs(config.INFER.result_path)
+
+ inputs = [
+ Input(
+ [None, config.MODEL.feat_dim, config.MODEL.tscale],
+ 'float32',
+ name='feat_input')
+ ]
+ labels = [Input([None, 1], 'int64', name='video_idx')]
+
+ #data
+ infer_dataset = BmnDataset(infer_cfg, 'infer')
+
+ model = BMN(config, args.dynamic)
+ model.prepare(
+ metrics=BmnMetric(
+ config, mode='infer'),
+ inputs=inputs,
+ labels=labels,
+ device=device)
+
+ # load checkpoint
+ if args.weights:
+ assert os.path.exists(
+ args.weights +
+ ".pdparams"), "Given weight dir {} not exist.".format(args.weights)
+ logger.info('load test weights from {}'.format(args.weights))
+ model.load(args.weights)
+
+ # here use model.eval instead of model.test, as post process is required in our case
+ model.evaluate(
+ eval_data=infer_dataset,
+ batch_size=infer_cfg.TEST.batch_size,
+ num_workers=infer_cfg.TEST.num_workers,
+ log_freq=args.log_interval)
+
+ logger.info("[INFER] infer finished")
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ infer_bmn(args)
diff --git a/bmn/reader.py b/bmn/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c1da592e932b6ddbd19476e994f4267ae2f927
--- /dev/null
+++ b/bmn/reader.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle
+import numpy as np
+import json
+import logging
+import os
+import sys
+
+sys.path.append('../')
+
+from distributed import DistributedBatchSampler
+from paddle.fluid.io import Dataset, DataLoader
+
+logger = logging.getLogger(__name__)
+
+from config_utils import *
+from bmn_utils import iou_with_anchors, ioa_with_anchors
+
+DATATYPE = "float32"
+
+
+class BmnDataset(Dataset):
+ def __init__(self, cfg, mode):
+ self.mode = mode
+ self.tscale = cfg.MODEL.tscale # 100
+ self.dscale = cfg.MODEL.dscale # 100
+ self.anno_file = cfg.MODEL.anno_file
+ self.feat_path = cfg.MODEL.feat_path
+ self.file_list = cfg.INFER.filelist
+ self.subset = cfg[mode.upper()]['subset']
+ self.tgap = 1. / self.tscale
+
+ self.get_dataset_dict()
+ self.get_match_map()
+
+ def __getitem__(self, index):
+ video_name = self.video_list[index]
+ video_idx = self.video_list.index(video_name)
+ video_feat = self.load_file(video_name)
+ if self.mode == 'infer':
+ return video_feat, video_idx
+ else:
+ gt_iou_map, gt_start, gt_end = self.get_video_label(video_name)
+ if self.mode == 'train' or self.mode == 'valid':
+ return video_feat, gt_iou_map, gt_start, gt_end
+ elif self.mode == 'test':
+ return video_feat, gt_iou_map, gt_start, gt_end, video_idx
+
+ def __len__(self):
+ return len(self.video_list)
+
+ def get_dataset_dict(self):
+ assert (
+ os.path.exists(self.feat_path)), "Input feature path not exists"
+ assert (os.listdir(self.feat_path)), "No feature file in feature path"
+ self.video_dict = {}
+ if self.mode == "infer":
+ annos = json.load(open(self.file_list))
+ for video_name in annos.keys():
+ self.video_dict[video_name] = annos[video_name]
+ else:
+ annos = json.load(open(self.anno_file))
+ for video_name in annos.keys():
+ video_subset = annos[video_name]["subset"]
+ if self.subset in video_subset:
+ self.video_dict[video_name] = annos[video_name]
+ self.video_list = list(self.video_dict.keys())
+ self.video_list.sort()
+ print("%s subset video numbers: %d" %
+ (self.subset, len(self.video_list)))
+ video_name_set = set(
+ [video_name + '.npy' for video_name in self.video_list])
+ assert (video_name_set.intersection(set(os.listdir(self.feat_path))) ==
+ video_name_set), "Input feature not exists in feature path"
+
+ def get_match_map(self):
+ match_map = []
+ for idx in range(self.tscale):
+ tmp_match_window = []
+ xmin = self.tgap * idx
+ for jdx in range(1, self.tscale + 1):
+ xmax = xmin + self.tgap * jdx
+ tmp_match_window.append([xmin, xmax])
+ match_map.append(tmp_match_window)
+ match_map = np.array(match_map)
+ match_map = np.transpose(match_map, [1, 0, 2])
+ match_map = np.reshape(match_map, [-1, 2])
+ self.match_map = match_map
+ self.anchor_xmin = [self.tgap * i for i in range(self.tscale)]
+ self.anchor_xmax = [self.tgap * i for i in range(1, self.tscale + 1)]
+
+ def get_video_label(self, video_name):
+ video_info = self.video_dict[video_name]
+ video_second = video_info['duration_second']
+ video_labels = video_info['annotations']
+
+ gt_bbox = []
+ gt_iou_map = []
+ for gt in video_labels:
+ tmp_start = max(min(1, gt["segment"][0] / video_second), 0)
+ tmp_end = max(min(1, gt["segment"][1] / video_second), 0)
+ gt_bbox.append([tmp_start, tmp_end])
+ tmp_gt_iou_map = iou_with_anchors(
+ self.match_map[:, 0], self.match_map[:, 1], tmp_start, tmp_end)
+ tmp_gt_iou_map = np.reshape(tmp_gt_iou_map,
+ [self.dscale, self.tscale])
+ gt_iou_map.append(tmp_gt_iou_map)
+ gt_iou_map = np.array(gt_iou_map)
+ gt_iou_map = np.max(gt_iou_map, axis=0)
+
+ gt_bbox = np.array(gt_bbox)
+ gt_xmins = gt_bbox[:, 0]
+ gt_xmaxs = gt_bbox[:, 1]
+ gt_len_small = 3 * self.tgap
+ gt_start_bboxs = np.stack(
+ (gt_xmins - gt_len_small / 2, gt_xmins + gt_len_small / 2), axis=1)
+ gt_end_bboxs = np.stack(
+ (gt_xmaxs - gt_len_small / 2, gt_xmaxs + gt_len_small / 2), axis=1)
+
+ match_score_start = []
+ for jdx in range(len(self.anchor_xmin)):
+ match_score_start.append(
+ np.max(
+ ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
+ jdx], gt_start_bboxs[:, 0], gt_start_bboxs[:, 1])))
+ match_score_end = []
+ for jdx in range(len(self.anchor_xmin)):
+ match_score_end.append(
+ np.max(
+ ioa_with_anchors(self.anchor_xmin[jdx], self.anchor_xmax[
+ jdx], gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
+
+ gt_start = np.array(match_score_start)
+ gt_end = np.array(match_score_end)
+ return gt_iou_map.astype(DATATYPE), gt_start.astype(
+ DATATYPE), gt_end.astype(DATATYPE)
+
+ def load_file(self, video_name):
+ file_name = video_name + ".npy"
+ file_path = os.path.join(self.feat_path, file_name)
+ video_feat = np.load(file_path)
+ video_feat = video_feat.T
+ video_feat = video_feat.astype("float32")
+ return video_feat
diff --git a/bmn/run.sh b/bmn/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..24fd8e3da991c74628f6d345badf7bfe2e67c35d
--- /dev/null
+++ b/bmn/run.sh
@@ -0,0 +1,3 @@
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+
+python -m paddle.distributed.launch train.py
diff --git a/bmn/train.py b/bmn/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe46f6a607c6ab8f93be45ffeee11478ef862eb6
--- /dev/null
+++ b/bmn/train.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle.fluid as fluid
+import argparse
+import logging
+import sys
+import os
+
+sys.path.append('../')
+
+from model import set_device, Input
+from bmn_model import BMN, BmnLoss
+from reader import BmnDataset
+from config_utils import *
+
+DATATYPE = 'float32'
+
+logging.root.handlers = []
+FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("Paddle high level api of BMN.")
+ parser.add_argument(
+ "-d",
+ "--dynamic",
+ default=True,
+ action='store_true',
+ help="enable dygraph mode")
+ parser.add_argument(
+ '--config_file',
+ type=str,
+ default='bmn.yaml',
+ help='path to config file of model')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=None,
+ help='training batch size. None to use config file setting.')
+ parser.add_argument(
+ '--learning_rate',
+ type=float,
+ default=0.001,
+ help='learning rate use for training. None to use config file setting.')
+ parser.add_argument(
+ '--resume',
+ type=str,
+ default=None,
+ help='filename to resume training based on previous checkpoints. '
+ 'None for not resuming any checkpoints.')
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='gpu',
+ help='gpu or cpu, default use gpu.')
+ parser.add_argument(
+ '--epoch',
+ type=int,
+ default=9,
+ help='epoch number, 0 for read from config file')
+ parser.add_argument(
+ '--valid_interval',
+ type=int,
+ default=1,
+ help='validation epoch interval, 0 for no validation.')
+ parser.add_argument(
+ '--save_dir',
+ type=str,
+ default="checkpoint",
+ help='path to save train snapshoot')
+ parser.add_argument(
+ '--log_interval',
+ type=int,
+ default=10,
+ help='mini-batch interval to log.')
+ args = parser.parse_args()
+ return args
+
+
+# Optimizer
+def optimizer(cfg, parameter_list):
+ bd = [cfg.TRAIN.lr_decay_iter]
+ base_lr = cfg.TRAIN.learning_rate
+ lr_decay = cfg.TRAIN.learning_rate_decay
+ l2_weight_decay = cfg.TRAIN.l2_weight_decay
+ lr = [base_lr, base_lr * lr_decay]
+ optimizer = fluid.optimizer.Adam(
+ fluid.layers.piecewise_decay(
+ boundaries=bd, values=lr),
+ parameter_list=parameter_list,
+ regularization=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=l2_weight_decay))
+ return optimizer
+
+
+# TRAIN
+def train_bmn(args):
+ device = set_device(args.device)
+ fluid.enable_dygraph(device) if args.dynamic else None
+
+ if not os.path.isdir(args.save_dir):
+ os.makedirs(args.save_dir)
+
+ config = parse_config(args.config_file)
+ train_cfg = merge_configs(config, 'train', vars(args))
+ val_cfg = merge_configs(config, 'valid', vars(args))
+
+ inputs = [
+ Input(
+ [None, config.MODEL.feat_dim, config.MODEL.tscale],
+ 'float32',
+ name='feat_input')
+ ]
+ gt_iou_map = Input(
+ [None, config.MODEL.dscale, config.MODEL.tscale],
+ 'float32',
+ name='gt_iou_map')
+ gt_start = Input([None, config.MODEL.tscale], 'float32', name='gt_start')
+ gt_end = Input([None, config.MODEL.tscale], 'float32', name='gt_end')
+ labels = [gt_iou_map, gt_start, gt_end]
+
+ # data
+ train_dataset = BmnDataset(train_cfg, 'train')
+ val_dataset = BmnDataset(val_cfg, 'valid')
+
+ # model
+ model = BMN(config, args.dynamic)
+ optim = optimizer(config, parameter_list=model.parameters())
+ model.prepare(
+ optimizer=optim,
+ loss_function=BmnLoss(config),
+ inputs=inputs,
+ labels=labels,
+ device=device)
+
+ # if resume weights is given, load resume weights directly
+ if args.resume is not None:
+ model.load(args.resume)
+
+ model.fit(train_data=train_dataset,
+ eval_data=val_dataset,
+ batch_size=train_cfg.TRAIN.batch_size,
+ epochs=args.epoch,
+ eval_freq=args.valid_interval,
+ log_freq=args.log_interval,
+ save_dir=args.save_dir,
+ shuffle=train_cfg.TRAIN.use_shuffle,
+ num_workers=train_cfg.TRAIN.num_workers,
+ drop_last=True)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ train_bmn(args)
diff --git a/cyclegan/README.md b/cyclegan/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ef35c3ab1b3ec53ca2f9b7d6ba28f210b6d36e91
--- /dev/null
+++ b/cyclegan/README.md
@@ -0,0 +1,139 @@
+# Cycle GAN
+---
+## 内容
+
+- [安装](#安装)
+- [简介](#简介)
+- [代码结构](#代码结构)
+- [数据准备](#数据准备)
+- [模型训练与预测](#模型训练与预测)
+
+## 安装
+
+运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
+
+## 简介
+Cycle GAN 是一种image to image 的图像生成网络,实现了非对称图像数据集的生成和风格迁移。模型结构如下图所示,我们的模型包含两个生成网络 G: X → Y 和 F: Y → X,以及相关的判别器 DY 和 DX 。通过训练DY,使G将X图尽量转换为Y图,反之亦然。同时引入两个“周期一致性损失”,它们保证:如果我们从一个领域转换到另一个领域,它还可以被转换回去:(b)正向循环一致性损失:x→G(x)→F(G(x))≈x, (c)反向循环一致性损失:y→F(y)→G(F(y))≈y
+
+
+
+图1.网络结构
+
+
+
+## 代码结构
+```
+├── data.py # 读取、处理数据。
+├── layers.py # 封装定义基础的layers。
+├── cyclegan.py # 定义基础生成网络和判别网络。
+├── train.py # 训练脚本。
+└── infer.py # 预测脚本。
+```
+
+
+## 数据准备
+
+CycleGAN 支持的数据集可以参考download.py中的`cycle_pix_dataset`,可以通过指定`python download.py --dataset xxx` 下载得到。
+
+由于版权问题,cityscapes 数据集无法通过脚本直接获得,需要从[官方](https://www.cityscapes-dataset.com/)下载数据,
+下载完之后执行`python prepare_cityscapes_dataset.py --gtFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./data/cityscapes/`处理,
+将数据存放在`data/cityscapes`。
+
+数据下载处理完毕后,需要您将数据组织为以下路径结构:
+```
+data
+|-- cityscapes
+| |-- testA
+| |-- testB
+| |-- trainA
+| |-- trainB
+
+```
+
+然后运行txt生成脚本:`python generate_txt.py`,最终数据组织如下所示:
+```
+data
+|-- cityscapes
+| |-- testA
+| |-- testA.txt
+| |-- testB
+| |-- testB.txt
+| |-- trainA
+| |-- trainA.txt
+| |-- trainB
+| `-- trainB.txt
+
+```
+
+以上数据文件中,`data`文件夹需要放在训练脚本`train.py`同级目录下。`testA`为存放真实街景图片的文件夹,`testB`为存放语义分割图片的文件夹,`testA.txt`和`testB.txt`分别为测试图片路径列表文件,格式如下:
+
+```
+data/cityscapes/testA/234_A.jpg
+data/cityscapes/testA/292_A.jpg
+data/cityscapes/testA/412_A.jpg
+```
+
+训练数据组织方式与测试数据相同。
+
+
+## 模型训练与预测
+
+### 训练
+
+在GPU单卡上训练:
+
+```
+env CUDA_VISIBLE_DEVICES=0 python train.py
+```
+
+执行`python train.py --help`可查看更多使用方式和参数详细说明。
+
+图1为训练152轮的训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'g_loss','da_loss'和'db_loss'分别为生成器、判别器A和判别器B的训练损失。
+
+
+### 测试
+
+执行以下命令可以选择已保存的训练权重,对测试集进行测试,通过 `--epoch` 制定权重轮次:
+
+```
+env CUDA_VISIBLE_DEVICES=0 python test.py --init_model=checkpoint/199
+```
+生成结果在 `output/eval`中
+
+
+### 预测
+
+执行以下命令读取单张或多张图片进行预测:
+
+真实街景生成分割图像:
+
+```
+env CUDA_VISIBLE_DEVICES=0 python infer.py \
+ --init_model="./checkpoints/199" --input="./image/testA/123_A.jpg" \
+ --input_style=A
+```
+
+分割图像生成真实街景:
+
+```
+env CUDA_VISIBLE_DEVICES=0 python infer.py \
+ --init_model="checkpoints/199" --input="./image/testB/78_B.jpg" \
+ --input_style=B
+```
+生成结果在 `output/single`中
+
+训练180轮的模型预测效果如fakeA和fakeB所示:
+
+
+
+
+A2B
+
+
+
+
+
+B2A
+
+
+>在本文示例中,均可通过修改`CUDA_VISIBLE_DEVICES`改变使用的显卡号。
diff --git a/cyclegan/__init__.py b/cyclegan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cyclegan/check.py b/cyclegan/check.py
new file mode 100644
index 0000000000000000000000000000000000000000..79ab4862d3c2082c36039b047be08d4a4b5dcedd
--- /dev/null
+++ b/cyclegan/check.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import paddle.fluid as fluid
+
+__all__ = ['check_gpu', 'check_version']
+
+
+def check_gpu(use_gpu):
+ """
+ Log error and exit when set use_gpu=true in paddlepaddle
+ cpu version.
+ """
+ err = "Config use_gpu cannot be set as true while you are " \
+ "using paddlepaddle cpu version ! \nPlease try: \n" \
+ "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
+ "\t2. Set use_gpu as false in config file to run " \
+ "model on CPU"
+
+ try:
+ if use_gpu and not fluid.is_compiled_with_cuda():
+ print(err)
+ sys.exit(1)
+ except Exception as e:
+ pass
+
+
+def check_version():
+ """
+ Log error and exit when the installed version of paddlepaddle is
+ not satisfied.
+ """
+ err = "PaddlePaddle version 1.6 or higher is required, " \
+ "or a suitable develop version is satisfied as well. \n" \
+ "Please make sure the version is good with your code." \
+
+ try:
+ fluid.require_version('1.7.0')
+ except Exception as e:
+ print(err)
+ sys.exit(1)
diff --git a/cyclegan/cyclegan.py b/cyclegan/cyclegan.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fdd21c1bdf41a8ed3b6743297b99ef239bd5543
--- /dev/null
+++ b/cyclegan/cyclegan.py
@@ -0,0 +1,232 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from layers import ConvBN, DeConvBN
+import paddle.fluid as fluid
+from model import Model, Loss
+
+
+class ResnetBlock(fluid.dygraph.Layer):
+ def __init__(self, dim, dropout=False):
+ super(ResnetBlock, self).__init__()
+ self.dropout = dropout
+ self.conv0 = ConvBN(dim, dim, 3, 1)
+ self.conv1 = ConvBN(dim, dim, 3, 1, act=None)
+
+ def forward(self, inputs):
+ out_res = fluid.layers.pad2d(inputs, [1, 1, 1, 1], mode="reflect")
+ out_res = self.conv0(out_res)
+ if self.dropout:
+ out_res = fluid.layers.dropout(out_res, dropout_prob=0.5)
+ out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
+ out_res = self.conv1(out_res)
+ return out_res + inputs
+
+
+class ResnetGenerator(fluid.dygraph.Layer):
+ def __init__(self, input_channel, n_blocks=9, dropout=False):
+ super(ResnetGenerator, self).__init__()
+
+ self.conv0 = ConvBN(input_channel, 32, 7, 1)
+ self.conv1 = ConvBN(32, 64, 3, 2, padding=1)
+ self.conv2 = ConvBN(64, 128, 3, 2, padding=1)
+
+ dim = 128
+ self.resnet_blocks = []
+ for i in range(n_blocks):
+ block = self.add_sublayer("generator_%d" % (i + 1),
+ ResnetBlock(dim, dropout))
+ self.resnet_blocks.append(block)
+
+ self.deconv0 = DeConvBN(
+ dim, 32 * 2, 3, 2, padding=[1, 1], outpadding=[0, 1, 0, 1])
+ self.deconv1 = DeConvBN(
+ 32 * 2, 32, 3, 2, padding=[1, 1], outpadding=[0, 1, 0, 1])
+
+ self.conv3 = ConvBN(
+ 32, input_channel, 7, 1, norm=False, act=False, use_bias=True)
+
+ def forward(self, inputs):
+ pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect")
+ y = self.conv0(pad_input)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ for resnet_block in self.resnet_blocks:
+ y = resnet_block(y)
+ y = self.deconv0(y)
+ y = self.deconv1(y)
+ y = fluid.layers.pad2d(y, [3, 3, 3, 3], mode="reflect")
+ y = self.conv3(y)
+ y = fluid.layers.tanh(y)
+ return y
+
+
+class NLayerDiscriminator(fluid.dygraph.Layer):
+ def __init__(self, input_channel, d_dims=64, d_nlayers=3):
+ super(NLayerDiscriminator, self).__init__()
+ self.conv0 = ConvBN(
+ input_channel,
+ d_dims,
+ 4,
+ 2,
+ 1,
+ norm=False,
+ use_bias=True,
+ relufactor=0.2)
+
+ nf_mult, nf_mult_prev = 1, 1
+ self.conv_layers = []
+ for n in range(1, d_nlayers):
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ conv = self.add_sublayer(
+ 'discriminator_%d' % (n),
+ ConvBN(
+ d_dims * nf_mult_prev,
+ d_dims * nf_mult,
+ 4,
+ 2,
+ 1,
+ relufactor=0.2))
+ self.conv_layers.append(conv)
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**d_nlayers, 8)
+ self.conv4 = ConvBN(
+ d_dims * nf_mult_prev, d_dims * nf_mult, 4, 1, 1, relufactor=0.2)
+ self.conv5 = ConvBN(
+ d_dims * nf_mult,
+ 1,
+ 4,
+ 1,
+ 1,
+ norm=False,
+ act=None,
+ use_bias=True,
+ relufactor=0.2)
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ for conv in self.conv_layers:
+ y = conv(y)
+ y = self.conv4(y)
+ y = self.conv5(y)
+ return y
+
+
+class Generator(Model):
+ def __init__(self, input_channel=3):
+ super(Generator, self).__init__()
+ self.g = ResnetGenerator(input_channel)
+
+ def forward(self, input):
+ fake = self.g(input)
+ return fake
+
+
+class GeneratorCombine(Model):
+ def __init__(self, g_AB=None, g_BA=None, d_A=None, d_B=None,
+ is_train=True):
+ super(GeneratorCombine, self).__init__()
+ self.g_AB = g_AB
+ self.g_BA = g_BA
+ self.is_train = is_train
+ if self.is_train:
+ self.d_A = d_A
+ self.d_B = d_B
+
+ def forward(self, input_A, input_B):
+ # Translate images to the other domain
+ fake_B = self.g_AB(input_A)
+ fake_A = self.g_BA(input_B)
+
+ # Translate images back to original domain
+ cyc_A = self.g_BA(fake_B)
+ cyc_B = self.g_AB(fake_A)
+ if not self.is_train:
+ return fake_A, fake_B, cyc_A, cyc_B
+
+ # Identity mapping of images
+ idt_A = self.g_AB(input_B)
+ idt_B = self.g_BA(input_A)
+
+ # Discriminators determines validity of translated images
+ # d_A(g_AB(A))
+ valid_A = self.d_A.d(fake_B)
+ # d_B(g_BA(A))
+ valid_B = self.d_B.d(fake_A)
+ return input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B
+
+
+class GLoss(Loss):
+ def __init__(self, lambda_A=10., lambda_B=10., lambda_identity=0.5):
+ super(GLoss, self).__init__()
+ self.lambda_A = lambda_A
+ self.lambda_B = lambda_B
+ self.lambda_identity = lambda_identity
+
+ def forward(self, outputs, labels=None):
+ input_A, input_B, fake_A, fake_B, cyc_A, cyc_B, idt_A, idt_B, valid_A, valid_B = outputs
+
+ def mse(a, b):
+ return fluid.layers.reduce_mean(fluid.layers.square(a - b))
+
+ def mae(a, b): # L1Loss
+ return fluid.layers.reduce_mean(fluid.layers.abs(a - b))
+
+ g_A_loss = mse(valid_A, 1.)
+ g_B_loss = mse(valid_B, 1.)
+ g_loss = g_A_loss + g_B_loss
+
+ cyc_A_loss = mae(input_A, cyc_A) * self.lambda_A
+ cyc_B_loss = mae(input_B, cyc_B) * self.lambda_B
+ cyc_loss = cyc_A_loss + cyc_B_loss
+
+ idt_loss_A = mae(input_B, idt_A) * (self.lambda_B *
+ self.lambda_identity)
+ idt_loss_B = mae(input_A, idt_B) * (self.lambda_A *
+ self.lambda_identity)
+ idt_loss = idt_loss_A + idt_loss_B
+
+ loss = cyc_loss + g_loss + idt_loss
+ return loss
+
+
+class Discriminator(Model):
+ def __init__(self, input_channel=3):
+ super(Discriminator, self).__init__()
+ self.d = NLayerDiscriminator(input_channel)
+
+ def forward(self, real, fake):
+ pred_real = self.d(real)
+ pred_fake = self.d(fake)
+ return pred_real, pred_fake
+
+
+class DLoss(Loss):
+ def __init__(self):
+ super(DLoss, self).__init__()
+
+ def forward(self, inputs, labels=None):
+ pred_real, pred_fake = inputs
+ loss = fluid.layers.square(pred_fake) + fluid.layers.square(pred_real -
+ 1.)
+ loss = fluid.layers.reduce_mean(loss / 2.0)
+ return loss
diff --git a/cyclegan/data.py b/cyclegan/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..effa4eeee12a7a4905f3cc40687d8349601bc6c6
--- /dev/null
+++ b/cyclegan/data.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+import random
+import numpy as np
+from PIL import Image, ImageOps
+
+DATASET = "cityscapes"
+A_LIST_FILE = "./data/" + DATASET + "/trainA.txt"
+B_LIST_FILE = "./data/" + DATASET + "/trainB.txt"
+A_TEST_LIST_FILE = "./data/" + DATASET + "/testA.txt"
+B_TEST_LIST_FILE = "./data/" + DATASET + "/testB.txt"
+IMAGES_ROOT = "./data/" + DATASET + "/"
+
+import paddle.fluid as fluid
+
+
+class Cityscapes(fluid.io.Dataset):
+ def __init__(self, root_path, file_path, mode='train', return_name=False):
+ self.root_path = root_path
+ self.file_path = file_path
+ self.mode = mode
+ self.return_name = return_name
+ self.images = [root_path + l for l in open(file_path, 'r').readlines()]
+
+ def _train(self, image):
+ ## Resize
+ image = image.resize((286, 286), Image.BICUBIC)
+ ## RandomCrop
+ i = np.random.randint(0, 30)
+ j = np.random.randint(0, 30)
+ image = image.crop((i, j, i + 256, j + 256))
+ # RandomHorizontalFlip
+ if np.random.rand() > 0.5:
+ image = ImageOps.mirror(image)
+ return image
+
+ def __getitem__(self, idx):
+ f = self.images[idx].strip("\n\r\t ")
+ image = Image.open(f)
+ if self.mode == 'train':
+ image = self._train(image)
+ else:
+ image = image.resize((256, 256), Image.BICUBIC)
+ # ToTensor
+ image = np.array(image).transpose([2, 0, 1]).astype('float32')
+ image = image / 255.0
+ # Normalize, mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]
+ image = (image - 0.5) / 0.5
+ if self.return_name:
+ return [image], os.path.basename(f)
+ else:
+ return [image]
+
+ def __len__(self):
+ return len(self.images)
+
+
+def DataA(root=IMAGES_ROOT, fpath=A_LIST_FILE):
+ """
+ Reader of images with A style for training.
+ """
+ return Cityscapes(root, fpath)
+
+
+def DataB(root=IMAGES_ROOT, fpath=B_LIST_FILE):
+ """
+ Reader of images with B style for training.
+ """
+ return Cityscapes(root, fpath)
+
+
+def TestDataA(root=IMAGES_ROOT, fpath=A_TEST_LIST_FILE):
+ """
+ Reader of images with A style for training.
+ """
+ return Cityscapes(root, fpath, mode='test', return_name=True)
+
+
+def TestDataB(root=IMAGES_ROOT, fpath=B_TEST_LIST_FILE):
+ """
+ Reader of images with B style for training.
+ """
+ return Cityscapes(root, fpath, mode='test', return_name=True)
+
+
+class ImagePool(object):
+ def __init__(self, pool_size=50):
+ self.pool = []
+ self.count = 0
+ self.pool_size = pool_size
+
+ def get(self, image):
+ if self.count < self.pool_size:
+ self.pool.append(image)
+ self.count += 1
+ return image
+ else:
+ p = random.random()
+ if p > 0.5:
+ random_id = random.randint(0, self.pool_size - 1)
+ temp = self.pool[random_id]
+ self.pool[random_id] = image
+ return temp
+ else:
+ return image
diff --git a/cyclegan/image/A2B.png b/cyclegan/image/A2B.png
new file mode 100644
index 0000000000000000000000000000000000000000..b67466da9bdf04344ac6a8f417169414641be664
Binary files /dev/null and b/cyclegan/image/A2B.png differ
diff --git a/cyclegan/image/B2A.png b/cyclegan/image/B2A.png
new file mode 100644
index 0000000000000000000000000000000000000000..851dd7422144a12cbdc25c47229c4db3ed727120
Binary files /dev/null and b/cyclegan/image/B2A.png differ
diff --git a/cyclegan/image/net.png b/cyclegan/image/net.png
new file mode 100644
index 0000000000000000000000000000000000000000..46681f8eea98995deeb03fb90257451a6fdfcdf8
Binary files /dev/null and b/cyclegan/image/net.png differ
diff --git a/cyclegan/image/testA/123_A.jpg b/cyclegan/image/testA/123_A.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c78de45861aa3afaa33ecaeb6f72e444a8391987
Binary files /dev/null and b/cyclegan/image/testA/123_A.jpg differ
diff --git a/cyclegan/image/testB/78_B.jpg b/cyclegan/image/testB/78_B.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..849c3be3ce6bd94cf38b1e2e40725727949c2a75
Binary files /dev/null and b/cyclegan/image/testB/78_B.jpg differ
diff --git a/cyclegan/infer.py b/cyclegan/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b61a958d59e19b73fa01d3c484e1e3231fae71b
--- /dev/null
+++ b/cyclegan/infer.py
@@ -0,0 +1,108 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import glob
+import numpy as np
+import argparse
+
+from PIL import Image
+from scipy.misc import imsave
+
+import paddle.fluid as fluid
+from check import check_gpu, check_version
+
+from model import Model, Input, set_device
+from cyclegan import Generator, GeneratorCombine
+
+
+def main():
+ place = set_device(FLAGS.device)
+ fluid.enable_dygraph(place) if FLAGS.dynamic else None
+
+ # Generators
+ g_AB = Generator()
+ g_BA = Generator()
+ g = GeneratorCombine(g_AB, g_BA, is_train=False)
+
+ im_shape = [-1, 3, 256, 256]
+ input_A = Input(im_shape, 'float32', 'input_A')
+ input_B = Input(im_shape, 'float32', 'input_B')
+ g.prepare(inputs=[input_A, input_B])
+ g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
+
+ out_path = FLAGS.output + "/single"
+ if not os.path.exists(out_path):
+ os.makedirs(out_path)
+ for f in glob.glob(FLAGS.input):
+ image_name = os.path.basename(f)
+ image = Image.open(f).convert('RGB')
+ image = image.resize((256, 256), Image.BICUBIC)
+ image = np.array(image) / 127.5 - 1
+
+ image = image[:, :, 0:3].astype("float32")
+ data = image.transpose([2, 0, 1])[np.newaxis, :]
+
+ if FLAGS.input_style == "A":
+ _, fake, _, _ = g.test([data, data])
+
+ if FLAGS.input_style == "B":
+ fake, _, _, _ = g.test([data, data])
+
+ fake = np.squeeze(fake[0]).transpose([1, 2, 0])
+
+ opath = "{}/fake{}{}".format(out_path, FLAGS.input_style, image_name)
+ imsave(opath, ((fake + 1) * 127.5).astype(np.uint8))
+ print("transfer {} to {}".format(f, opath))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("CycleGAN inference")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_false', help="Enable dygraph mode")
+ parser.add_argument(
+ "-p",
+ "--device",
+ type=str,
+ default='gpu',
+ help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-i",
+ "--input",
+ type=str,
+ default='./image/testA/123_A.jpg',
+ help="input image")
+ parser.add_argument(
+ "-o",
+ '--output',
+ type=str,
+ default='output',
+ help="The test result to be saved to.")
+ parser.add_argument(
+ "-m",
+ "--init_model",
+ type=str,
+ default='checkpoint/199',
+ help="The init model file of directory.")
+ parser.add_argument(
+ "-s", "--input_style", type=str, default='A', help="A or B")
+ FLAGS = parser.parse_args()
+ print(FLAGS)
+ check_gpu(str.lower(FLAGS.device) == 'gpu')
+ check_version()
+ main()
diff --git a/cyclegan/layers.py b/cyclegan/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c79ef5ff541646d98c3ce26d1d9a1888dc4421c
--- /dev/null
+++ b/cyclegan/layers.py
@@ -0,0 +1,140 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+import paddle.fluid as fluid
+from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, BatchNorm
+
+# cudnn is not better when batch size is 1.
+use_cudnn = False
+import numpy as np
+
+
+class ConvBN(fluid.dygraph.Layer):
+ """docstring for Conv2D"""
+
+ def __init__(self,
+ num_channels,
+ num_filters,
+ filter_size,
+ stride=1,
+ padding=0,
+ stddev=0.02,
+ norm=True,
+ is_test=False,
+ act='leaky_relu',
+ relufactor=0.0,
+ use_bias=False):
+ super(ConvBN, self).__init__()
+
+ pattr = fluid.ParamAttr(
+ initializer=fluid.initializer.NormalInitializer(
+ loc=0.0, scale=stddev))
+ self.conv = Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ use_cudnn=use_cudnn,
+ param_attr=pattr,
+ bias_attr=use_bias)
+ if norm:
+ self.bn = BatchNorm(
+ num_filters,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.NormalInitializer(1.0,
+ 0.02)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.0)),
+ is_test=False,
+ trainable_statistics=True)
+ self.relufactor = relufactor
+ self.norm = norm
+ self.act = act
+
+ def forward(self, inputs):
+ conv = self.conv(inputs)
+ if self.norm:
+ conv = self.bn(conv)
+
+ if self.act == 'leaky_relu':
+ conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor)
+ elif self.act == 'relu':
+ conv = fluid.layers.relu(conv)
+ else:
+ conv = conv
+
+ return conv
+
+
+class DeConvBN(fluid.dygraph.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ filter_size,
+ stride=1,
+ padding=[0, 0],
+ outpadding=[0, 0, 0, 0],
+ stddev=0.02,
+ act='leaky_relu',
+ norm=True,
+ is_test=False,
+ relufactor=0.0,
+ use_bias=False):
+ super(DeConvBN, self).__init__()
+
+ pattr = fluid.ParamAttr(
+ initializer=fluid.initializer.NormalInitializer(
+ loc=0.0, scale=stddev))
+ self._deconv = Conv2DTranspose(
+ num_channels,
+ num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ param_attr=pattr,
+ bias_attr=use_bias)
+ if norm:
+ self.bn = BatchNorm(
+ num_filters,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.NormalInitializer(1.0,
+ 0.02)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.0)),
+ is_test=False,
+ trainable_statistics=True)
+ self.outpadding = outpadding
+ self.relufactor = relufactor
+ self.use_bias = use_bias
+ self.norm = norm
+ self.act = act
+
+ def forward(self, inputs):
+ conv = self._deconv(inputs)
+ conv = fluid.layers.pad2d(
+ conv, paddings=self.outpadding, mode='constant', pad_value=0.0)
+
+ if self.norm:
+ conv = self.bn(conv)
+
+ if self.act == 'leaky_relu':
+ conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor)
+ elif self.act == 'relu':
+ conv = fluid.layers.relu(conv)
+ else:
+ conv = conv
+
+ return conv
diff --git a/cyclegan/test.py b/cyclegan/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..995663090f07e345e54be47da26a8c0e7fd32a4a
--- /dev/null
+++ b/cyclegan/test.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import argparse
+import numpy as np
+from scipy.misc import imsave
+
+import paddle.fluid as fluid
+from check import check_gpu, check_version
+
+from model import Model, Input, set_device
+from cyclegan import Generator, GeneratorCombine
+import data as data
+
+
+def main():
+ place = set_device(FLAGS.device)
+ fluid.enable_dygraph(place) if FLAGS.dynamic else None
+
+ # Generators
+ g_AB = Generator()
+ g_BA = Generator()
+ g = GeneratorCombine(g_AB, g_BA, is_train=False)
+
+ im_shape = [-1, 3, 256, 256]
+ input_A = Input(im_shape, 'float32', 'input_A')
+ input_B = Input(im_shape, 'float32', 'input_B')
+ g.prepare(inputs=[input_A, input_B])
+ g.load(FLAGS.init_model, skip_mismatch=True, reset_optimizer=True)
+
+ if not os.path.exists(FLAGS.output):
+ os.makedirs(FLAGS.output)
+
+ test_data_A = data.TestDataA()
+ test_data_B = data.TestDataB()
+
+ for i in range(len(test_data_A)):
+ data_A, A_name = test_data_A[i]
+ data_B, B_name = test_data_B[i]
+ data_A = np.array(data_A).astype("float32")
+ data_B = np.array(data_B).astype("float32")
+
+ fake_A, fake_B, cyc_A, cyc_B = g.test([data_A, data_B])
+
+ datas = [fake_A, fake_B, cyc_A, cyc_B, data_A, data_B]
+ odatas = []
+ for o in datas:
+ d = np.squeeze(o[0]).transpose([1, 2, 0])
+ im = ((d + 1) * 127.5).astype(np.uint8)
+ odatas.append(im)
+ imsave(FLAGS.output + "/fakeA_" + B_name, odatas[0])
+ imsave(FLAGS.output + "/fakeB_" + A_name, odatas[1])
+ imsave(FLAGS.output + "/cycA_" + A_name, odatas[2])
+ imsave(FLAGS.output + "/cycB_" + B_name, odatas[3])
+ imsave(FLAGS.output + "/inputA_" + A_name, odatas[4])
+ imsave(FLAGS.output + "/inputB_" + B_name, odatas[5])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("CycleGAN test")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_false', help="Enable dygraph mode")
+ parser.add_argument(
+ "-p",
+ "--device",
+ type=str,
+ default='gpu',
+ help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-b", "--batch_size", default=1, type=int, help="batch size")
+ parser.add_argument(
+ "-o",
+ '--output',
+ type=str,
+ default='output/eval',
+ help="The test result to be saved to.")
+ parser.add_argument(
+ "-m",
+ "--init_model",
+ type=str,
+ default='checkpoint/199',
+ help="The init model file of directory.")
+ FLAGS = parser.parse_args()
+ print(FLAGS)
+ check_gpu(str.lower(FLAGS.device) == 'gpu')
+ check_version()
+ main()
diff --git a/cyclegan/train.py b/cyclegan/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2203fc19c8e0381fa27bde26a22a863130532e9
--- /dev/null
+++ b/cyclegan/train.py
@@ -0,0 +1,158 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import random
+import argparse
+import contextlib
+import time
+
+import paddle
+import paddle.fluid as fluid
+from check import check_gpu, check_version
+
+from model import Model, Input, set_device
+
+import data as data
+from cyclegan import Generator, Discriminator, GeneratorCombine, GLoss, DLoss
+
+step_per_epoch = 2974
+
+
+def opt(parameters):
+ lr_base = 0.0002
+ bounds = [100, 120, 140, 160, 180]
+ lr = [1., 0.8, 0.6, 0.4, 0.2, 0.1]
+ bounds = [i * step_per_epoch for i in bounds]
+ lr = [i * lr_base for i in lr]
+ optimizer = fluid.optimizer.Adam(
+ learning_rate=fluid.layers.piecewise_decay(
+ boundaries=bounds, values=lr),
+ parameter_list=parameters,
+ beta1=0.5)
+ return optimizer
+
+
+def main():
+ place = set_device(FLAGS.device)
+ fluid.enable_dygraph(place) if FLAGS.dynamic else None
+
+ # Generators
+ g_AB = Generator()
+ g_BA = Generator()
+
+ # Discriminators
+ d_A = Discriminator()
+ d_B = Discriminator()
+
+ g = GeneratorCombine(g_AB, g_BA, d_A, d_B)
+
+ da_params = d_A.parameters()
+ db_params = d_B.parameters()
+ g_params = g_AB.parameters() + g_BA.parameters()
+
+ da_optimizer = opt(da_params)
+ db_optimizer = opt(db_params)
+ g_optimizer = opt(g_params)
+
+ im_shape = [None, 3, 256, 256]
+ input_A = Input(im_shape, 'float32', 'input_A')
+ input_B = Input(im_shape, 'float32', 'input_B')
+ fake_A = Input(im_shape, 'float32', 'fake_A')
+ fake_B = Input(im_shape, 'float32', 'fake_B')
+
+ g_AB.prepare(inputs=[input_A])
+ g_BA.prepare(inputs=[input_B])
+
+ g.prepare(g_optimizer, GLoss(), inputs=[input_A, input_B])
+ d_A.prepare(da_optimizer, DLoss(), inputs=[input_B, fake_B])
+ d_B.prepare(db_optimizer, DLoss(), inputs=[input_A, fake_A])
+
+ if FLAGS.resume:
+ g.load(FLAGS.resume)
+
+ loader_A = fluid.io.DataLoader(
+ data.DataA(),
+ places=place,
+ shuffle=True,
+ return_list=True,
+ batch_size=FLAGS.batch_size)
+ loader_B = fluid.io.DataLoader(
+ data.DataB(),
+ places=place,
+ shuffle=True,
+ return_list=True,
+ batch_size=FLAGS.batch_size)
+
+ A_pool = data.ImagePool()
+ B_pool = data.ImagePool()
+
+ for epoch in range(FLAGS.epoch):
+ for i, (data_A, data_B) in enumerate(zip(loader_A, loader_B)):
+ data_A = data_A[0][0] if not FLAGS.dynamic else data_A[0]
+ data_B = data_B[0][0] if not FLAGS.dynamic else data_B[0]
+ start = time.time()
+
+ fake_B = g_AB.test(data_A)[0]
+ fake_A = g_BA.test(data_B)[0]
+ g_loss = g.train([data_A, data_B])[0]
+ fake_pb = B_pool.get(fake_B)
+ da_loss = d_A.train([data_B, fake_pb])[0]
+
+ fake_pa = A_pool.get(fake_A)
+ db_loss = d_B.train([data_A, fake_pa])[0]
+
+ t = time.time() - start
+ if i % 20 == 0:
+ print("epoch: {} | step: {:3d} | g_loss: {:.4f} | " \
+ "da_loss: {:.4f} | db_loss: {:.4f} | s/step {:.4f}".
+ format(epoch, i, g_loss[0], da_loss[0], db_loss[0], t))
+ g.save('{}/{}'.format(FLAGS.checkpoint_path, epoch))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("CycleGAN Training on Cityscapes")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_false', help="Enable dygraph mode")
+ parser.add_argument(
+ "-p",
+ "--device",
+ type=str,
+ default='gpu',
+ help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-e", "--epoch", default=200, type=int, help="Epoch number")
+ parser.add_argument(
+ "-b", "--batch_size", default=1, type=int, help="batch size")
+ parser.add_argument(
+ "-o",
+ "--checkpoint_path",
+ type=str,
+ default='checkpoint',
+ help="path to save checkpoint")
+ parser.add_argument(
+ "-r",
+ "--resume",
+ default=None,
+ type=str,
+ help="checkpoint path to resume")
+ FLAGS = parser.parse_args()
+ print(FLAGS)
+ check_gpu(str.lower(FLAGS.device) == 'gpu')
+ check_version()
+ main()
diff --git a/datasets/folder.py b/datasets/folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e853e7e106cf7a305c79ab900515be6f8febf3a0
--- /dev/null
+++ b/datasets/folder.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import cv2
+
+from paddle.fluid.io import Dataset
+
+
+def has_valid_extension(filename, extensions):
+ """Checks if a file is a vilid extension.
+
+ Args:
+ filename (str): path to a file
+ extensions (tuple of str): extensions to consider (lowercase)
+
+ Returns:
+ bool: True if the filename ends with one of given extensions
+ """
+ return filename.lower().endswith(extensions)
+
+
+def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
+ images = []
+ dir = os.path.expanduser(dir)
+ if not ((extensions is None) ^ (is_valid_file is None)):
+ raise ValueError(
+ "Both extensions and is_valid_file cannot be None or not None at the same time"
+ )
+ if extensions is not None:
+
+ def is_valid_file(x):
+ return has_valid_extension(x, extensions)
+
+ for target in sorted(class_to_idx.keys()):
+ d = os.path.join(dir, target)
+ if not os.path.isdir(d):
+ continue
+ for root, _, fnames in sorted(os.walk(d, followlinks=True)):
+ for fname in sorted(fnames):
+ path = os.path.join(root, fname)
+ if is_valid_file(path):
+ item = (path, class_to_idx[target])
+ images.append(item)
+
+ return images
+
+
+class DatasetFolder(Dataset):
+ """A generic data loader where the samples are arranged in this way:
+
+ root/class_a/1.ext
+ root/class_a/2.ext
+ root/class_a/3.ext
+
+ root/class_b/123.ext
+ root/class_b/456.ext
+ root/class_b/789.ext
+
+ Args:
+ root (string): Root directory path.
+ loader (callable, optional): A function to load a sample given its path.
+ extensions (tuple[string], optional): A list of allowed extensions.
+ both extensions and is_valid_file should not be passed.
+ transform (callable, optional): A function/transform that takes in
+ a sample and returns a transformed version.
+ target_transform (callable, optional): A function/transform that takes
+ in the target and transforms it.
+ is_valid_file (callable, optional): A function that takes path of a file
+ and check if the file is a valid file (used to check of corrupt files)
+ both extensions and is_valid_file should not be passed.
+
+ Attributes:
+ classes (list): List of the class names.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ samples (list): List of (sample path, class_index) tuples
+ targets (list): The class_index value for each image in the dataset
+ """
+
+ def __init__(self,
+ root,
+ loader=None,
+ extensions=None,
+ transform=None,
+ target_transform=None,
+ is_valid_file=None):
+ self.root = root
+ if extensions is None:
+ extensions = IMG_EXTENSIONS
+ classes, class_to_idx = self._find_classes(self.root)
+ samples = make_dataset(self.root, class_to_idx, extensions,
+ is_valid_file)
+ if len(samples) == 0:
+ raise (RuntimeError(
+ "Found 0 files in subfolders of: " + self.root + "\n"
+ "Supported extensions are: " + ",".join(extensions)))
+
+ self.loader = cv2_loader if loader is None else loader
+ self.extensions = extensions
+
+ self.classes = classes
+ self.class_to_idx = class_to_idx
+ self.samples = samples
+ self.targets = [s[1] for s in samples]
+
+ def _find_classes(self, dir):
+ """
+ Finds the class folders in a dataset.
+
+ Args:
+ dir (string): Root directory path.
+
+ Returns:
+ tuple: (classes, class_to_idx) where classes are relative to (dir),
+ and class_to_idx is a dictionary.
+
+ """
+ if sys.version_info >= (3, 5):
+ # Faster and available in Python 3.5 and above
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()]
+ else:
+ classes = [
+ d for d in os.listdir(dir)
+ if os.path.isdir(os.path.join(dir, d))
+ ]
+ classes.sort()
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ return classes, class_to_idx
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def __len__(self):
+ return len(self.samples)
+
+
+IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
+ '.tiff', '.webp')
+
+
+def cv2_loader(path):
+ return cv2.imread(path)
diff --git a/image_classification/README.MD b/image_classification/README.MD
new file mode 100644
index 0000000000000000000000000000000000000000..9be3362090e97e64ee5c09ead6247f5ecb217781
--- /dev/null
+++ b/image_classification/README.MD
@@ -0,0 +1,92 @@
+# 高级api图像分类
+
+## 数据集准备
+在开始训练前,请确保已经下载解压好[ImageNet数据集](http://image-net.org/download),并放在合适的目录下,准备好的数据集的目录结构如下所示:
+
+```bash
+/path/to/imagenet
+ train
+ n01440764
+ xxx.jpg
+ ...
+ n01443537
+ xxx.jpg
+ ...
+ ...
+ val
+ n01440764
+ xxx.jpg
+ ...
+ n01443537
+ xxx.jpg
+ ...
+ ...
+```
+
+
+## 训练
+### 单卡训练
+执行如下命令进行训练
+```bash
+python -u main.py --arch resnet50 /path/to/imagenet -d
+```
+-d 是使用动态模式训练,默认为静态图模式。
+
+### 多卡训练
+执行如下命令进行训练
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 -d /path/to/imagenet
+```
+
+## 预测
+
+### 单卡预测
+执行如下命令进行预测
+```bash
+python -u main.py --arch resnet50 -d --evaly-only /path/to/imagenet
+```
+
+### 多卡预测
+执行如下命令进行多卡预测
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch resnet50 --evaly-only /path/to/imagenet
+```
+
+
+## 参数说明
+
+
+* **arch**: 要训练或预测的模型名称
+* **device**: 训练使用的设备,'gpu'或'cpu',默认值:'gpu'
+* **dynamic**: 是否使用动态图模式训练
+* **epoch**: 训练的轮数,默认值:120
+* **learning-rate**: 学习率,默认值:0.1
+* **batch-size**: 每张卡的batch size,默认值:64
+* **output-dir**: 模型文件保存的文件夹,默认值:'output'
+* **num-workers**: dataloader的进程数,默认值:4
+* **resume**: 恢复训练的模型路径,默认值:None
+* **eval-only**: 是否仅仅进行预测
+* **lr-scheduler**: 学习率衰减策略,默认值:piecewise
+* **milestones**: piecewise学习率衰减策略的边界,默认值:[30, 60, 80]
+* **weight-decay**: 模型权重正则化系数,默认值:1e-4
+* **momentum**: SGD优化器的动量,默认值:0.9
+
+
+## 模型
+
+| 模型 | top1 acc | top5 acc |
+| --- | --- | --- |
+| [ResNet50](https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams) | 76.28 | 93.04 |
+| [vgg16](https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams) | 71.84 | 90.71 |
+| [mobilenet_v1](https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams) | 71.25 | 89.92 |
+| [mobilenet_v2](https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams) | 72.27 | 90.66 |
+
+上述模型的复现参数请参考scripts下的脚本。
+
+
+## 参考文献
+- ResNet: [Deep Residual Learning for Image Recognitio](https://arxiv.org/abs/1512.03385), Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+- MobileNetV1: [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861), Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
+- MobileNetV2: [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/pdf/1801.04381v4.pdf), Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
+- VGG: [Very Deep Convolutional Networks for Large-scale Image Recognition](https://arxiv.org/pdf/1409.1556), Karen Simonyan, Andrew Zisserman
+
diff --git a/image_classification/imagenet_dataset.py b/image_classification/imagenet_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..948ac5b8bb4c360bc2ea52d819c2958da52ef68f
--- /dev/null
+++ b/image_classification/imagenet_dataset.py
@@ -0,0 +1,98 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import math
+import random
+import numpy as np
+
+from datasets.folder import DatasetFolder
+
+
+def center_crop_resize(img):
+ h, w = img.shape[:2]
+ c = int(224 / 256 * min((h, w)))
+ i = (h + 1 - c) // 2
+ j = (w + 1 - c) // 2
+ img = img[i:i + c, j:j + c, :]
+ return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
+
+
+def random_crop_resize(img):
+ height, width = img.shape[:2]
+ area = height * width
+
+ for attempt in range(10):
+ target_area = random.uniform(0.08, 1.) * area
+ log_ratio = (math.log(3 / 4), math.log(4 / 3))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w <= width and h <= height:
+ i = random.randint(0, height - h)
+ j = random.randint(0, width - w)
+ img = img[i:i + h, j:j + w, :]
+ return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
+
+ return center_crop_resize(img)
+
+
+def random_flip(img):
+ if np.random.randint(0, 2) == 1:
+ img = img[:, ::-1, :]
+ return img
+
+
+def normalize_permute(img):
+ # transpose and convert to RGB from BGR
+ img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
+ mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
+ std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
+ invstd = 1. / std
+ for v, m, s in zip(img, mean, invstd):
+ v.__isub__(m).__imul__(s)
+ return img
+
+
+def compose(functions):
+ def process(sample):
+ img, label = sample
+ for fn in functions:
+ img = fn(img)
+ return img, label
+
+ return process
+
+
+class ImageNetDataset(DatasetFolder):
+ def __init__(self, path, mode='train'):
+ super(ImageNetDataset, self).__init__(path)
+ self.mode = mode
+ if self.mode == 'train':
+ self.transform = compose([
+ cv2.imread, random_crop_resize, random_flip, normalize_permute
+ ])
+ else:
+ self.transform = compose(
+ [cv2.imread, center_crop_resize, normalize_permute])
+
+ def __getitem__(self, idx):
+ img, label = self.samples[idx]
+ return self.transform((img, [label]))
+
+ def __len__(self):
+ return len(self.samples)
diff --git a/image_classification/main.py b/image_classification/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..781824fa60f9d703187697825595d81889b9c53c
--- /dev/null
+++ b/image_classification/main.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import contextlib
+import os
+import sys
+sys.path.append('../')
+
+import time
+import math
+import numpy as np
+import models
+import paddle.fluid as fluid
+
+from model import CrossEntropy, Input, set_device
+from imagenet_dataset import ImageNetDataset
+from distributed import DistributedBatchSampler
+from paddle.fluid.dygraph.parallel import ParallelEnv
+from metrics import Accuracy
+from paddle.fluid.io import BatchSampler, DataLoader
+
+
+def make_optimizer(step_per_epoch, parameter_list=None):
+ base_lr = FLAGS.lr
+ lr_scheduler = FLAGS.lr_scheduler
+ momentum = FLAGS.momentum
+ weight_decay = FLAGS.weight_decay
+
+ if lr_scheduler == 'piecewise':
+ milestones = FLAGS.milestones
+ boundaries = [step_per_epoch * e for e in milestones]
+ values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
+ learning_rate = fluid.layers.piecewise_decay(
+ boundaries=boundaries, values=values)
+ elif lr_scheduler == 'cosine':
+ learning_rate = fluid.layers.cosine_decay(base_lr, step_per_epoch,
+ FLAGS.epoch)
+ else:
+ raise ValueError(
+ "Expected lr_scheduler in ['piecewise', 'cosine'], but got {}".
+ format(lr_scheduler))
+
+ learning_rate = fluid.layers.linear_lr_warmup(
+ learning_rate=learning_rate,
+ warmup_steps=5 * step_per_epoch,
+ start_lr=0.,
+ end_lr=base_lr)
+
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=learning_rate,
+ momentum=momentum,
+ regularization=fluid.regularizer.L2Decay(weight_decay),
+ parameter_list=parameter_list)
+
+ return optimizer
+
+
+def main():
+ device = set_device(FLAGS.device)
+ fluid.enable_dygraph(device) if FLAGS.dynamic else None
+
+ model = models.__dict__[FLAGS.arch](pretrained=FLAGS.eval_only and
+ not FLAGS.resume)
+
+ if FLAGS.resume is not None:
+ model.load(FLAGS.resume)
+
+ inputs = [Input([None, 3, 224, 224], 'float32', name='image')]
+ labels = [Input([None, 1], 'int64', name='label')]
+
+ train_dataset = ImageNetDataset(
+ os.path.join(FLAGS.data, 'train'), mode='train')
+ val_dataset = ImageNetDataset(os.path.join(FLAGS.data, 'val'), mode='val')
+
+ optim = make_optimizer(
+ np.ceil(
+ len(train_dataset) * 1. / FLAGS.batch_size / ParallelEnv().nranks),
+ parameter_list=model.parameters())
+
+ model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 5)), inputs, labels)
+
+ if FLAGS.eval_only:
+ model.evaluate(
+ val_dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.num_workers)
+ return
+
+ output_dir = os.path.join(FLAGS.output_dir, FLAGS.arch,
+ time.strftime('%Y-%m-%d-%H-%M',
+ time.localtime()))
+ if ParallelEnv().local_rank == 0 and not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ model.fit(train_dataset,
+ val_dataset,
+ batch_size=FLAGS.batch_size,
+ epochs=FLAGS.epoch,
+ save_dir=output_dir,
+ num_workers=FLAGS.num_workers)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser("Resnet Training on ImageNet")
+ parser.add_argument(
+ 'data',
+ metavar='DIR',
+ help='path to dataset '
+ '(should have subdirectories named "train" and "val"')
+ parser.add_argument(
+ "--arch", type=str, default='resnet50', help="model name")
+ parser.add_argument(
+ "--device", type=str, default='gpu', help="device to run, cpu or gpu")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_true', help="enable dygraph mode")
+ parser.add_argument(
+ "-e", "--epoch", default=90, type=int, help="number of epoch")
+ parser.add_argument(
+ '--lr',
+ '--learning-rate',
+ default=0.1,
+ type=float,
+ metavar='LR',
+ help='initial learning rate')
+ parser.add_argument(
+ "-b", "--batch-size", default=64, type=int, help="batch size")
+ parser.add_argument(
+ "-n", "--num-workers", default=4, type=int, help="dataloader workers")
+ parser.add_argument(
+ "--output-dir", type=str, default='output', help="save dir")
+ parser.add_argument(
+ "-r",
+ "--resume",
+ default=None,
+ type=str,
+ help="checkpoint path to resume")
+ parser.add_argument(
+ "--eval-only", action='store_true', help="enable dygraph mode")
+ parser.add_argument(
+ "--lr-scheduler",
+ default='piecewise',
+ type=str,
+ help="learning rate scheduler")
+ parser.add_argument(
+ "--milestones",
+ nargs='+',
+ type=int,
+ default=[30, 60, 80],
+ help="piecewise decay milestones")
+ parser.add_argument(
+ "--weight-decay", default=1e-4, type=float, help="weight decay")
+ parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
+ FLAGS = parser.parse_args()
+ assert FLAGS.data, "error: must provide data path"
+ main()
diff --git a/lac.py b/lac.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdd380686256b2039f6aa7f2289639559969a6a8
--- /dev/null
+++ b/lac.py
@@ -0,0 +1,728 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+lexical analysis network structure
+"""
+
+from __future__ import division
+from __future__ import print_function
+
+import io
+import os
+import sys
+import math
+import argparse
+import numpy as np
+
+from metrics import Metric
+from model import Model, Input, Loss, set_device
+
+import paddle.fluid as fluid
+from paddle.fluid.optimizer import AdamOptimizer
+from paddle.fluid.initializer import NormalInitializer
+from paddle.fluid.dygraph.nn import Embedding, Linear, GRUUnit
+
+
+class DynamicGRU(fluid.dygraph.Layer):
+ def __init__(self,
+ size,
+ h_0=None,
+ param_attr=None,
+ bias_attr=None,
+ is_reverse=False,
+ gate_activation='sigmoid',
+ candidate_activation='tanh',
+ origin_mode=False,
+ init_size=None):
+ super(DynamicGRU, self).__init__()
+
+ self.gru_unit = GRUUnit(
+ size * 3,
+ param_attr=param_attr,
+ bias_attr=bias_attr,
+ activation=candidate_activation,
+ gate_activation=gate_activation,
+ origin_mode=origin_mode)
+
+ self.size = size
+ self.h_0 = h_0
+ self.is_reverse = is_reverse
+
+ def forward(self, inputs):
+ hidden = self.h_0
+ res = []
+
+ for i in range(inputs.shape[1]):
+ if self.is_reverse:
+ i = inputs.shape[1] - 1 - i
+ input_ = inputs[:, i:i + 1, :]
+ input_ = fluid.layers.reshape(
+ input_, [-1, input_.shape[2]], inplace=False)
+ hidden, reset, gate = self.gru_unit(input_, hidden)
+ hidden_ = fluid.layers.reshape(
+ hidden, [-1, 1, hidden.shape[1]], inplace=False)
+ res.append(hidden_)
+ if self.is_reverse:
+ res = res[::-1]
+ res = fluid.layers.concat(res, axis=1)
+ return res
+
+
+class BiGRU(fluid.dygraph.Layer):
+ def __init__(self, input_dim, grnn_hidden_dim, init_bound, h_0=None):
+ super(BiGRU, self).__init__()
+
+ self.pre_gru = Linear(
+ input_dim=input_dim,
+ output_dim=grnn_hidden_dim * 3,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=-init_bound, high=init_bound),
+ regularizer=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=1e-4)))
+
+ self.gru = DynamicGRU(
+ size=grnn_hidden_dim,
+ h_0=h_0,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=-init_bound, high=init_bound),
+ regularizer=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=1e-4)))
+
+ self.pre_gru_r = Linear(
+ input_dim=input_dim,
+ output_dim=grnn_hidden_dim * 3,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=-init_bound, high=init_bound),
+ regularizer=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=1e-4)))
+
+ self.gru_r = DynamicGRU(
+ size=grnn_hidden_dim,
+ is_reverse=True,
+ h_0=h_0,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=-init_bound, high=init_bound),
+ regularizer=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=1e-4)))
+
+ def forward(self, input_feature):
+ res_pre_gru = self.pre_gru(input_feature)
+ res_gru = self.gru(res_pre_gru)
+ res_pre_gru_r = self.pre_gru_r(input_feature)
+ res_gru_r = self.gru_r(res_pre_gru_r)
+ bi_merge = fluid.layers.concat(input=[res_gru, res_gru_r], axis=-1)
+ return bi_merge
+
+
+class Linear_chain_crf(fluid.dygraph.Layer):
+ def __init__(self, param_attr, size=None, is_test=False, dtype='float32'):
+ super(Linear_chain_crf, self).__init__()
+
+ self._param_attr = param_attr
+ self._dtype = dtype
+ self._size = size
+ self._is_test = is_test
+ self._transition = self.create_parameter(
+ attr=self._param_attr,
+ shape=[self._size + 2, self._size],
+ dtype=self._dtype)
+
+ @property
+ def weight(self):
+ return self._transition
+
+ @weight.setter
+ def weight(self, value):
+ self._transition = value
+
+ def forward(self, input, label, length=None):
+
+ alpha = self._helper.create_variable_for_type_inference(
+ dtype=self._dtype)
+ emission_exps = self._helper.create_variable_for_type_inference(
+ dtype=self._dtype)
+ transition_exps = self._helper.create_variable_for_type_inference(
+ dtype=self._dtype)
+ log_likelihood = self._helper.create_variable_for_type_inference(
+ dtype=self._dtype)
+ this_inputs = {
+ "Emission": [input],
+ "Transition": self._transition,
+ "Label": [label]
+ }
+ if length:
+ this_inputs['Length'] = [length]
+ self._helper.append_op(
+ type='linear_chain_crf',
+ inputs=this_inputs,
+ outputs={
+ "Alpha": [alpha],
+ "EmissionExps": [emission_exps],
+ "TransitionExps": transition_exps,
+ "LogLikelihood": log_likelihood
+ },
+ attrs={"is_test": self._is_test, })
+ return log_likelihood
+
+
+class Crf_decoding(fluid.dygraph.Layer):
+ def __init__(self, param_attr, size=None, is_test=False, dtype='float32'):
+ super(Crf_decoding, self).__init__()
+
+ self._dtype = dtype
+ self._size = size
+ self._is_test = is_test
+ self._param_attr = param_attr
+ self._transition = self.create_parameter(
+ attr=self._param_attr,
+ shape=[self._size + 2, self._size],
+ dtype=self._dtype)
+
+ @property
+ def weight(self):
+ return self._transition
+
+ @weight.setter
+ def weight(self, value):
+ self._transition = value
+
+ def forward(self, input, label=None, length=None):
+
+ viterbi_path = self._helper.create_variable_for_type_inference(
+ dtype=self._dtype)
+ this_inputs = {
+ "Emission": [input],
+ "Transition": self._transition,
+ "Label": label
+ }
+ if length:
+ this_inputs['Length'] = [length]
+ self._helper.append_op(
+ type='crf_decoding',
+ inputs=this_inputs,
+ outputs={"ViterbiPath": [viterbi_path]},
+ attrs={"is_test": self._is_test, })
+ return viterbi_path
+
+
+class Chunk_eval(fluid.dygraph.Layer):
+ def __init__(self,
+ num_chunk_types,
+ chunk_scheme,
+ excluded_chunk_types=None):
+ super(Chunk_eval, self).__init__()
+ self.num_chunk_types = num_chunk_types
+ self.chunk_scheme = chunk_scheme
+ self.excluded_chunk_types = excluded_chunk_types
+
+ def forward(self, input, label, seq_length=None):
+ precision = self._helper.create_variable_for_type_inference(
+ dtype="float32")
+ recall = self._helper.create_variable_for_type_inference(
+ dtype="float32")
+ f1_score = self._helper.create_variable_for_type_inference(
+ dtype="float32")
+ num_infer_chunks = self._helper.create_variable_for_type_inference(
+ dtype="int64")
+ num_label_chunks = self._helper.create_variable_for_type_inference(
+ dtype="int64")
+ num_correct_chunks = self._helper.create_variable_for_type_inference(
+ dtype="int64")
+
+ this_input = {"Inference": input, "Label": label[0]}
+ if seq_length:
+ this_input["SeqLength"] = seq_length[0]
+ self._helper.append_op(
+ type='chunk_eval',
+ inputs=this_input,
+ outputs={
+ "Precision": [precision],
+ "Recall": [recall],
+ "F1-Score": [f1_score],
+ "NumInferChunks": [num_infer_chunks],
+ "NumLabelChunks": [num_label_chunks],
+ "NumCorrectChunks": [num_correct_chunks]
+ },
+ attrs={
+ "num_chunk_types": self.num_chunk_types,
+ "chunk_scheme": self.chunk_scheme,
+ "excluded_chunk_types": self.excluded_chunk_types or []
+ })
+ return (num_infer_chunks, num_label_chunks, num_correct_chunks)
+
+
+class LAC(Model):
+ def __init__(self, args, vocab_size, num_labels, length=None):
+ super(LAC, self).__init__()
+ """
+ define the lexical analysis network structure
+ word: stores the input of the model
+ for_infer: a boolean value, indicating if the model to be created is for training or predicting.
+
+ return:
+ for infer: return the prediction
+ otherwise: return the prediction
+ """
+ self.word_emb_dim = args.word_emb_dim
+ self.vocab_size = vocab_size
+ self.num_labels = num_labels
+ self.grnn_hidden_dim = args.grnn_hidden_dim
+ self.emb_lr = args.emb_learning_rate if 'emb_learning_rate' in dir(
+ args) else 1.0
+ self.crf_lr = args.emb_learning_rate if 'crf_learning_rate' in dir(
+ args) else 1.0
+ self.bigru_num = args.bigru_num
+ self.init_bound = 0.1
+
+ self.word_embedding = Embedding(
+ size=[self.vocab_size, self.word_emb_dim],
+ dtype='float32',
+ param_attr=fluid.ParamAttr(
+ learning_rate=self.emb_lr,
+ name="word_emb",
+ initializer=fluid.initializer.Uniform(
+ low=-self.init_bound, high=self.init_bound)))
+
+ h_0 = fluid.layers.create_global_var(
+ shape=[args.batch_size, self.grnn_hidden_dim],
+ value=0.0,
+ dtype='float32',
+ persistable=True,
+ force_cpu=True,
+ name='h_0')
+
+ self.bigru_units = []
+ for i in range(self.bigru_num):
+ if i == 0:
+ self.bigru_units.append(
+ self.add_sublayer(
+ "bigru_units%d" % i,
+ BiGRU(
+ self.grnn_hidden_dim,
+ self.grnn_hidden_dim,
+ self.init_bound,
+ h_0=h_0)))
+ else:
+ self.bigru_units.append(
+ self.add_sublayer(
+ "bigru_units%d" % i,
+ BiGRU(
+ self.grnn_hidden_dim * 2,
+ self.grnn_hidden_dim,
+ self.init_bound,
+ h_0=h_0)))
+
+ self.fc = Linear(
+ input_dim=self.grnn_hidden_dim * 2,
+ output_dim=self.num_labels,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=-self.init_bound, high=self.init_bound),
+ regularizer=fluid.regularizer.L2DecayRegularizer(
+ regularization_coeff=1e-4)))
+
+ self.linear_chain_crf = Linear_chain_crf(
+ param_attr=fluid.ParamAttr(
+ name='linear_chain_crfw', learning_rate=self.crf_lr),
+ size=self.num_labels)
+
+ self.crf_decoding = Crf_decoding(
+ param_attr=fluid.ParamAttr(
+ name='crfw', learning_rate=self.crf_lr),
+ size=self.num_labels)
+
+ def forward(self, word, target, lengths):
+ """
+ Configure the network
+ """
+ word_embed = self.word_embedding(word)
+ input_feature = word_embed
+
+ for i in range(self.bigru_num):
+ bigru_output = self.bigru_units[i](input_feature)
+ input_feature = bigru_output
+
+ emission = self.fc(bigru_output)
+
+ crf_cost = self.linear_chain_crf(
+ input=emission, label=target, length=lengths)
+ avg_cost = fluid.layers.mean(x=crf_cost)
+ self.crf_decoding.weight = self.linear_chain_crf.weight
+ crf_decode = self.crf_decoding(input=emission, length=lengths)
+ return crf_decode, avg_cost, lengths
+
+
+class LacLoss(Loss):
+ def __init__(self):
+ super(LacLoss, self).__init__()
+ pass
+
+ def forward(self, outputs, labels):
+ avg_cost = outputs[1]
+ return avg_cost
+
+
+class ChunkEval(Metric):
+ def __init__(self, num_labels, name=None, *args, **kwargs):
+ super(ChunkEval, self).__init__(*args, **kwargs)
+ self._init_name(name)
+ self.chunk_eval = Chunk_eval(
+ int(math.ceil((num_labels - 1) / 2.0)), "IOB")
+ self.reset()
+
+ def add_metric_op(self, pred, label, *args, **kwargs):
+ crf_decode = pred[0]
+ lengths = pred[2]
+ (num_infer_chunks, num_label_chunks,
+ num_correct_chunks) = self.chunk_eval(
+ input=crf_decode, label=label, seq_length=lengths)
+ return [num_infer_chunks, num_label_chunks, num_correct_chunks]
+
+ def update(self, num_infer_chunks, num_label_chunks, num_correct_chunks,
+ *args, **kwargs):
+ self.infer_chunks_total += num_infer_chunks
+ self.label_chunks_total += num_label_chunks
+ self.correct_chunks_total += num_correct_chunks
+ precision = float(
+ num_correct_chunks) / num_infer_chunks if num_infer_chunks else 0
+ recall = float(
+ num_correct_chunks) / num_label_chunks if num_label_chunks else 0
+ f1_score = float(2 * precision * recall) / (
+ precision + recall) if num_correct_chunks else 0
+ return [precision, recall, f1_score]
+
+ def reset(self):
+ self.infer_chunks_total = 0
+ self.label_chunks_total = 0
+ self.correct_chunks_total = 0
+
+ def accumulate(self):
+ precision = float(
+ self.correct_chunks_total
+ ) / self.infer_chunks_total if self.infer_chunks_total else 0
+ recall = float(
+ self.correct_chunks_total
+ ) / self.label_chunks_total if self.label_chunks_total else 0
+ f1_score = float(2 * precision * recall) / (
+ precision + recall) if self.correct_chunks_total else 0
+ res = [precision, recall, f1_score]
+ return res
+
+ def _init_name(self, name):
+ name = name or 'chunk eval'
+ self._name = ['precision', 'recall', 'F1']
+
+ def name(self):
+ return self._name
+
+
+class LacDataset(object):
+ """
+ Load lexical analysis dataset
+ """
+
+ def __init__(self, args):
+ self.word_dict_path = args.word_dict_path
+ self.label_dict_path = args.label_dict_path
+ self.word_rep_dict_path = args.word_rep_dict_path
+ self._load_dict()
+
+ def _load_dict(self):
+ self.word2id_dict = self.load_kv_dict(
+ self.word_dict_path, reverse=True, value_func=np.int64)
+ self.id2word_dict = self.load_kv_dict(self.word_dict_path)
+ self.label2id_dict = self.load_kv_dict(
+ self.label_dict_path, reverse=True, value_func=np.int64)
+ self.id2label_dict = self.load_kv_dict(self.label_dict_path)
+ if self.word_rep_dict_path is None:
+ self.word_replace_dict = dict()
+ else:
+ self.word_replace_dict = self.load_kv_dict(self.word_rep_dict_path)
+
+ def load_kv_dict(self,
+ dict_path,
+ reverse=False,
+ delimiter="\t",
+ key_func=None,
+ value_func=None):
+ """
+ Load key-value dict from file
+ """
+ result_dict = {}
+ for line in io.open(dict_path, "r", encoding='utf8'):
+ terms = line.strip("\n").split(delimiter)
+ if len(terms) != 2:
+ continue
+ if reverse:
+ value, key = terms
+ else:
+ key, value = terms
+ if key in result_dict:
+ raise KeyError("key duplicated with [%s]" % (key))
+ if key_func:
+ key = key_func(key)
+ if value_func:
+ value = value_func(value)
+ result_dict[key] = value
+ return result_dict
+
+ @property
+ def vocab_size(self):
+ return len(self.word2id_dict.values())
+
+ @property
+ def num_labels(self):
+ return len(self.label2id_dict.values())
+
+ def get_num_examples(self, filename):
+ """num of line of file"""
+ return sum(1 for line in io.open(filename, "r", encoding='utf8'))
+
+ def word_to_ids(self, words):
+ """convert word to word index"""
+ word_ids = []
+ for word in words:
+ word = self.word_replace_dict.get(word, word)
+ if word not in self.word2id_dict:
+ word = "OOV"
+ word_id = self.word2id_dict[word]
+ word_ids.append(word_id)
+
+ return word_ids
+
+ def label_to_ids(self, labels):
+ """convert label to label index"""
+ label_ids = []
+ for label in labels:
+ if label not in self.label2id_dict:
+ label = "O"
+ label_id = self.label2id_dict[label]
+ label_ids.append(label_id)
+ return label_ids
+
+ def file_reader(self,
+ filename,
+ mode="train",
+ batch_size=32,
+ max_seq_len=126):
+ """
+ yield (word_idx, target_idx) one by one from file,
+ or yield (word_idx, ) in `infer` mode
+ """
+
+ def wrapper():
+ fread = io.open(filename, "r", encoding="utf-8")
+ headline = next(fread)
+ headline = headline.strip().split('\t')
+ assert len(headline) == 2 and headline[0] == "text_a" and headline[
+ 1] == "label"
+ buf = []
+ for line in fread:
+ words, labels = line.strip("\n").split("\t")
+ if len(words) < 1:
+ continue
+ word_ids = self.word_to_ids(words.split("\002"))
+ label_ids = self.label_to_ids(labels.split("\002"))
+ assert len(word_ids) == len(label_ids)
+ word_ids = word_ids[0:max_seq_len]
+ words_len = np.int64(len(word_ids))
+ word_ids += [0 for _ in range(max_seq_len - words_len)]
+ label_ids = label_ids[0:max_seq_len]
+ label_ids += [0 for _ in range(max_seq_len - words_len)]
+ assert len(word_ids) == len(label_ids)
+ yield word_ids, label_ids, words_len
+ fread.close()
+
+ return wrapper
+
+
+def create_lexnet_data_generator(args, reader, file_name, place, mode="train"):
+ def wrapper():
+ batch_words, batch_labels, seq_lens = [], [], []
+ for epoch in xrange(args.epoch):
+ for instance in reader.file_reader(
+ file_name, mode, max_seq_len=args.max_seq_len)():
+ words, labels, words_len = instance
+ if len(seq_lens) < args.batch_size:
+ batch_words.append(words)
+ batch_labels.append(labels)
+ seq_lens.append(words_len)
+ if len(seq_lens) == args.batch_size:
+ yield batch_words, batch_labels, seq_lens, batch_labels
+ batch_words, batch_labels, seq_lens = [], [], []
+
+ if len(seq_lens) > 0:
+ yield batch_words, batch_labels, seq_lens, batch_labels
+ batch_words, batch_labels, seq_lens = [], [], []
+
+ return wrapper
+
+
+def create_dataloader(generator, place, feed_list=None):
+ if not feed_list:
+ data_loader = fluid.io.DataLoader.from_generator(
+ capacity=50,
+ use_double_buffer=True,
+ iterable=True,
+ return_list=True)
+ else:
+ data_loader = fluid.io.DataLoader.from_generator(
+ feed_list=feed_list,
+ capacity=50,
+ use_double_buffer=True,
+ iterable=True,
+ return_list=True)
+ data_loader.set_batch_generator(generator, places=place)
+ return data_loader
+
+
+def main(args):
+ place = set_device(args.device)
+ fluid.enable_dygraph(place) if args.dynamic else None
+
+ inputs = [
+ Input(
+ [None, args.max_seq_len], 'int64', name='words'), Input(
+ [None, args.max_seq_len], 'int64', name='target'), Input(
+ [None], 'int64', name='length')
+ ]
+ labels = [Input([None, args.max_seq_len], 'int64', name='labels')]
+
+ feed = [x.forward() for x in inputs + labels]
+ dataset = LacDataset(args)
+ train_path = os.path.join(args.data, "train.tsv")
+ test_path = os.path.join(args.data, "test.tsv")
+
+ if args.dynamic:
+ feed_list = None
+ else:
+ feed_list = feed
+ train_generator = create_lexnet_data_generator(
+ args, reader=dataset, file_name=train_path, place=place, mode="train")
+ test_generator = create_lexnet_data_generator(
+ args, reader=dataset, file_name=test_path, place=place, mode="test")
+
+ train_dataset = create_dataloader(
+ train_generator, place, feed_list=feed_list)
+ test_dataset = create_dataloader(
+ test_generator, place, feed_list=feed_list)
+
+ vocab_size = dataset.vocab_size
+ num_labels = dataset.num_labels
+ model = LAC(args, vocab_size, num_labels)
+
+ optim = AdamOptimizer(
+ learning_rate=args.base_learning_rate,
+ parameter_list=model.parameters())
+
+ model.prepare(
+ optim,
+ LacLoss(),
+ ChunkEval(num_labels),
+ inputs=inputs,
+ labels=labels,
+ device=args.device)
+
+ if args.resume is not None:
+ model.load(args.resume)
+
+ model.fit(train_dataset,
+ test_dataset,
+ epochs=args.epoch,
+ batch_size=args.batch_size,
+ eval_freq=args.eval_freq,
+ save_freq=args.save_freq,
+ save_dir=args.save_dir)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser("LAC training")
+ parser.add_argument(
+ "-dir", "--data", default=None, type=str, help='path to LAC dataset')
+ parser.add_argument(
+ "-wd",
+ "--word_dict_path",
+ default=None,
+ type=str,
+ help='word dict path')
+ parser.add_argument(
+ "-ld",
+ "--label_dict_path",
+ default=None,
+ type=str,
+ help='label dict path')
+ parser.add_argument(
+ "-wrd",
+ "--word_rep_dict_path",
+ default=None,
+ type=str,
+ help='The path of the word replacement Dictionary.')
+ parser.add_argument(
+ "-dev",
+ "--device",
+ type=str,
+ default='gpu',
+ help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_true', help="enable dygraph mode")
+ parser.add_argument(
+ "-e", "--epoch", default=10, type=int, help="number of epoch")
+ parser.add_argument(
+ '-lr',
+ '--base_learning_rate',
+ default=1e-3,
+ type=float,
+ metavar='LR',
+ help='initial learning rate')
+ parser.add_argument(
+ "--word_emb_dim",
+ default=128,
+ type=int,
+ help='word embedding dimension')
+ parser.add_argument(
+ "--grnn_hidden_dim", default=128, type=int, help="hidden dimension")
+ parser.add_argument(
+ "--bigru_num", default=2, type=int, help='the number of bi-rnn')
+ parser.add_argument("-elr", "--emb_learning_rate", default=1.0, type=float)
+ parser.add_argument("-clr", "--crf_learning_rate", default=1.0, type=float)
+ parser.add_argument(
+ "-b", "--batch_size", default=300, type=int, help="batch size")
+ parser.add_argument(
+ "--max_seq_len", default=126, type=int, help="max sequence length")
+ parser.add_argument(
+ "-n", "--num_devices", default=1, type=int, help="number of devices")
+ parser.add_argument(
+ "-r",
+ "--resume",
+ default=None,
+ type=str,
+ help="checkpoint path to resume")
+ parser.add_argument(
+ "-o",
+ "--save_dir",
+ default="./model",
+ type=str,
+ help="save model path")
+ parser.add_argument(
+ "-sf", "--save_freq", default=1, type=int, help="save frequency")
+ parser.add_argument(
+ "-ef", "--eval_freq", default=1, type=int, help="eval frequency")
+
+ args = parser.parse_args()
+ print(args)
+ main(args)
diff --git a/model.py b/model.py
index dea21bb98329404d02c10e2a563f21d76f7851e1..6fecbf1d29fa3c37ad3073fae0fcdcd819b52937 100644
--- a/model.py
+++ b/model.py
@@ -42,6 +42,14 @@ __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input', 'set_device']
def set_device(device):
+ """
+ Args:
+ device (str): specify device type, 'cpu' or 'gpu'.
+
+ Returns:
+ fluid.CUDAPlace or fluid.CPUPlace: Created GPU or CPU place.
+ """
+
assert isinstance(device, six.string_types) and device.lower() in ['cpu', 'gpu'], \
"Expected device in ['cpu', 'gpu'], but got {}".format(device)
@@ -114,9 +122,9 @@ class Loss(object):
def forward(self, outputs, labels):
raise NotImplementedError()
- def __call__(self, outputs, labels):
+ def __call__(self, outputs, labels=None):
labels = to_list(labels)
- if in_dygraph_mode():
+ if in_dygraph_mode() and labels:
labels = [to_variable(l) for l in labels]
losses = to_list(self.forward(to_list(outputs), labels))
if self.average:
@@ -410,7 +418,8 @@ class StaticGraphAdapter(object):
and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
- self.model._optimizer._learning_rate_map[prog] = lr_var
+ new_lr_var = prog.global_block().vars[lr_var.name]
+ self.model._optimizer._learning_rate_map[prog] = new_lr_var
losses = []
metrics = []
@@ -852,8 +861,6 @@ class Model(fluid.dygraph.Layer):
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict in static graph mode")
- if loss_function and not isinstance(labels, (list, Input)):
- raise TypeError("'labels' must be list in static graph mode")
metrics = metrics or []
for metric in to_list(metrics):
@@ -1083,7 +1090,11 @@ class Model(fluid.dygraph.Layer):
return eval_result
- def predict(self, test_data, batch_size=1, num_workers=0):
+ def predict(self,
+ test_data,
+ batch_size=1,
+ num_workers=0,
+ stack_outputs=True):
"""
FIXME: add more comments and usage
Args:
@@ -1096,6 +1107,12 @@ class Model(fluid.dygraph.Layer):
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
+ stack_output (bool): whether stack output field like a batch, as for an output
+ filed of a sample is in shape [X, Y], test_data contains N samples, predict
+ output field will be in shape [N, X, Y] if stack_output is True, and will
+ be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
+ is False. stack_outputs as False is used for LoDTensor output situation,
+ it is recommended set as True if outputs contains no LoDTensor. Default False
"""
if fluid.in_dygraph_mode():
@@ -1122,19 +1139,16 @@ class Model(fluid.dygraph.Layer):
if not isinstance(test_loader, Iterable):
loader = test_loader()
- outputs = None
+ outputs = []
for data in tqdm.tqdm(loader):
- if not fluid.in_dygraph_mode():
- data = data[0]
-
- outs = self.test(*data)
+ data = flatten(data)
+ outputs.append(self.test(data[:len(self._inputs)]))
- if outputs is None:
- outputs = outs
- else:
- outputs = [
- np.vstack([x, outs[i]]) for i, x in enumerate(outputs)
- ]
+ # NOTE: for lod tensor output, we should not stack outputs
+ # for stacking may loss its detail info
+ outputs = list(zip(*outputs))
+ if stack_outputs:
+ outputs = [np.stack(outs, axis=0) for outs in outputs]
self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ad506c20a7395108bb1999806d8667fbb074dd
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from . import resnet
+from . import vgg
+from . import mobilenetv1
+from . import mobilenetv2
+from . import darknet
+from . import yolov3
+from . import tsm
+
+from .resnet import *
+from .mobilenetv1 import *
+from .mobilenetv2 import *
+from .vgg import *
+from .darknet import *
+from .yolov3 import *
+from .tsm import *
+
+__all__ = resnet.__all__ \
+ + vgg.__all__ \
+ + mobilenetv1.__all__ \
+ + mobilenetv2.__all__ \
+ + darknet.__all__ \
+ + yolov3.__all__ \
+ + tsm.__all__
diff --git a/models/darknet.py b/models/darknet.py
new file mode 100755
index 0000000000000000000000000000000000000000..095cf7d63c628483b3b0842f4c54d81bba75ceb6
--- /dev/null
+++ b/models/darknet.py
@@ -0,0 +1,204 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import paddle.fluid as fluid
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.regularizer import L2Decay
+
+from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
+
+from model import Model
+from .download import get_weights_path
+
+__all__ = ['DarkNet53', 'ConvBNLayer', 'darknet53']
+
+# {num_layers: (url, md5)}
+pretrain_infos = {
+ 53: ('https://paddlemodels.bj.bcebos.com/hapi/darknet53.pdparams',
+ '2506357a5c31e865785112fc614a487d')
+}
+
+
+class ConvBNLayer(fluid.dygraph.Layer):
+ def __init__(self,
+ ch_in,
+ ch_out,
+ filter_size=3,
+ stride=1,
+ groups=1,
+ padding=0,
+ act="leaky"):
+ super(ConvBNLayer, self).__init__()
+
+ self.conv = Conv2D(
+ num_channels=ch_in,
+ num_filters=ch_out,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.Normal(0., 0.02)),
+ bias_attr=False,
+ act=None)
+ self.batch_norm = BatchNorm(
+ num_channels=ch_out,
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.Normal(0., 0.02),
+ regularizer=L2Decay(0.)),
+ bias_attr=ParamAttr(
+ initializer=fluid.initializer.Constant(0.0),
+ regularizer=L2Decay(0.)))
+
+ self.act = act
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.batch_norm(out)
+ if self.act == 'leaky':
+ out = fluid.layers.leaky_relu(x=out, alpha=0.1)
+ return out
+
+class DownSample(fluid.dygraph.Layer):
+ def __init__(self,
+ ch_in,
+ ch_out,
+ filter_size=3,
+ stride=2,
+ padding=1):
+
+ super(DownSample, self).__init__()
+
+ self.conv_bn_layer = ConvBNLayer(
+ ch_in=ch_in,
+ ch_out=ch_out,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding)
+ self.ch_out = ch_out
+ def forward(self, inputs):
+ out = self.conv_bn_layer(inputs)
+ return out
+
+class BasicBlock(fluid.dygraph.Layer):
+ def __init__(self, ch_in, ch_out):
+ super(BasicBlock, self).__init__()
+
+ self.conv1 = ConvBNLayer(
+ ch_in=ch_in,
+ ch_out=ch_out,
+ filter_size=1,
+ stride=1,
+ padding=0)
+ self.conv2 = ConvBNLayer(
+ ch_in=ch_out,
+ ch_out=ch_out*2,
+ filter_size=3,
+ stride=1,
+ padding=1)
+ def forward(self, inputs):
+ conv1 = self.conv1(inputs)
+ conv2 = self.conv2(conv1)
+ out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None)
+ return out
+
+class LayerWarp(fluid.dygraph.Layer):
+ def __init__(self, ch_in, ch_out, count):
+ super(LayerWarp,self).__init__()
+
+ self.basicblock0 = BasicBlock(ch_in, ch_out)
+ self.res_out_list = []
+ for i in range(1,count):
+ res_out = self.add_sublayer("basic_block_%d" % (i),
+ BasicBlock(
+ ch_out*2,
+ ch_out))
+ self.res_out_list.append(res_out)
+ self.ch_out = ch_out
+ def forward(self,inputs):
+ y = self.basicblock0(inputs)
+ for basic_block_i in self.res_out_list:
+ y = basic_block_i(y)
+ return y
+
+
+DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
+
+
+class DarkNet53(Model):
+ def __init__(self, num_layers=53, ch_in=3):
+ super(DarkNet53, self).__init__()
+ assert num_layers in DarkNet_cfg.keys(), \
+ "only support num_layers in {} currently" \
+ .format(DarkNet_cfg.keys())
+ self.stages = DarkNet_cfg[num_layers]
+ self.stages = self.stages[0:5]
+
+ self.conv0 = ConvBNLayer(
+ ch_in=ch_in,
+ ch_out=32,
+ filter_size=3,
+ stride=1,
+ padding=1)
+
+ self.downsample0 = DownSample(
+ ch_in=32,
+ ch_out=32 * 2)
+ self.darknet53_conv_block_list = []
+ self.downsample_list = []
+ ch_in = [64,128,256,512,1024]
+ for i, stage in enumerate(self.stages):
+ conv_block = self.add_sublayer(
+ "stage_%d" % (i),
+ LayerWarp(
+ int(ch_in[i]),
+ 32*(2**i),
+ stage))
+ self.darknet53_conv_block_list.append(conv_block)
+ for i in range(len(self.stages) - 1):
+ downsample = self.add_sublayer(
+ "stage_%d_downsample" % i,
+ DownSample(
+ ch_in = 32*(2**(i+1)),
+ ch_out = 32*(2**(i+2))))
+ self.downsample_list.append(downsample)
+
+ def forward(self,inputs):
+
+ out = self.conv0(inputs)
+ out = self.downsample0(out)
+ blocks = []
+ for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
+ out = conv_block_i(out)
+ blocks.append(out)
+ if i < len(self.stages) - 1:
+ out = self.downsample_list[i](out)
+ return blocks[-1:-4:-1]
+
+
+def _darknet(num_layers=53, input_channels=3, pretrained=True):
+ model = DarkNet53(num_layers, input_channels)
+ if pretrained:
+ assert num_layers in pretrain_infos.keys(), \
+ "DarkNet{} do not have pretrained weights now, " \
+ "pretrained should be set as False".format(num_layers)
+ weight_path = get_weights_path(*(pretrain_infos[num_layers]))
+ assert weight_path.endswith('.pdparams'), \
+ "suffix of weight must be .pdparams"
+ model.load(weight_path[:-9])
+ return model
+
+
+def darknet53(input_channels=3, pretrained=True):
+ return _darknet(53, input_channels, pretrained)
diff --git a/models/download.py b/models/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..10d3fba390647c494448b83295901a8973d2aba8
--- /dev/null
+++ b/models/download.py
@@ -0,0 +1,147 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import os.path as osp
+import shutil
+import requests
+import tqdm
+import hashlib
+import time
+
+from paddle.fluid.dygraph.parallel import ParallelEnv
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['get_weights_path']
+
+WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
+
+DOWNLOAD_RETRY_LIMIT = 3
+
+
+def get_weights_path(url, md5sum=None):
+ """Get weights path from WEIGHT_HOME, if not exists,
+ download it from url.
+ """
+ path, _ = get_path(url, WEIGHTS_HOME, md5sum)
+ return path
+
+
+def map_path(url, root_dir):
+ # parse path after download under root_dir
+ fname = osp.split(url)[-1]
+ fpath = fname
+ return osp.join(root_dir, fpath)
+
+
+def get_path(url, root_dir, md5sum=None, check_exist=True):
+ """ Download from given url to root_dir.
+ if file or directory specified by url is exists under
+ root_dir, return the path directly, otherwise download
+ from url and decompress it, return the path.
+
+ url (str): download url
+ root_dir (str): root dir for downloading, it should be
+ WEIGHTS_HOME or DATASET_HOME
+ md5sum (str): md5 sum of download package
+ """
+ # parse path after download to decompress under root_dir
+ fullpath = map_path(url, root_dir)
+
+ exist_flag = False
+ if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
+ exist_flag = True
+ if ParallelEnv().local_rank == 0:
+ logger.info("Found {}".format(fullpath))
+ else:
+ if ParallelEnv().local_rank == 0:
+ fullpath = _download(url, root_dir, md5sum)
+ else:
+ while not os.path.exists(fullpath):
+ time.sleep(1)
+ return fullpath, exist_flag
+
+
+def _download(url, path, md5sum=None):
+ """
+ Download from url, save to path.
+
+ url (str): download url
+ path (str): download to given path
+ """
+ if not osp.exists(path):
+ os.makedirs(path)
+
+ fname = osp.split(url)[-1]
+ fullname = osp.join(path, fname)
+ retry_cnt = 0
+
+ while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
+ if retry_cnt < DOWNLOAD_RETRY_LIMIT:
+ retry_cnt += 1
+ else:
+ raise RuntimeError("Download from {} failed. "
+ "Retry limit reached".format(url))
+ if ParallelEnv().local_rank == 0:
+ logger.info("Downloading {} from {}".format(fname, url))
+
+ req = requests.get(url, stream=True)
+ if req.status_code != 200:
+ raise RuntimeError("Downloading from {} failed with code "
+ "{}!".format(url, req.status_code))
+
+ # For protecting download interupted, download to
+ # tmp_fullname firstly, move tmp_fullname to fullname
+ # after download finished
+ tmp_fullname = fullname + "_tmp"
+ total_size = req.headers.get('content-length')
+ with open(tmp_fullname, 'wb') as f:
+ if total_size:
+ for chunk in tqdm.tqdm(
+ req.iter_content(chunk_size=1024),
+ total=(int(total_size) + 1023) // 1024,
+ unit='KB'):
+ f.write(chunk)
+ else:
+ for chunk in req.iter_content(chunk_size=1024):
+ if chunk:
+ f.write(chunk)
+ shutil.move(tmp_fullname, fullname)
+
+ return fullname
+
+
+def _md5check(fullname, md5sum=None):
+ if md5sum is None:
+ return True
+ if ParallelEnv().local_rank == 0:
+ logger.info("File {} md5 checking...".format(fullname))
+ md5 = hashlib.md5()
+ with open(fullname, 'rb') as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ md5.update(chunk)
+ calc_md5sum = md5.hexdigest()
+
+ if calc_md5sum != md5sum:
+ if ParallelEnv().local_rank == 0:
+ logger.info("File {} md5 check failed, {}(calc) != "
+ "{}(base)".format(fullname, calc_md5sum, md5sum))
+ return False
+ return True
diff --git a/models/mobilenetv1.py b/models/mobilenetv1.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2e7959b1b9bf78e30ac80f874262234f66ff22e
--- /dev/null
+++ b/models/mobilenetv1.py
@@ -0,0 +1,266 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.initializer import MSRA
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
+
+from model import Model
+from .download import get_weights_path
+
+__all__ = ['MobileNetV1', 'mobilenet_v1']
+
+model_urls = {
+ 'mobilenetv1_1.0':
+ ('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams',
+ 'bf0d25cb0bed1114d9dac9384ce2b4a6')
+}
+
+
+class ConvBNLayer(fluid.dygraph.Layer):
+ def __init__(self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ channels=None,
+ num_groups=1,
+ act='relu',
+ use_cudnn=True,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+
+ self._conv = Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ act=None,
+ use_cudnn=use_cudnn,
+ param_attr=ParamAttr(
+ initializer=MSRA(), name=self.full_name() + "_weights"),
+ bias_attr=False)
+
+ self._batch_norm = BatchNorm(
+ num_filters,
+ act=act,
+ param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
+ bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
+ moving_mean_name=self.full_name() + "_bn" + '_mean',
+ moving_variance_name=self.full_name() + "_bn" + '_variance')
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ return y
+
+
+class DepthwiseSeparable(fluid.dygraph.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters1,
+ num_filters2,
+ num_groups,
+ stride,
+ scale,
+ name=None):
+ super(DepthwiseSeparable, self).__init__()
+
+ self._depthwise_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=int(num_filters1 * scale),
+ filter_size=3,
+ stride=stride,
+ padding=1,
+ num_groups=int(num_groups * scale),
+ use_cudnn=False)
+
+ self._pointwise_conv = ConvBNLayer(
+ num_channels=int(num_filters1 * scale),
+ filter_size=1,
+ num_filters=int(num_filters2 * scale),
+ stride=1,
+ padding=0)
+
+ def forward(self, inputs):
+ y = self._depthwise_conv(inputs)
+ y = self._pointwise_conv(y)
+ return y
+
+
+class MobileNetV1(Model):
+ """MobileNetV1 model from
+ `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" `_.
+
+ Args:
+ scale (float): scale of channels in each layer. Default: 1.0.
+ class_dim (int): output dim of last fc layer. Default: 1000.
+ """
+
+ def __init__(self, scale=1.0, class_dim=1000):
+ super(MobileNetV1, self).__init__()
+ self.scale = scale
+ self.dwsl = []
+
+ self.conv1 = ConvBNLayer(
+ num_channels=3,
+ filter_size=3,
+ channels=3,
+ num_filters=int(32 * scale),
+ stride=2,
+ padding=1)
+
+ dws21 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(32 * scale),
+ num_filters1=32,
+ num_filters2=64,
+ num_groups=32,
+ stride=1,
+ scale=scale),
+ name="conv2_1")
+ self.dwsl.append(dws21)
+
+ dws22 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(64 * scale),
+ num_filters1=64,
+ num_filters2=128,
+ num_groups=64,
+ stride=2,
+ scale=scale),
+ name="conv2_2")
+ self.dwsl.append(dws22)
+
+ dws31 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=128,
+ num_groups=128,
+ stride=1,
+ scale=scale),
+ name="conv3_1")
+ self.dwsl.append(dws31)
+
+ dws32 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=256,
+ num_groups=128,
+ stride=2,
+ scale=scale),
+ name="conv3_2")
+ self.dwsl.append(dws32)
+
+ dws41 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=256,
+ num_groups=256,
+ stride=1,
+ scale=scale),
+ name="conv4_1")
+ self.dwsl.append(dws41)
+
+ dws42 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=512,
+ num_groups=256,
+ stride=2,
+ scale=scale),
+ name="conv4_2")
+ self.dwsl.append(dws42)
+
+ for i in range(5):
+ tmp = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=512,
+ num_groups=512,
+ stride=1,
+ scale=scale),
+ name="conv5_" + str(i + 1))
+ self.dwsl.append(tmp)
+
+ dws56 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=1024,
+ num_groups=512,
+ stride=2,
+ scale=scale),
+ name="conv5_6")
+ self.dwsl.append(dws56)
+
+ dws6 = self.add_sublayer(
+ sublayer=DepthwiseSeparable(
+ num_channels=int(1024 * scale),
+ num_filters1=1024,
+ num_filters2=1024,
+ num_groups=1024,
+ stride=1,
+ scale=scale),
+ name="conv6")
+ self.dwsl.append(dws6)
+
+ self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
+
+ self.out = Linear(
+ int(1024 * scale),
+ class_dim,
+ act='softmax',
+ param_attr=ParamAttr(
+ initializer=MSRA(), name=self.full_name() + "fc7_weights"),
+ bias_attr=ParamAttr(name="fc7_offset"))
+
+ def forward(self, inputs):
+ y = self.conv1(inputs)
+ for dws in self.dwsl:
+ y = dws(y)
+ y = self.pool2d_avg(y)
+ y = fluid.layers.reshape(y, shape=[-1, 1024])
+ y = self.out(y)
+ return y
+
+
+def _mobilenet(arch, pretrained=False, **kwargs):
+ model = MobileNetV1(**kwargs)
+ if pretrained:
+ assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
+ arch)
+ weight_path = get_weights_path(model_urls[arch][0],
+ model_urls[arch][1])
+ assert weight_path.endswith(
+ '.pdparams'), "suffix of weight must be .pdparams"
+ model.load(weight_path[:-9])
+
+ return model
+
+
+def mobilenet_v1(pretrained=False, scale=1.0):
+ model = _mobilenet('mobilenetv1_' + str(scale), pretrained, scale=scale)
+ return model
diff --git a/models/mobilenetv2.py b/models/mobilenetv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0079ee79d932a76dc75548b7641526bc80019011
--- /dev/null
+++ b/models/mobilenetv2.py
@@ -0,0 +1,252 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
+
+from model import Model
+from .download import get_weights_path
+
+__all__ = ['MobileNetV2', 'mobilenet_v2']
+
+model_urls = {
+ 'mobilenetv2_1.0':
+ ('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams',
+ '8ff74f291f72533f2a7956a4efff9d88')
+}
+
+
+class ConvBNLayer(fluid.dygraph.Layer):
+ def __init__(self,
+ num_channels,
+ filter_size,
+ num_filters,
+ stride,
+ padding,
+ channels=None,
+ num_groups=1,
+ use_cudnn=True):
+ super(ConvBNLayer, self).__init__()
+
+ tmp_param = ParamAttr(name=self.full_name() + "_weights")
+ self._conv = Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ act=None,
+ use_cudnn=use_cudnn,
+ param_attr=tmp_param,
+ bias_attr=False)
+
+ self._batch_norm = BatchNorm(
+ num_filters,
+ param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
+ bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
+ moving_mean_name=self.full_name() + "_bn" + '_mean',
+ moving_variance_name=self.full_name() + "_bn" + '_variance')
+
+ def forward(self, inputs, if_act=True):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ if if_act:
+ y = fluid.layers.relu6(y)
+ return y
+
+
+class InvertedResidualUnit(fluid.dygraph.Layer):
+ def __init__(
+ self,
+ num_channels,
+ num_in_filter,
+ num_filters,
+ stride,
+ filter_size,
+ padding,
+ expansion_factor, ):
+ super(InvertedResidualUnit, self).__init__()
+ num_expfilter = int(round(num_in_filter * expansion_factor))
+ self._expand_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_expfilter,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ num_groups=1)
+
+ self._bottleneck_conv = ConvBNLayer(
+ num_channels=num_expfilter,
+ num_filters=num_expfilter,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ num_groups=num_expfilter,
+ use_cudnn=False)
+
+ self._linear_conv = ConvBNLayer(
+ num_channels=num_expfilter,
+ num_filters=num_filters,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ num_groups=1)
+
+ def forward(self, inputs, ifshortcut):
+ y = self._expand_conv(inputs, if_act=True)
+ y = self._bottleneck_conv(y, if_act=True)
+ y = self._linear_conv(y, if_act=False)
+ if ifshortcut:
+ y = fluid.layers.elementwise_add(inputs, y)
+ return y
+
+
+class InvresiBlocks(fluid.dygraph.Layer):
+ def __init__(self, in_c, t, c, n, s):
+ super(InvresiBlocks, self).__init__()
+
+ self._first_block = InvertedResidualUnit(
+ num_channels=in_c,
+ num_in_filter=in_c,
+ num_filters=c,
+ stride=s,
+ filter_size=3,
+ padding=1,
+ expansion_factor=t)
+
+ self._inv_blocks = []
+ for i in range(1, n):
+ tmp = self.add_sublayer(
+ sublayer=InvertedResidualUnit(
+ num_channels=c,
+ num_in_filter=c,
+ num_filters=c,
+ stride=1,
+ filter_size=3,
+ padding=1,
+ expansion_factor=t),
+ name=self.full_name() + "_" + str(i + 1))
+ self._inv_blocks.append(tmp)
+
+ def forward(self, inputs):
+ y = self._first_block(inputs, ifshortcut=False)
+ for inv_block in self._inv_blocks:
+ y = inv_block(y, ifshortcut=True)
+ return y
+
+
+class MobileNetV2(Model):
+ """MobileNetV2 model from
+ `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
+
+ Args:
+ scale (float): scale of channels in each layer. Default: 1.0.
+ class_dim (int): output dim of last fc layer. Default: 1000.
+ """
+
+ def __init__(self, scale=1.0, class_dim=1000):
+ super(MobileNetV2, self).__init__()
+ self.scale = scale
+ self.class_dim = class_dim
+
+ bottleneck_params_list = [
+ (1, 16, 1, 1),
+ (6, 24, 2, 2),
+ (6, 32, 3, 2),
+ (6, 64, 4, 2),
+ (6, 96, 3, 1),
+ (6, 160, 3, 2),
+ (6, 320, 1, 1),
+ ]
+
+ #1. conv1
+ self._conv1 = ConvBNLayer(
+ num_channels=3,
+ num_filters=int(32 * scale),
+ filter_size=3,
+ stride=2,
+ padding=1)
+
+ #2. bottleneck sequences
+ self._invl = []
+ i = 1
+ in_c = int(32 * scale)
+ for layer_setting in bottleneck_params_list:
+ t, c, n, s = layer_setting
+ i += 1
+ tmp = self.add_sublayer(
+ sublayer=InvresiBlocks(
+ in_c=in_c, t=t, c=int(c * scale), n=n, s=s),
+ name='conv' + str(i))
+ self._invl.append(tmp)
+ in_c = int(c * scale)
+
+ #3. last_conv
+ self._out_c = int(1280 * scale) if scale > 1.0 else 1280
+ self._conv9 = ConvBNLayer(
+ num_channels=in_c,
+ num_filters=self._out_c,
+ filter_size=1,
+ stride=1,
+ padding=0)
+
+ #4. pool
+ self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
+
+ #5. fc
+ tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
+ self._fc = Linear(
+ self._out_c,
+ class_dim,
+ act='softmax',
+ param_attr=tmp_param,
+ bias_attr=ParamAttr(name="fc10_offset"))
+
+ def forward(self, inputs):
+ y = self._conv1(inputs, if_act=True)
+ for inv in self._invl:
+ y = inv(y)
+ y = self._conv9(y, if_act=True)
+ y = self._pool2d_avg(y)
+ y = fluid.layers.reshape(y, shape=[-1, self._out_c])
+ y = self._fc(y)
+ return y
+
+
+def _mobilenet(arch, pretrained=False, **kwargs):
+ model = MobileNetV2(**kwargs)
+ if pretrained:
+ assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
+ arch)
+ weight_path = get_weights_path(model_urls[arch][0],
+ model_urls[arch][1])
+ assert weight_path.endswith(
+ '.pdparams'), "suffix of weight must be .pdparams"
+ model.load(weight_path[:-9])
+
+ return model
+
+
+def mobilenet_v2(pretrained=False, scale=1.0):
+ """MobileNetV2
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = _mobilenet('mobilenetv2_' + str(scale), pretrained, scale=scale)
+ return model
diff --git a/models/resnet.py b/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2cf4b603e6890510e9fafb65bcb96ab52cd2771
--- /dev/null
+++ b/models/resnet.py
@@ -0,0 +1,293 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle.fluid as fluid
+
+from paddle.fluid.layer_helper import LayerHelper
+from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
+from paddle.fluid.dygraph.container import Sequential
+
+from model import Model
+from .download import get_weights_path
+
+__all__ = [
+ 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
+]
+
+model_urls = {
+ 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
+ '0884c9087266496c41c60d14a96f8530')
+}
+
+
+class ConvBNLayer(fluid.dygraph.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ filter_size,
+ stride=1,
+ groups=1,
+ act=None):
+ super(ConvBNLayer, self).__init__()
+
+ self._conv = Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=(filter_size - 1) // 2,
+ groups=groups,
+ act=None,
+ bias_attr=False)
+
+ self._batch_norm = BatchNorm(num_filters, act=act)
+
+ def forward(self, inputs):
+ x = self._conv(inputs)
+ x = self._batch_norm(x)
+
+ return x
+
+
+class BasicBlock(fluid.dygraph.Layer):
+
+ expansion = 1
+
+ def __init__(self, num_channels, num_filters, stride, shortcut=True):
+ super(BasicBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=3,
+ act='relu')
+ self.conv1 = ConvBNLayer(
+ num_channels=num_filters,
+ num_filters=num_filters,
+ filter_size=3,
+ stride=stride,
+ act='relu')
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=1,
+ stride=stride)
+
+ self.shortcut = shortcut
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+
+ y = short + conv1
+
+ return fluid.layers.relu(y)
+
+
+class BottleneckBlock(fluid.dygraph.Layer):
+
+ expansion = 4
+
+ def __init__(self, num_channels, num_filters, stride, shortcut=True):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=1,
+ act='relu')
+ self.conv1 = ConvBNLayer(
+ num_channels=num_filters,
+ num_filters=num_filters,
+ filter_size=3,
+ stride=stride,
+ act='relu')
+ self.conv2 = ConvBNLayer(
+ num_channels=num_filters,
+ num_filters=num_filters * self.expansion,
+ filter_size=1,
+ act=None)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_filters * self.expansion,
+ filter_size=1,
+ stride=stride)
+
+ self.shortcut = shortcut
+
+ self._num_channels_out = num_filters * self.expansion
+
+ def forward(self, inputs):
+ x = self.conv0(inputs)
+ conv1 = self.conv1(x)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+
+ x = fluid.layers.elementwise_add(x=short, y=conv2)
+
+ return fluid.layers.relu(x)
+
+
+class ResNet(Model):
+ """ResNet model from
+ `"Deep Residual Learning for Image Recognition" `_
+
+ Args:
+ Block (BasicBlock|BottleneckBlock): block module of model.
+ depth (int): layers of resnet, default: 50.
+ num_classes (int): output dim of last fc layer, default: 1000.
+ """
+
+ def __init__(self, Block, depth=50, num_classes=1000):
+ super(ResNet, self).__init__()
+
+ layer_config = {
+ 18: [2, 2, 2, 2],
+ 34: [3, 4, 6, 3],
+ 50: [3, 4, 6, 3],
+ 101: [3, 4, 23, 3],
+ 152: [3, 8, 36, 3],
+ }
+ assert depth in layer_config.keys(), \
+ "supported depth are {} but input layer is {}".format(
+ layer_config.keys(), depth)
+
+ layers = layer_config[depth]
+
+ in_channels = 64
+ out_channels = [64, 128, 256, 512]
+
+ self.conv = ConvBNLayer(
+ num_channels=3,
+ num_filters=64,
+ filter_size=7,
+ stride=2,
+ act='relu')
+ self.pool = Pool2D(
+ pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
+
+ self.layers = []
+ for idx, num_blocks in enumerate(layers):
+ blocks = []
+ shortcut = False
+ for b in range(num_blocks):
+ if b == 1:
+ in_channels = out_channels[idx] * Block.expansion
+ block = Block(
+ num_channels=in_channels,
+ num_filters=out_channels[idx],
+ stride=2 if b == 0 and idx != 0 else 1,
+ shortcut=shortcut)
+ blocks.append(block)
+ shortcut = True
+ layer = self.add_sublayer("layer_{}".format(idx),
+ Sequential(*blocks))
+ self.layers.append(layer)
+
+ self.global_pool = Pool2D(
+ pool_size=7, pool_type='avg', global_pooling=True)
+
+ stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0)
+ self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1
+ self.fc = Linear(
+ self.fc_input_dim,
+ num_classes,
+ act='softmax',
+ param_attr=fluid.param_attr.ParamAttr(
+ initializer=fluid.initializer.Uniform(-stdv, stdv)))
+
+ def forward(self, inputs):
+ x = self.conv(inputs)
+ x = self.pool(x)
+ for layer in self.layers:
+ x = layer(x)
+ x = self.global_pool(x)
+ x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
+ x = self.fc(x)
+ return x
+
+
+def _resnet(arch, Block, depth, pretrained):
+ model = ResNet(Block, depth)
+ if pretrained:
+ assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
+ arch)
+ weight_path = get_weights_path(model_urls[arch][0],
+ model_urls[arch][1])
+ assert weight_path.endswith(
+ '.pdparams'), "suffix of weight must be .pdparams"
+ model.load(weight_path[:-9])
+ return model
+
+
+def resnet18(pretrained=False):
+ """ResNet 18-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _resnet('resnet18', BasicBlock, 18, pretrained)
+
+
+def resnet34(pretrained=False):
+ """ResNet 34-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _resnet('resnet34', BasicBlock, 34, pretrained)
+
+
+def resnet50(pretrained=False):
+ """ResNet 50-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _resnet('resnet50', BottleneckBlock, 50, pretrained)
+
+
+def resnet101(pretrained=False):
+ """ResNet 101-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _resnet('resnet101', BottleneckBlock, 101, pretrained)
+
+
+def resnet152(pretrained=False):
+ """ResNet 152-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _resnet('resnet152', BottleneckBlock, 152, pretrained)
diff --git a/models/tsm.py b/models/tsm.py
new file mode 100644
index 0000000000000000000000000000000000000000..91acd16b288e7e0803e0448f0e93a484b0b92c17
--- /dev/null
+++ b/models/tsm.py
@@ -0,0 +1,204 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import math
+import paddle.fluid as fluid
+from paddle.fluid.layer_helper import LayerHelper
+from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
+
+from model import Model
+from .download import get_weights_path
+
+__all__ = ["TSM_ResNet", "tsm_resnet50"]
+
+# {num_layers: (url, md5)}
+pretrain_infos = {
+ 50: ('https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams',
+ '5755dc538e422589f417f7b38d7cc3c7')
+}
+
+
+class ConvBNLayer(fluid.dygraph.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ filter_size,
+ stride=1,
+ groups=1,
+ act=None):
+ super(ConvBNLayer, self).__init__()
+
+ self._conv = Conv2D(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=(filter_size - 1) // 2,
+ groups=None,
+ act=None,
+ param_attr=fluid.param_attr.ParamAttr(),
+ bias_attr=False)
+
+ self._batch_norm = BatchNorm(
+ num_filters,
+ act=act,
+ param_attr=fluid.param_attr.ParamAttr(),
+ bias_attr=fluid.param_attr.ParamAttr())
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+
+ return y
+
+
+class BottleneckBlock(fluid.dygraph.Layer):
+ def __init__(self,
+ num_channels,
+ num_filters,
+ stride,
+ shortcut=True,
+ seg_num=8):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv0 = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=1,
+ act='relu')
+ self.conv1 = ConvBNLayer(
+ num_channels=num_filters,
+ num_filters=num_filters,
+ filter_size=3,
+ stride=stride,
+ act='relu')
+ self.conv2 = ConvBNLayer(
+ num_channels=num_filters,
+ num_filters=num_filters * 4,
+ filter_size=1,
+ act=None)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=num_filters * 4,
+ filter_size=1,
+ stride=stride)
+ self.shortcut = shortcut
+ self.seg_num = seg_num
+ self._num_channels_out = int(num_filters * 4)
+
+ def forward(self, inputs):
+ shifts = fluid.layers.temporal_shift(inputs, self.seg_num, 1.0 / 8)
+ y = self.conv0(shifts)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+ y = fluid.layers.elementwise_add(x=short, y=conv2, act="relu")
+ return y
+
+
+class TSM_ResNet(Model):
+ """
+ TSM network with ResNet as backbone
+
+ Args:
+ num_layers (int): ResNet layer number, only support 50 currently.
+ Default 50.
+ seg_num (int): segment number of each video sample. Default 8.
+ num_classes (int): video class number. Default 400.
+ """
+ def __init__(self, num_layers=50, seg_num=8, num_classes=400):
+ super(TSM_ResNet, self).__init__()
+
+ self.layers = num_layers
+ self.seg_num = seg_num
+ self.class_dim = num_classes
+
+ if self.layers == 50:
+ depth = [3, 4, 6, 3]
+ else:
+ raise NotImplementedError
+ num_filters = [64, 128, 256, 512]
+
+ self.conv = ConvBNLayer(
+ num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
+ self.pool2d_max = Pool2D(
+ pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
+
+ self.bottleneck_block_list = []
+ num_channels = 64
+
+ for block in range(len(depth)):
+ shortcut = False
+ for i in range(depth[block]):
+ bottleneck_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BottleneckBlock(
+ num_channels=num_channels,
+ num_filters=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut,
+ seg_num=self.seg_num))
+ num_channels = int(bottleneck_block._num_channels_out)
+ self.bottleneck_block_list.append(bottleneck_block)
+ shortcut = True
+ self.pool2d_avg = Pool2D(
+ pool_size=7, pool_type='avg', global_pooling=True)
+
+ stdv = 1.0 / math.sqrt(2048 * 1.0)
+
+ self.out = Linear(
+ 2048,
+ self.class_dim,
+ act="softmax",
+ param_attr=fluid.param_attr.ParamAttr(
+ initializer=fluid.initializer.Uniform(-stdv, stdv)),
+ bias_attr=fluid.param_attr.ParamAttr(
+ learning_rate=2.0, regularizer=fluid.regularizer.L2Decay(0.)))
+
+ def forward(self, inputs):
+ y = fluid.layers.reshape(
+ inputs, [-1, inputs.shape[2], inputs.shape[3], inputs.shape[4]])
+ y = self.conv(y)
+ y = self.pool2d_max(y)
+ for bottleneck_block in self.bottleneck_block_list:
+ y = bottleneck_block(y)
+ y = self.pool2d_avg(y)
+ y = fluid.layers.dropout(y, dropout_prob=0.5)
+ y = fluid.layers.reshape(y, [-1, self.seg_num, y.shape[1]])
+ y = fluid.layers.reduce_mean(y, dim=1)
+ y = fluid.layers.reshape(y, shape=[-1, 2048])
+ y = self.out(y)
+ return y
+
+
+def _tsm_resnet(num_layers, seg_num=8, num_classes=400, pretrained=True):
+ model = TSM_ResNet(num_layers, seg_num, num_classes)
+ if pretrained:
+ assert num_layers in pretrain_infos.keys(), \
+ "TSM-ResNet{} do not have pretrained weights now, " \
+ "pretrained should be set as False".format(num_layers)
+ weight_path = get_weights_path(*(pretrain_infos[num_layers]))
+ assert weight_path.endswith('.pdparams'), \
+ "suffix of weight must be .pdparams"
+ model.load(weight_path[:-9])
+ return model
+
+
+def tsm_resnet50(seg_num=8, num_classes=400, pretrained=True):
+ return _tsm_resnet(50, seg_num, num_classes, pretrained)
diff --git a/models/vgg.py b/models/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8ca21f0c370c1963b1c7c61aca101abe63d179b
--- /dev/null
+++ b/models/vgg.py
@@ -0,0 +1,200 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
+from paddle.fluid.dygraph.container import Sequential
+
+from model import Model
+from .download import get_weights_path
+
+__all__ = [
+ 'VGG',
+ 'vgg11',
+ 'vgg11_bn',
+ 'vgg13',
+ 'vgg13_bn',
+ 'vgg16',
+ 'vgg16_bn',
+ 'vgg19_bn',
+ 'vgg19',
+]
+
+model_urls = {
+ 'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
+ 'c788f453a3b999063e8da043456281ee')
+}
+
+
+class Classifier(fluid.dygraph.Layer):
+ def __init__(self, num_classes):
+ super(Classifier, self).__init__()
+ self.linear1 = Linear(512 * 7 * 7, 4096)
+ self.linear2 = Linear(4096, 4096)
+ self.linear3 = Linear(4096, num_classes, act='softmax')
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = fluid.layers.relu(x)
+ x = fluid.layers.dropout(x, 0.5)
+ x = self.linear2(x)
+ x = fluid.layers.relu(x)
+ x = fluid.layers.dropout(x, 0.5)
+ out = self.linear3(x)
+ return out
+
+
+class VGG(Model):
+ """VGG model from
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
+
+ Args:
+ features (fluid.dygraph.Layer): vgg features create by function make_layers.
+ num_classes (int): output dim of last fc layer. Default: 1000.
+ """
+
+ def __init__(self, features, num_classes=1000):
+ super(VGG, self).__init__()
+ self.features = features
+ classifier = Classifier(num_classes)
+ self.classifier = self.add_sublayer("classifier",
+ Sequential(classifier))
+
+ def forward(self, x):
+ x = self.features(x)
+ x = fluid.layers.flatten(x, 1)
+ x = self.classifier(x)
+ return x
+
+
+def make_layers(cfg, batch_norm=False):
+ layers = []
+ in_channels = 3
+
+ for v in cfg:
+ if v == 'M':
+ layers += [Pool2D(pool_size=2, pool_stride=2)]
+ else:
+ if batch_norm:
+ conv2d = Conv2D(in_channels, v, filter_size=3, padding=1)
+ layers += [conv2d, BatchNorm(v, act='relu')]
+ else:
+ conv2d = Conv2D(
+ in_channels, v, filter_size=3, padding=1, act='relu')
+ layers += [conv2d]
+ in_channels = v
+ return Sequential(*layers)
+
+
+cfgs = {
+ 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'B':
+ [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'D': [
+ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
+ 512, 512, 512, 'M'
+ ],
+ 'E': [
+ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512,
+ 512, 'M', 512, 512, 512, 512, 'M'
+ ],
+}
+
+
+def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
+ model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
+
+ if pretrained:
+ assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
+ arch)
+ weight_path = get_weights_path(model_urls[arch][0],
+ model_urls[arch][1])
+ assert weight_path.endswith(
+ '.pdparams'), "suffix of weight must be .pdparams"
+ model.load(weight_path[:-9])
+
+ return model
+
+
+def vgg11(pretrained=False, **kwargs):
+ """VGG 11-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg11', 'A', False, pretrained, **kwargs)
+
+
+def vgg11_bn(pretrained=False, **kwargs):
+ """VGG 11-layer model with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg11_bn', 'A', True, pretrained, **kwargs)
+
+
+def vgg13(pretrained=False, **kwargs):
+ """VGG 13-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg13', 'B', False, pretrained, **kwargs)
+
+
+def vgg13_bn(pretrained=False, **kwargs):
+ """VGG 13-layer model with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg13_bn', 'B', True, pretrained, **kwargs)
+
+
+def vgg16(pretrained=False, **kwargs):
+ """VGG 16-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg16', 'D', False, pretrained, **kwargs)
+
+
+def vgg16_bn(pretrained=False, **kwargs):
+ """VGG 16-layer with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg16_bn', 'D', True, pretrained, **kwargs)
+
+
+def vgg19(pretrained=False, **kwargs):
+ """VGG 19-layer model
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg19', 'E', False, pretrained, **kwargs)
+
+
+def vgg19_bn(pretrained=False, **kwargs):
+ """VGG 19-layer model with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ return _vgg('vgg19_bn', 'E', True, pretrained, **kwargs)
diff --git a/models/yolov3.py b/models/yolov3.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2bbc88ee27cb08269bd2a986ff7b55b4f199999
--- /dev/null
+++ b/models/yolov3.py
@@ -0,0 +1,250 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import paddle.fluid as fluid
+from paddle.fluid.dygraph.nn import Conv2D
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.regularizer import L2Decay
+
+from model import Model, Loss
+from .darknet import darknet53, ConvBNLayer
+from .download import get_weights_path
+
+__all__ = ['YoloLoss', 'YOLOv3', 'yolov3_darknet53']
+
+# {num_layers: (url, md5)}
+pretrain_infos = {
+ 53: ('https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams',
+ 'aed7dd45124ff2e844ae3bd5ba6c91d2')
+}
+
+
+class YoloDetectionBlock(fluid.dygraph.Layer):
+ def __init__(self, ch_in, channel):
+ super(YoloDetectionBlock, self).__init__()
+
+ assert channel % 2 == 0, \
+ "channel {} cannot be divided by 2".format(channel)
+
+ self.conv0 = ConvBNLayer(
+ ch_in=ch_in,
+ ch_out=channel,
+ filter_size=1,
+ stride=1,
+ padding=0)
+ self.conv1 = ConvBNLayer(
+ ch_in=channel,
+ ch_out=channel*2,
+ filter_size=3,
+ stride=1,
+ padding=1)
+ self.conv2 = ConvBNLayer(
+ ch_in=channel*2,
+ ch_out=channel,
+ filter_size=1,
+ stride=1,
+ padding=0)
+ self.conv3 = ConvBNLayer(
+ ch_in=channel,
+ ch_out=channel*2,
+ filter_size=3,
+ stride=1,
+ padding=1)
+ self.route = ConvBNLayer(
+ ch_in=channel*2,
+ ch_out=channel,
+ filter_size=1,
+ stride=1,
+ padding=0)
+ self.tip = ConvBNLayer(
+ ch_in=channel,
+ ch_out=channel*2,
+ filter_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, inputs):
+ out = self.conv0(inputs)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ out = self.conv3(out)
+ route = self.route(out)
+ tip = self.tip(route)
+ return route, tip
+
+
+class YOLOv3(Model):
+ def __init__(self, num_classes=80, model_mode='train'):
+ super(YOLOv3, self).__init__()
+ self.num_classes = num_classes
+ assert str.lower(model_mode) in ['train', 'eval', 'test'], \
+ "model_mode should be 'train' 'eval' or 'test', but got " \
+ "{}".format(model_mode)
+ self.model_mode = str.lower(model_mode)
+ self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
+ 59, 119, 116, 90, 156, 198, 373, 326]
+ self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+ self.valid_thresh = 0.005
+ self.nms_thresh = 0.45
+ self.nms_topk = 400
+ self.nms_posk = 100
+ self.draw_thresh = 0.5
+
+ self.backbone = darknet53(pretrained=(model_mode=='train'))
+ self.block_outputs = []
+ self.yolo_blocks = []
+ self.route_blocks = []
+
+ for idx, num_chan in enumerate([1024, 768, 384]):
+ yolo_block = self.add_sublayer(
+ "yolo_detecton_block_{}".format(idx),
+ YoloDetectionBlock(num_chan, 512 // (2**idx)))
+ self.yolo_blocks.append(yolo_block)
+
+ num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5)
+
+ block_out = self.add_sublayer(
+ "block_out_{}".format(idx),
+ Conv2D(num_channels=1024 // (2**idx),
+ num_filters=num_filters,
+ filter_size=1,
+ act=None,
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.Normal(0., 0.02)),
+ bias_attr=ParamAttr(
+ initializer=fluid.initializer.Constant(0.0),
+ regularizer=L2Decay(0.))))
+ self.block_outputs.append(block_out)
+ if idx < 2:
+ route = self.add_sublayer(
+ "route2_{}".format(idx),
+ ConvBNLayer(ch_in=512 // (2**idx),
+ ch_out=256 // (2**idx),
+ filter_size=1,
+ act='leaky_relu'))
+ self.route_blocks.append(route)
+
+ def forward(self, img_info, inputs):
+ outputs = []
+ boxes = []
+ scores = []
+ downsample = 32
+
+ feats = self.backbone(inputs)
+ route = None
+ for idx, feat in enumerate(feats):
+ if idx > 0:
+ feat = fluid.layers.concat(input=[route, feat], axis=1)
+ route, tip = self.yolo_blocks[idx](feat)
+ block_out = self.block_outputs[idx](tip)
+ outputs.append(block_out)
+
+ if idx < 2:
+ route = self.route_blocks[idx](route)
+ route = fluid.layers.resize_nearest(route, scale=2)
+
+ if self.model_mode != 'train':
+ anchor_mask = self.anchor_masks[idx]
+ mask_anchors = []
+ for m in anchor_mask:
+ mask_anchors.append(self.anchors[2 * m])
+ mask_anchors.append(self.anchors[2 * m + 1])
+ img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
+ img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
+ b, s = fluid.layers.yolo_box(
+ x=block_out,
+ img_size=img_shape,
+ anchors=mask_anchors,
+ class_num=self.num_classes,
+ conf_thresh=self.valid_thresh,
+ downsample_ratio=downsample)
+
+ boxes.append(b)
+ scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
+
+ downsample //= 2
+
+ if self.model_mode == 'train':
+ return outputs
+
+ preds = [img_id[0, :],
+ fluid.layers.multiclass_nms(
+ bboxes=fluid.layers.concat(boxes, axis=1),
+ scores=fluid.layers.concat(scores, axis=2),
+ score_threshold=self.valid_thresh,
+ nms_top_k=self.nms_topk,
+ keep_top_k=self.nms_posk,
+ nms_threshold=self.nms_thresh,
+ background_label=-1)]
+
+ if self.model_mode == 'test':
+ return preds
+
+ # model_mode == "eval"
+ return outputs + preds
+
+class YoloLoss(Loss):
+ def __init__(self, num_classes=80, num_max_boxes=50):
+ super(YoloLoss, self).__init__()
+ self.num_classes = num_classes
+ self.num_max_boxes = num_max_boxes
+ self.ignore_thresh = 0.7
+ self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
+ 59, 119, 116, 90, 156, 198, 373, 326]
+ self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
+
+ def forward(self, outputs, labels):
+ downsample = 32
+ gt_box, gt_label, gt_score = labels
+ losses = []
+
+ for idx, out in enumerate(outputs):
+ if idx == 3: break # debug
+ anchor_mask = self.anchor_masks[idx]
+ loss = fluid.layers.yolov3_loss(
+ x=out,
+ gt_box=gt_box,
+ gt_label=gt_label,
+ gt_score=gt_score,
+ anchor_mask=anchor_mask,
+ downsample_ratio=downsample,
+ anchors=self.anchors,
+ class_num=self.num_classes,
+ ignore_thresh=self.ignore_thresh,
+ use_label_smooth=True)
+ loss = fluid.layers.reduce_mean(loss)
+ losses.append(loss)
+ downsample //= 2
+ return losses
+
+
+def _yolov3_darknet(num_layers=53, num_classes=80,
+ model_mode='train', pretrained=True):
+ model = YOLOv3(num_classes, model_mode)
+ if pretrained:
+ assert num_layers in pretrain_infos.keys(), \
+ "YOLOv3-DarkNet{} do not have pretrained weights now, " \
+ "pretrained should be set as False".format(num_layers)
+ weight_path = get_weights_path(*(pretrain_infos[num_layers]))
+ assert weight_path.endswith('.pdparams'), \
+ "suffix of weight must be .pdparams"
+ model.load(weight_path[:-9])
+ return model
+
+
+def yolov3_darknet53(num_classes=80, model_mode='train', pretrained=True):
+ return _yolov3_darknet(53, num_classes, model_mode, pretrained)
diff --git a/text.py b/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b6cac2fea9539b5836d0ada1184ab50e2424e1e
--- /dev/null
+++ b/text.py
@@ -0,0 +1,1000 @@
+import collections
+import copy
+import six
+import sys
+from functools import partial, reduce
+
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.layers.utils as utils
+from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
+from paddle.fluid.dygraph import to_variable, Embedding, Linear, LayerNorm
+from paddle.fluid.data_feeder import convert_dtype
+
+from paddle.fluid import layers
+from paddle.fluid.dygraph import Layer
+from paddle.fluid.layers import BeamSearchDecoder
+
+__all__ = [
+ 'RNNCell', 'BasicLSTMCell', 'BasicGRUCell', 'RNN', 'DynamicDecode',
+ 'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
+ 'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
+ 'TransformerDecoder', 'TransformerBeamSearchDecoder'
+]
+
+
+class RNNCell(Layer):
+ def get_initial_states(self,
+ batch_ref,
+ shape=None,
+ dtype=None,
+ init_value=0,
+ batch_dim_idx=0):
+ """
+ Generate initialized states according to provided shape, data type and
+ value.
+
+ Parameters:
+ batch_ref: A (possibly nested structure of) tensor variable[s].
+ The first dimension of the tensor will be used as batch size to
+ initialize states.
+ shape: A (possiblely nested structure of) shape[s], where a shape is
+ represented as a list/tuple of integer). -1(for batch size) will
+ beautomatically inserted if shape is not started with it. If None,
+ property `state_shape` will be used. The default value is None.
+ dtype: A (possiblely nested structure of) data type[s]. The structure
+ must be same as that of `shape`, except when all tensors' in states
+ has the same data type, a single data type can be used. If None and
+ property `cell.state_shape` is not available, float32 will be used
+ as the data type. The default value is None.
+ init_value: A float value used to initialize states.
+
+ Returns:
+ Variable: tensor variable[s] packed in the same structure provided \
+ by shape, representing the initialized states.
+ """
+ # TODO: use inputs and batch_size
+ batch_ref = flatten(batch_ref)[0]
+
+ def _is_shape_sequence(seq):
+ if sys.version_info < (3, ):
+ integer_types = (
+ int,
+ long, )
+ else:
+ integer_types = (int, )
+ """For shape, list/tuple of integer is the finest-grained objection"""
+ if (isinstance(seq, list) or isinstance(seq, tuple)):
+ if reduce(
+ lambda flag, x: isinstance(x, integer_types) and flag,
+ seq, True):
+ return False
+ # TODO: Add check for the illegal
+ if isinstance(seq, dict):
+ return True
+ return (isinstance(seq, collections.Sequence) and
+ not isinstance(seq, six.string_types))
+
+ class Shape(object):
+ def __init__(self, shape):
+ self.shape = shape if shape[0] == -1 else ([-1] + list(shape))
+
+ # nested structure of shapes
+ states_shapes = self.state_shape if shape is None else shape
+ is_sequence_ori = utils.is_sequence
+ utils.is_sequence = _is_shape_sequence
+ states_shapes = map_structure(lambda shape: Shape(shape),
+ states_shapes)
+ utils.is_sequence = is_sequence_ori
+
+ # nested structure of dtypes
+ try:
+ states_dtypes = self.state_dtype if dtype is None else dtype
+ except NotImplementedError: # use fp32 as default
+ states_dtypes = "float32"
+ if len(flatten(states_dtypes)) == 1:
+ dtype = flatten(states_dtypes)[0]
+ states_dtypes = map_structure(lambda shape: dtype, states_shapes)
+
+ init_states = map_structure(
+ lambda shape, dtype: fluid.layers.fill_constant_batch_size_like(
+ input=batch_ref,
+ shape=shape.shape,
+ dtype=dtype,
+ value=init_value,
+ input_dim_idx=batch_dim_idx), states_shapes, states_dtypes)
+ return init_states
+
+ @property
+ def state_shape(self):
+ """
+ Abstract method (property).
+ Used to initialize states.
+ A (possiblely nested structure of) shape[s], where a shape is represented
+ as a list/tuple of integers (-1 for batch size would be automatically
+ inserted into a shape if shape is not started with it).
+ Not necessary to be implemented if states are not initialized by
+ `get_initial_states` or the `shape` argument is provided when using
+ `get_initial_states`.
+ """
+ raise NotImplementedError(
+ "Please add implementaion for `state_shape` in the used cell.")
+
+ @property
+ def state_dtype(self):
+ """
+ Abstract method (property).
+ Used to initialize states.
+ A (possiblely nested structure of) data types[s]. The structure must be
+ same as that of `shape`, except when all tensors' in states has the same
+ data type, a signle data type can be used.
+ Not necessary to be implemented if states are not initialized
+ by `get_initial_states` or the `dtype` argument is provided when using
+ `get_initial_states`.
+ """
+ raise NotImplementedError(
+ "Please add implementaion for `state_dtype` in the used cell.")
+
+
+class BasicLSTMCell(RNNCell):
+ """
+ ****
+ BasicLSTMUnit class, Using basic operator to build LSTM
+ The algorithm can be described as the code below.
+ .. math::
+ i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
+ f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
+ o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
+ \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
+ c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
+ h_t &= o_t \odot tanh(c_t)
+ - $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix
+ of weights from the input gate to the input)
+ - The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
+ - sigmoid is the logistic sigmoid function.
+ - $i, f, o$ and $c$ are the input gate, forget gate, output gate,
+ and cell activation vectors, respectively, all of which have the same size as
+ the cell output activation vector $h$.
+ - The :math:`\odot` is the element-wise product of the vectors.
+ - :math:`tanh` is the activation functions.
+ - :math:`\\tilde{c_t}` is also called candidate hidden state,
+ which is computed based on the current input and the previous hidden state.
+ Args:
+ name_scope(string) : The name scope used to identify parameter and bias name
+ hidden_size (integer): The hidden size used in the Unit.
+ param_attr(ParamAttr|None): The parameter attribute for the learnable
+ weight matrix. Note:
+ If it is set to None or one attribute of ParamAttr, lstm_unit will
+ create ParamAttr as param_attr. If the Initializer of the param_attr
+ is not set, the parameter is initialized with Xavier. Default: None.
+ bias_attr (ParamAttr|None): The parameter attribute for the bias
+ of LSTM unit.
+ If it is set to None or one attribute of ParamAttr, lstm_unit will
+ create ParamAttr as bias_attr. If the Initializer of the bias_attr
+ is not set, the bias is initialized as zero. Default: None.
+ gate_activation (function|None): The activation function for gates (actGate).
+ Default: 'fluid.layers.sigmoid'
+ activation (function|None): The activation function for cells (actNode).
+ Default: 'fluid.layers.tanh'
+ forget_bias(float|1.0): forget bias used when computing forget gate
+ dtype(string): data type used in this unit
+ """
+
+ def __init__(self,
+ input_size,
+ hidden_size,
+ param_attr=None,
+ bias_attr=None,
+ gate_activation=None,
+ activation=None,
+ forget_bias=1.0,
+ dtype='float32'):
+ super(BasicLSTMCell, self).__init__()
+
+ self._hidden_size = hidden_size
+ self._param_attr = param_attr
+ self._bias_attr = bias_attr
+ self._gate_activation = gate_activation or layers.sigmoid
+ self._activation = activation or layers.tanh
+ self._forget_bias = layers.fill_constant(
+ [1], dtype=dtype, value=forget_bias)
+ self._forget_bias.stop_gradient = False
+ self._dtype = dtype
+ self._input_size = input_size
+
+ self._weight = self.create_parameter(
+ attr=self._param_attr,
+ shape=[
+ self._input_size + self._hidden_size, 4 * self._hidden_size
+ ],
+ dtype=self._dtype)
+
+ self._bias = self.create_parameter(
+ attr=self._bias_attr,
+ shape=[4 * self._hidden_size],
+ dtype=self._dtype,
+ is_bias=True)
+
+ def forward(self, input, state):
+ pre_hidden, pre_cell = state
+ concat_input_hidden = layers.concat([input, pre_hidden], 1)
+ gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
+
+ gate_input = layers.elementwise_add(gate_input, self._bias)
+ i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
+ new_cell = layers.elementwise_add(
+ layers.elementwise_mul(
+ pre_cell,
+ layers.sigmoid(layers.elementwise_add(f, self._forget_bias))),
+ layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j)))
+ new_hidden = layers.tanh(new_cell) * layers.sigmoid(o)
+
+ return new_hidden, [new_hidden, new_cell]
+
+ @property
+ def state_shape(self):
+ return [[self._hidden_size], [self._hidden_size]]
+
+
+class BasicGRUCell(RNNCell):
+ """
+ ****
+ BasicGRUUnit class, using basic operators to build GRU
+ The algorithm can be described as the equations below.
+
+ .. math::
+ u_t & = actGate(W_ux xu_{t} + W_uh h_{t-1} + b_u)
+
+ r_t & = actGate(W_rx xr_{t} + W_rh h_{t-1} + b_r)
+
+ m_t & = actNode(W_cx xm_t + W_ch dot(r_t, h_{t-1}) + b_m)
+
+ h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
+
+ Args:
+ hidden_size (integer): The hidden size used in the Unit.
+ param_attr(ParamAttr|None): The parameter attribute for the learnable
+ weight matrix. Note:
+ If it is set to None or one attribute of ParamAttr, gru_unit will
+ create ParamAttr as param_attr. If the Initializer of the param_attr
+ is not set, the parameter is initialized with Xavier. Default: None.
+ bias_attr (ParamAttr|None): The parameter attribute for the bias
+ of GRU unit.
+ If it is set to None or one attribute of ParamAttr, gru_unit will
+ create ParamAttr as bias_attr. If the Initializer of the bias_attr
+ is not set, the bias is initialized zero. Default: None.
+ gate_activation (function|None): The activation function for gates (actGate).
+ Default: 'fluid.layers.sigmoid'
+ activation (function|None): The activation function for cell (actNode).
+ Default: 'fluid.layers.tanh'
+ dtype(string): data type used in this unit
+ """
+
+ def __init__(self,
+ input_size,
+ hidden_size,
+ param_attr=None,
+ bias_attr=None,
+ gate_activation=None,
+ activation=None,
+ dtype='float32'):
+ super(BasicGRUCell, self).__init__()
+ self._input_size = input_size
+ self._hiden_size = hidden_size
+ self._param_attr = param_attr
+ self._bias_attr = bias_attr
+ self._gate_activation = gate_activation or layers.sigmoid
+ self._activation = activation or layers.tanh
+ self._dtype = dtype
+
+ if self._param_attr is not None and self._param_attr.name is not None:
+ gate_param_attr = copy.deepcopy(self._param_attr)
+ candidate_param_attr = copy.deepcopy(self._param_attr)
+ gate_param_attr.name += "_gate"
+ candidate_param_attr.name += "_candidate"
+ else:
+ gate_param_attr = self._param_attr
+ candidate_param_attr = self._param_attr
+
+ self._gate_weight = self.create_parameter(
+ attr=gate_param_attr,
+ shape=[self._input_size + self._hiden_size, 2 * self._hiden_size],
+ dtype=self._dtype)
+
+ self._candidate_weight = self.create_parameter(
+ attr=candidate_param_attr,
+ shape=[self._input_size + self._hiden_size, self._hiden_size],
+ dtype=self._dtype)
+
+ if self._bias_attr is not None and self._bias_attr.name is not None:
+ gate_bias_attr = copy.deepcopy(self._bias_attr)
+ candidate_bias_attr = copy.deepcopy(self._bias_attr)
+ gate_bias_attr.name += "_gate"
+ candidate_bias_attr.name += "_candidate"
+ else:
+ gate_bias_attr = self._bias_attr
+ candidate_bias_attr = self._bias_attr
+
+ self._gate_bias = self.create_parameter(
+ attr=gate_bias_attr,
+ shape=[2 * self._hiden_size],
+ dtype=self._dtype,
+ is_bias=True)
+ self._candidate_bias = self.create_parameter(
+ attr=candidate_bias_attr,
+ shape=[self._hiden_size],
+ dtype=self._dtype,
+ is_bias=True)
+
+ def forward(self, input, state):
+ pre_hidden = state
+ concat_input_hidden = layers.concat([input, pre_hidden], axis=1)
+
+ gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight)
+
+ gate_input = layers.elementwise_add(gate_input, self._gate_bias)
+
+ gate_input = self._gate_activation(gate_input)
+ r, u = layers.split(gate_input, num_or_sections=2, dim=1)
+
+ r_hidden = r * pre_hidden
+
+ candidate = layers.matmul(
+ layers.concat([input, r_hidden], 1), self._candidate_weight)
+ candidate = layers.elementwise_add(candidate, self._candidate_bias)
+
+ c = self._activation(candidate)
+ new_hidden = u * pre_hidden + (1 - u) * c
+
+ return new_hidden
+
+ @property
+ def state_shape(self):
+ return [self._hidden_size]
+
+
+class RNN(fluid.dygraph.Layer):
+ def __init__(self, cell, is_reverse=False, time_major=False):
+ super(RNN, self).__init__()
+ self.cell = cell
+ if not hasattr(self.cell, "call"):
+ self.cell.call = self.cell.forward
+ self.is_reverse = is_reverse
+ self.time_major = time_major
+ self.batch_index, self.time_step_index = (1, 0) if time_major else (0,
+ 1)
+
+ def forward(self,
+ inputs,
+ initial_states=None,
+ sequence_length=None,
+ **kwargs):
+ if fluid.in_dygraph_mode():
+
+ class ArrayWrapper(object):
+ def __init__(self, x):
+ self.array = [x]
+
+ def append(self, x):
+ self.array.append(x)
+ return self
+
+ def _maybe_copy(state, new_state, step_mask):
+ # TODO: use where_op
+ new_state = fluid.layers.elementwise_mul(
+ new_state, step_mask,
+ axis=0) - fluid.layers.elementwise_mul(
+ state, (step_mask - 1), axis=0)
+ return new_state
+
+ flat_inputs = flatten(inputs)
+ batch_size, time_steps = (
+ flat_inputs[0].shape[self.batch_index],
+ flat_inputs[0].shape[self.time_step_index])
+
+ if initial_states is None:
+ initial_states = self.cell.get_initial_states(
+ batch_ref=inputs, batch_dim_idx=self.batch_index)
+
+ if not self.time_major:
+ inputs = map_structure(
+ lambda x: fluid.layers.transpose(x, [1, 0] + list(
+ range(2, len(x.shape)))), inputs)
+
+ if sequence_length:
+ mask = fluid.layers.sequence_mask(
+ sequence_length,
+ maxlen=time_steps,
+ dtype=flatten(initial_states)[0].dtype)
+ mask = fluid.layers.transpose(mask, [1, 0])
+
+ if self.is_reverse:
+ inputs = map_structure(
+ lambda x: fluid.layers.reverse(x, axis=[0]), inputs)
+ mask = fluid.layers.reverse(
+ mask, axis=[0]) if sequence_length else None
+
+ states = initial_states
+ outputs = []
+ for i in range(time_steps):
+ step_inputs = map_structure(lambda x: x[i], inputs)
+ step_outputs, new_states = self.cell(step_inputs, states,
+ **kwargs)
+ if sequence_length:
+ new_states = map_structure(
+ partial(
+ _maybe_copy, step_mask=mask[i]),
+ states,
+ new_states)
+ states = new_states
+ outputs = map_structure(
+ lambda x: ArrayWrapper(x),
+ step_outputs) if i == 0 else map_structure(
+ lambda x, x_array: x_array.append(x), step_outputs,
+ outputs)
+
+ final_outputs = map_structure(
+ lambda x: fluid.layers.stack(x.array,
+ axis=self.time_step_index),
+ outputs)
+
+ if self.is_reverse:
+ final_outputs = map_structure(
+ lambda x: fluid.layers.reverse(x,
+ axis=self.time_step_index),
+ final_outputs)
+
+ final_states = new_states
+ else:
+ final_outputs, final_states = fluid.layers.rnn(
+ self.cell,
+ inputs,
+ initial_states=initial_states,
+ sequence_length=sequence_length,
+ time_major=self.time_major,
+ is_reverse=self.is_reverse,
+ **kwargs)
+ return final_outputs, final_states
+
+
+class DynamicDecode(Layer):
+ def __init__(self,
+ decoder,
+ max_step_num=None,
+ output_time_major=False,
+ impute_finished=False,
+ is_test=False,
+ return_length=False):
+ super(DynamicDecode, self).__init__()
+ self.decoder = decoder
+ self.max_step_num = max_step_num
+ self.output_time_major = output_time_major
+ self.impute_finished = impute_finished
+ self.is_test = is_test
+ self.return_length = return_length
+
+ def forward(self, inits=None, **kwargs):
+ if fluid.in_dygraph_mode():
+
+ class ArrayWrapper(object):
+ def __init__(self, x):
+ self.array = [x]
+
+ def append(self, x):
+ self.array.append(x)
+ return self
+
+ def __getitem__(self, item):
+ return self.array.__getitem__(item)
+
+ def _maybe_copy(state, new_state, step_mask):
+ # TODO: use where_op
+ state_dtype = state.dtype
+ if convert_dtype(state_dtype) in ["bool"]:
+ state = layers.cast(state, dtype="float32")
+ new_state = layers.cast(new_state, dtype="float32")
+ if step_mask.dtype != state.dtype:
+ step_mask = layers.cast(step_mask, dtype=state.dtype)
+ # otherwise, renamed bool gradients of would be summed up leading
+ # to sum(bool) error.
+ step_mask.stop_gradient = True
+ new_state = layers.elementwise_mul(
+ state, step_mask, axis=0) - layers.elementwise_mul(
+ new_state, (step_mask - 1), axis=0)
+ if convert_dtype(state_dtype) in ["bool"]:
+ new_state = layers.cast(new_state, dtype=state_dtype)
+ return new_state
+
+ initial_inputs, initial_states, initial_finished = self.decoder.initialize(
+ inits)
+ inputs, states, finished = (initial_inputs, initial_states,
+ initial_finished)
+ cond = layers.logical_not((layers.reduce_all(initial_finished)))
+ sequence_lengths = layers.cast(
+ layers.zeros_like(initial_finished), "int64")
+ outputs = None
+
+ step_idx = 0
+ step_idx_tensor = layers.fill_constant(
+ shape=[1], dtype="int64", value=step_idx)
+ while cond.numpy():
+ (step_outputs, next_states, next_inputs,
+ next_finished) = self.decoder.step(step_idx_tensor, inputs,
+ states, **kwargs)
+ if not self.decoder.tracks_own_finished:
+ # BeamSearchDecoder would track it own finished, since
+ # beams would be reordered and the finished status of each
+ # entry might change. Otherwise, perform logical OR which
+ # would not change the already finished.
+ next_finished = layers.logical_or(next_finished, finished)
+ # To confirm states.finished/finished be consistent with
+ # next_finished.
+ layers.assign(next_finished, finished)
+ next_sequence_lengths = layers.elementwise_add(
+ sequence_lengths,
+ layers.cast(
+ layers.logical_not(finished), sequence_lengths.dtype))
+
+ if self.impute_finished: # rectify the states for the finished.
+ next_states = map_structure(
+ lambda x, y: _maybe_copy(x, y, finished), states,
+ next_states)
+ outputs = map_structure(
+ lambda x: ArrayWrapper(x),
+ step_outputs) if step_idx == 0 else map_structure(
+ lambda x, x_array: x_array.append(x), step_outputs,
+ outputs)
+ inputs, states, finished, sequence_lengths = (
+ next_inputs, next_states, next_finished,
+ next_sequence_lengths)
+
+ layers.increment(x=step_idx_tensor, value=1.0, in_place=True)
+ step_idx += 1
+
+ layers.logical_not(layers.reduce_all(finished), cond)
+ if self.max_step_num is not None and step_idx > self.max_step_num:
+ break
+
+ final_outputs = map_structure(
+ lambda x: fluid.layers.stack(x.array, axis=0), outputs)
+ final_states = states
+
+ try:
+ final_outputs, final_states = self.decoder.finalize(
+ final_outputs, final_states, sequence_lengths)
+ except NotImplementedError:
+ pass
+
+ if not self.output_time_major:
+ final_outputs = map_structure(
+ lambda x: layers.transpose(x, [1, 0] + list(
+ range(2, len(x.shape)))), final_outputs)
+
+ return (final_outputs, final_states,
+ sequence_lengths) if self.return_length else (
+ final_outputs, final_states)
+ else:
+ return fluid.layers.dynamic_decode(
+ self.decoder,
+ inits,
+ max_step_num=self.max_step_num,
+ output_time_major=self.output_time_major,
+ impute_finished=self.impute_finished,
+ is_test=self.is_test,
+ return_length=self.return_length,
+ **kwargs)
+
+
+class TransfomerCell(object):
+ """
+ Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
+ used as RNNCell
+ """
+
+ def __init__(self, decoder):
+ self.decoder = decoder
+
+ def __call__(self, inputs, states, trg_src_attn_bias, enc_output,
+ static_caches):
+ trg_word, trg_pos = inputs
+ for cache, static_cache in zip(states, static_caches):
+ cache.update(static_cache)
+ logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
+ enc_output, states)
+ new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states]
+ return logits, new_states
+
+
+class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
+ def __init__(self, cell, start_token, end_token, beam_size,
+ var_dim_in_state):
+ super(TransformerBeamSearchDecoder,
+ self).__init__(cell, start_token, end_token, beam_size)
+ self.cell = cell
+ self.var_dim_in_state = var_dim_in_state
+
+ def _merge_batch_beams_with_var_dim(self, x):
+ # init length of cache is 0, and it increases with decoding carrying on,
+ # thus need to reshape elaborately
+ var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim
+ x = layers.transpose(x,
+ list(range(var_dim_in_state, len(x.shape))) +
+ list(range(0, var_dim_in_state)))
+ x = layers.reshape(
+ x, [0] * (len(x.shape) - var_dim_in_state
+ ) + [self.batch_size * self.beam_size] +
+ [int(size) for size in x.shape[-var_dim_in_state + 2:]])
+ x = layers.transpose(
+ x,
+ list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) +
+ list(range(0, (len(x.shape) + 1 - var_dim_in_state))))
+ return x
+
+ def _split_batch_beams_with_var_dim(self, x):
+ var_dim_size = layers.shape(x)[self.var_dim_in_state]
+ x = layers.reshape(
+ x, [-1, self.beam_size] +
+ [int(size)
+ for size in x.shape[1:self.var_dim_in_state]] + [var_dim_size] +
+ [int(size) for size in x.shape[self.var_dim_in_state + 1:]])
+ return x
+
+ def step(self, time, inputs, states, **kwargs):
+ # compared to RNN, Transformer has 3D data at every decoding step
+ inputs = layers.reshape(inputs, [-1, 1]) # token
+ pos = layers.ones_like(inputs) * time # pos
+ cell_states = map_structure(self._merge_batch_beams_with_var_dim,
+ states.cell_states)
+
+ cell_outputs, next_cell_states = self.cell((inputs, pos), cell_states,
+ **kwargs)
+ cell_outputs = map_structure(self._split_batch_beams, cell_outputs)
+ next_cell_states = map_structure(self._split_batch_beams_with_var_dim,
+ next_cell_states)
+
+ beam_search_output, beam_search_state = self._beam_search_step(
+ time=time,
+ logits=cell_outputs,
+ next_cell_states=next_cell_states,
+ beam_state=states)
+ next_inputs, finished = (beam_search_output.predicted_ids,
+ beam_search_state.finished)
+
+ return (beam_search_output, beam_search_state, next_inputs, finished)
+
+
+### Transformer Modules ###
+class PrePostProcessLayer(Layer):
+ """
+ PrePostProcessLayer
+ """
+
+ def __init__(self, process_cmd, d_model, dropout_rate):
+ super(PrePostProcessLayer, self).__init__()
+ self.process_cmd = process_cmd
+ self.functors = []
+ for cmd in self.process_cmd:
+ if cmd == "a": # add residual connection
+ self.functors.append(lambda x, y: x + y if y else x)
+ elif cmd == "n": # add layer normalization
+ self.functors.append(
+ self.add_sublayer(
+ "layer_norm_%d" % len(
+ self.sublayers(include_sublayers=False)),
+ LayerNorm(
+ normalized_shape=d_model,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(1.)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.)))))
+ elif cmd == "d": # add dropout
+ self.functors.append(lambda x: layers.dropout(
+ x, dropout_prob=dropout_rate, is_test=False)
+ if dropout_rate else x)
+
+ def forward(self, x, residual=None):
+ for i, cmd in enumerate(self.process_cmd):
+ if cmd == "a":
+ x = self.functors[i](x, residual)
+ else:
+ x = self.functors[i](x)
+ return x
+
+
+class MultiHeadAttention(Layer):
+ """
+ Multi-Head Attention
+ """
+
+ def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
+ super(MultiHeadAttention, self).__init__()
+ self.n_head = n_head
+ self.d_key = d_key
+ self.d_value = d_value
+ self.d_model = d_model
+ self.dropout_rate = dropout_rate
+ self.q_fc = Linear(
+ input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
+ self.k_fc = Linear(
+ input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
+ self.v_fc = Linear(
+ input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
+ self.proj_fc = Linear(
+ input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
+
+ def _prepare_qkv(self, queries, keys, values, cache=None):
+ if keys is None: # self-attention
+ keys, values = queries, queries
+ static_kv = False
+ else: # cross-attention
+ static_kv = True
+
+ q = self.q_fc(queries)
+ q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
+ q = layers.transpose(x=q, perm=[0, 2, 1, 3])
+
+ if cache is not None and static_kv and "static_k" in cache:
+ # for encoder-decoder attention in inference and has cached
+ k = cache["static_k"]
+ v = cache["static_v"]
+ else:
+ k = self.k_fc(keys)
+ v = self.v_fc(values)
+ k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
+ k = layers.transpose(x=k, perm=[0, 2, 1, 3])
+ v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
+ v = layers.transpose(x=v, perm=[0, 2, 1, 3])
+
+ if cache is not None:
+ if static_kv and not "static_k" in cache:
+ # for encoder-decoder attention in inference and has not cached
+ cache["static_k"], cache["static_v"] = k, v
+ elif not static_kv:
+ # for decoder self-attention in inference
+ cache_k, cache_v = cache["k"], cache["v"]
+ k = layers.concat([cache_k, k], axis=2)
+ v = layers.concat([cache_v, v], axis=2)
+ cache["k"], cache["v"] = k, v
+
+ return q, k, v
+
+ def forward(self, queries, keys, values, attn_bias, cache=None):
+ # compute q ,k ,v
+ q, k, v = self._prepare_qkv(queries, keys, values, cache)
+
+ # scale dot product attention
+ product = layers.matmul(
+ x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
+ if attn_bias:
+ product += attn_bias
+ weights = layers.softmax(product)
+ if self.dropout_rate:
+ weights = layers.dropout(
+ weights, dropout_prob=self.dropout_rate, is_test=False)
+
+ out = layers.matmul(weights, v)
+
+ # combine heads
+ out = layers.transpose(out, perm=[0, 2, 1, 3])
+ out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
+
+ # project to output
+ out = self.proj_fc(out)
+ return out
+
+ def cal_kv(self, keys, values):
+ k = self.k_fc(keys)
+ v = self.v_fc(values)
+ k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
+ k = layers.transpose(x=k, perm=[0, 2, 1, 3])
+ v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
+ v = layers.transpose(x=v, perm=[0, 2, 1, 3])
+ return k, v
+
+
+class FFN(Layer):
+ """
+ Feed-Forward Network
+ """
+
+ def __init__(self, d_inner_hid, d_model, dropout_rate):
+ super(FFN, self).__init__()
+ self.dropout_rate = dropout_rate
+ self.fc1 = Linear(
+ input_dim=d_model, output_dim=d_inner_hid, act="relu")
+ self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
+
+ def forward(self, x):
+ hidden = self.fc1(x)
+ if self.dropout_rate:
+ hidden = layers.dropout(
+ hidden, dropout_prob=self.dropout_rate, is_test=False)
+ out = self.fc2(hidden)
+ return out
+
+
+class TransformerEncoderLayer(Layer):
+ """
+ EncoderLayer
+ """
+
+ def __init__(self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(TransformerEncoderLayer, self).__init__()
+
+ self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
+ self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ attn_output = self.self_attn(
+ self.preprocesser1(enc_input), None, None, attn_bias)
+ attn_output = self.postprocesser1(attn_output, enc_input)
+
+ ffn_output = self.ffn(self.preprocesser2(attn_output))
+ ffn_output = self.postprocesser2(ffn_output, attn_output)
+ return ffn_output
+
+
+class TransformerEncoder(Layer):
+ """
+ encoder
+ """
+
+ def __init__(self,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(TransformerEncoder, self).__init__()
+
+ self.encoder_layers = list()
+ for i in range(n_layer):
+ self.encoder_layers.append(
+ self.add_sublayer(
+ "layer_%d" % i,
+ TransformerEncoderLayer(
+ n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd, postprocess_cmd)))
+ self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ for encoder_layer in self.encoder_layers:
+ enc_output = encoder_layer(enc_input, attn_bias)
+ enc_input = enc_output
+
+ return self.processer(enc_output)
+
+
+class TransformerDecoderLayer(Layer):
+ """
+ decoder
+ """
+
+ def __init__(self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+ super(TransformerDecoderLayer, self).__init__()
+
+ self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.cross_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
+ self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self,
+ dec_input,
+ enc_output,
+ self_attn_bias,
+ cross_attn_bias,
+ cache=None):
+ self_attn_output = self.self_attn(
+ self.preprocesser1(dec_input), None, None, self_attn_bias, cache)
+ self_attn_output = self.postprocesser1(self_attn_output, dec_input)
+
+ cross_attn_output = self.cross_attn(
+ self.preprocesser2(self_attn_output), enc_output, enc_output,
+ cross_attn_bias, cache)
+ cross_attn_output = self.postprocesser2(cross_attn_output,
+ self_attn_output)
+
+ ffn_output = self.ffn(self.preprocesser3(cross_attn_output))
+ ffn_output = self.postprocesser3(ffn_output, cross_attn_output)
+
+ return ffn_output
+
+
+class TransformerDecoder(Layer):
+ """
+ decoder
+ """
+
+ def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout, relu_dropout,
+ preprocess_cmd, postprocess_cmd):
+ super(TransformerDecoder, self).__init__()
+
+ self.decoder_layers = list()
+ for i in range(n_layer):
+ self.decoder_layers.append(
+ self.add_sublayer(
+ "layer_%d" % i,
+ TransformerDecoderLayer(
+ n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd, postprocess_cmd)))
+ self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self,
+ dec_input,
+ enc_output,
+ self_attn_bias,
+ cross_attn_bias,
+ caches=None):
+ for i, decoder_layer in enumerate(self.decoder_layers):
+ dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
+ cross_attn_bias, None
+ if caches is None else caches[i])
+ dec_input = dec_output
+
+ return self.processer(dec_output)
+
+ def prepare_static_cache(self, enc_output):
+ return [
+ dict(
+ zip(("static_k", "static_v"),
+ decoder_layer.cross_attn.cal_kv(enc_output, enc_output)))
+ for decoder_layer in self.decoder_layers
+ ]
diff --git a/transformer/README.md b/transformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2c4c22b91788a091fc9c08e303e5bcae7d80a4de
--- /dev/null
+++ b/transformer/README.md
@@ -0,0 +1,305 @@
+## Transformer
+
+以下是本例的简要目录结构及说明:
+
+```text
+.
+├── images # README 文档中的图片
+├── utils # 工具包
+├── gen_data.sh # 数据生成脚本
+├── predict.py # 预测脚本
+├── reader.py # 数据读取接口
+├── README.md # 文档
+├── train.py # 训练脚本
+├── model.py # 模型定义文件
+└── transformer.yaml # 配置文件
+```
+
+## 模型简介
+
+机器翻译(machine translation, MT)是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程,输入为源语言句子,输出为相应的目标语言的句子。
+
+本项目是机器翻译领域主流模型 Transformer 的 PaddlePaddle 实现, 包含模型训练,预测以及使用自定义数据等内容。用户可以基于发布的内容搭建自己的翻译模型。
+
+
+## 快速开始
+
+### 安装说明
+
+1. paddle安装
+
+ 本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
+
+2. 下载代码
+
+ 克隆代码库到本地
+ ```shell
+ git clone https://github.com/PaddlePaddle/hapi
+ cd hapi/transformer
+ ```
+
+3. 环境依赖
+
+ 请参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.6/beginners_guide/install/index_cn.html)部分的内容
+
+
+### 数据准备
+
+公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)作为示例提供。运行 `gen_data.sh` 脚本进行 WMT'16 EN-DE 数据集的下载和预处理(时间较长,建议后台运行)。数据处理过程主要包括 Tokenize 和 [BPE 编码(byte-pair encoding)](https://arxiv.org/pdf/1508.07909)。运行成功后,将会生成文件夹 `gen_data`,其目录结构如下:
+
+```text
+.
+├── wmt16_ende_data # WMT16 英德翻译数据
+├── wmt16_ende_data_bpe # BPE 编码的 WMT16 英德翻译数据
+├── mosesdecoder # Moses 机器翻译工具集,包含了 Tokenize、BLEU 评估等脚本
+└── subword-nmt # BPE 编码的代码
+```
+
+另外我们也整理提供了一份处理好的 WMT'16 EN-DE 数据以供[下载](https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz)使用,其中包含词典(`vocab_all.bpe.32000`文件)、训练所需的 BPE 数据(`train.tok.clean.bpe.32000.en-de`文件)、预测所需的 BPE 数据(`newstest2016.tok.bpe.32000.en-de`等文件)和相应的评估预测结果所需的 tokenize 数据(`newstest2016.tok.de`等文件)。
+
+
+自定义数据:如果需要使用自定义数据,本项目程序中可直接支持的数据格式为制表符 \t 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔。提供以上格式的数据文件(可以分多个part,数据读取支持文件通配符)和相应的词典文件即可直接运行。
+
+### 单机训练
+
+#### 单机单卡
+
+以提供的英德翻译数据为例,可以执行以下命令进行模型训练:
+
+```sh
+# setting visible devices for training
+export CUDA_VISIBLE_DEVICES=0
+
+python -u train.py \
+ --epoch 30 \
+ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --special_token '' '' '' \
+ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
+ --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
+ --batch_size 4096
+```
+
+以上命令中传入了训练轮数(`epoch`)和训练数据文件路径(注意请正确设置,支持通配符)等参数,更多参数的使用以及支持的模型超参数可以参见 `transformer.yaml` 配置文件,其中默认提供了 Transformer base model 的配置,如需调整可以在配置文件中更改或通过命令行传入(命令行传入内容将覆盖配置文件中的设置)。可以通过以下命令来训练 Transformer 论文中的 big model:
+
+```sh
+# setting visible devices for training
+export CUDA_VISIBLE_DEVICES=0
+
+python -u train.py \
+ --epoch 30 \
+ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --special_token '' '' '' \
+ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
+ --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
+ --batch_size 4096 \
+ --n_head 16 \
+ --d_model 1024 \
+ --d_inner_hid 4096 \
+ --prepostprocess_dropout 0.3
+```
+
+另外,如果在执行训练时若提供了 `save_model`(默认为 trained_models),则每个 epoch 将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的 `epoch_id.pdparams` 和 `epoch_id.pdopt` 两个文件),每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出:
+
+```txt
+step 100/1 - loss: 9.165776 - normalized loss: 7.790036 - ppl: 9564.142578 - 247ms/step
+step 200/1 - loss: 8.037900 - normalized loss: 6.662160 - ppl: 3096.104492 - 227ms/step
+step 300/1 - loss: 7.668307 - normalized loss: 6.292567 - ppl: 2139.457031 - 221ms/step
+step 400/1 - loss: 7.598633 - normalized loss: 6.222893 - ppl: 1995.466797 - 218ms/step
+```
+
+也可以使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度较慢。
+
+#### 单机多卡
+
+支持多进程多卡进行模型训练,启动训练的方式如下:
+
+```sh
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 train.py \
+ --epoch 30 \
+ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --special_token '' '' '' \
+ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
+ --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
+ --batch_size 4096 \
+ --print_step 100
+```
+
+#### 静态图训练
+
+默认使用动态图模式进行训练,可以通过设置 `eager_run` 参数为False来以静态图模式进行训练,如下:
+
+```sh
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 train.py \
+ --epoch 30 \
+ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --special_token '' '' '' \
+ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
+ --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
+ --batch_size 4096 \
+ --print_step 100 \
+ --eager_run False
+```
+
+
+### 模型推断
+
+以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
+
+```sh
+# setting visible devices for prediction
+export CUDA_VISIBLE_DEVICES=0
+
+python -u predict.py \
+ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --special_token '' '' '' \
+ --predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
+ --batch_size 32 \
+ --init_from_params base_model_dygraph/step_100000/transformer \
+ --beam_size 5 \
+ --max_out_len 255 \
+ --output_file predict.txt
+```
+
+ 由 `predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型文件路径(不包含扩展名),更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
+
+```sh
+# setting visible devices for prediction
+export CUDA_VISIBLE_DEVICES=0
+
+python -u predict.py \
+ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --special_token '' '' '' \
+ --predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
+ --batch_size 32 \
+ --init_from_params base_model_dygraph/step_100000/transformer \
+ --beam_size 5 \
+ --max_out_len 255 \
+ --output_file predict.txt \
+ --n_head 16 \
+ --d_model 1024 \
+ --d_inner_hid 4096 \
+ --prepostprocess_dropout 0.3
+```
+
+和训练类似,预测时同样可以以静态图模式进行,如下:
+
+```sh
+# setting visible devices for prediction
+export CUDA_VISIBLE_DEVICES=0
+
+python -u predict.py \
+ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
+ --special_token '' '' '' \
+ --predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
+ --batch_size 32 \
+ --init_from_params base_model_dygraph/step_100000/transformer \
+ --beam_size 5 \
+ --max_out_len 255 \
+ --output_file predict.txt \
+ --eager_run False
+```
+
+### 模型评估
+
+预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标):
+
+```sh
+# 还原 predict.txt 中的预测结果为 tokenize 后的数据
+sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
+# 若无 BLEU 评估工具,需先进行下载
+# git clone https://github.com/moses-smt/mosesdecoder.git
+# 以英德翻译 newstest2014 测试数据为例
+perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2014.tok.de < predict.tok.txt
+```
+可以看到类似如下的结果:
+```
+BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078)
+```
+
+使用本项目中提供的内容,英德翻译 base model 和 big model 八卡训练 100K 个 iteration 后测试有大约如下的 BLEU 值:
+
+| 测试集 | newstest2014 | newstest2015 | newstest2016 |
+|-|-|-|-|
+| Base | 26.35 | 29.07 | 33.30 |
+| Big | 27.07 | 30.09 | 34.38 |
+
+### 预训练模型
+
+我们这里提供了对应有以上 BLEU 值的 [base model](https://transformer-res.bj.bcebos.com/base_model_dygraph.tar.gz) 和 [big model](https://transformer-res.bj.bcebos.com/big_model_dygraph.tar.gz) 的模型参数提供下载使用(注意,模型使用了提供下载的数据进行训练和测试)。
+
+## 进阶使用
+
+### 背景介绍
+
+Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构,其完全使用注意力(Attention)机制来实现序列到序列的建模[1]。
+
+相较于此前 Seq2Seq 模型中广泛使用的循环神经网络(Recurrent Neural Network, RNN),使用(Self)Attention 进行输入序列到输出序列的变换主要具有以下优势:
+
+- 计算复杂度小
+ - 特征维度为 d 、长度为 n 的序列,在 RNN 中计算复杂度为 `O(n * d * d)` (n 个时间步,每个时间步计算 d 维的矩阵向量乘法),在 Self-Attention 中计算复杂度为 `O(n * n * d)` (n 个时间步两两计算 d 维的向量点积或其他相关度函数),n 通常要小于 d 。
+- 计算并行度高
+ - RNN 中当前时间步的计算要依赖前一个时间步的计算结果;Self-Attention 中各时间步的计算只依赖输入不依赖之前时间步输出,各时间步可以完全并行。
+- 容易学习长程依赖(long-range dependencies)
+ - RNN 中相距为 n 的两个位置间的关联需要 n 步才能建立;Self-Attention 中任何两个位置都直接相连;路径越短信号传播越容易。
+
+Transformer 中引入使用的基于 Self-Attention 的序列建模模块结构,已被广泛应用在 Bert [2]等语义表示模型中,取得了显著效果。
+
+
+### 模型概览
+
+Transformer 同样使用了 Seq2Seq 模型中典型的编码器-解码器(Encoder-Decoder)的框架结构,整体网络结构如图1所示。
+
+
+
+图 1. Transformer 网络结构图
+
+
+可以看到,和以往 Seq2Seq 模型不同,Transformer 的 Encoder 和 Decoder 中不再使用 RNN 的结构。
+
+### 模型特点
+
+Transformer 中的 Encoder 由若干相同的 layer 堆叠组成,每个 layer 主要由多头注意力(Multi-Head Attention)和全连接的前馈(Feed-Forward)网络这两个 sub-layer 构成。
+- Multi-Head Attention 在这里用于实现 Self-Attention,相比于简单的 Attention 机制,其将输入进行多路线性变换后分别计算 Attention 的结果,并将所有结果拼接后再次进行线性变换作为输出。参见图2,其中 Attention 使用的是点积(Dot-Product),并在点积后进行了 scale 的处理以避免因点积结果过大进入 softmax 的饱和区域。
+- Feed-Forward 网络会对序列中的每个位置进行相同的计算(Position-wise),其采用的是两次线性变换中间加以 ReLU 激活的结构。
+
+此外,每个 sub-layer 后还施以 Residual Connection [3]和 Layer Normalization [4]来促进梯度传播和模型收敛。
+
+
+
+图 2. Multi-Head Attention
+
+
+Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 layer ,在组成 Decoder 的 layer 中还多了一个 Multi-Head Attention 的 sub-layer 来实现对 Encoder 输出的 Attention,这个 Encoder-Decoder Attention 在其他 Seq2Seq 模型中也是存在的。
+
+## FAQ
+
+**Q:** 预测结果中样本数少于输入的样本数是什么原因
+**A:** 若样本中最大长度超过 `transformer.yaml` 中 `max_length` 的默认设置,请注意运行时增大 `--max_length` 的设置,否则超长样本将被过滤。
+
+**Q:** 预测时最大长度超过了训练时的最大长度怎么办
+**A:** 由于训练时 `max_length` 的设置决定了保存模型 position encoding 的大小,若预测时长度超过 `max_length`,请调大该值,会重新生成更大的 position encoding 表。
+
+
+## 参考文献
+1. Vaswani A, Shazeer N, Parmar N, et al. [Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)[C]//Advances in Neural Information Processing Systems. 2017: 6000-6010.
+2. Devlin J, Chang M W, Lee K, et al. [Bert: Pre-training of deep bidirectional transformers for language understanding](https://arxiv.org/abs/1810.04805)[J]. arXiv preprint arXiv:1810.04805, 2018.
+3. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
+4. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
+5. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015.
+
+
+## 作者
+- [guochengCS](https://github.com/guoshengCS)
+
+## 如何贡献代码
+
+如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
diff --git a/transformer/gen_data.sh b/transformer/gen_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e00ae05d9c5cc59b7b401428f6e1252397debfe9
--- /dev/null
+++ b/transformer/gen_data.sh
@@ -0,0 +1,220 @@
+#! /usr/bin/env bash
+
+set -e
+
+OUTPUT_DIR=$PWD/gen_data
+
+###############################################################################
+# change these variables for other WMT data
+###############################################################################
+OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt16_ende_data"
+OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt16_ende_data_bpe"
+LANG1="en"
+LANG2="de"
+# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
+TRAIN_DATA=(
+'http://www.statmt.org/europarl/v7/de-en.tgz'
+'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
+'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
+'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
+'http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz'
+'news-commentary-v11.de-en.en' 'news-commentary-v11.de-en.de'
+)
+# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
+DEV_TEST_DATA=(
+'http://data.statmt.org/wmt16/translation-task/dev.tgz'
+'newstest201[45]-deen-ref.en.sgm' 'newstest201[45]-deen-src.de.sgm'
+'http://data.statmt.org/wmt16/translation-task/test.tgz'
+'newstest2016-deen-ref.en.sgm' 'newstest2016-deen-src.de.sgm'
+)
+###############################################################################
+
+###############################################################################
+# change these variables for other WMT data
+###############################################################################
+# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
+# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
+# LANG1="en"
+# LANG2="fr"
+# # each of TRAIN_DATA: ata_url data_tgz data_file
+# TRAIN_DATA=(
+# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
+# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
+# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
+# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
+# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
+# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
+# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
+# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
+# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
+# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
+# )
+# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
+# DEV_TEST_DATA=(
+# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
+# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
+# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
+# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
+# )
+###############################################################################
+
+mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
+
+# Extract training data
+for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
+ data_url=${TRAIN_DATA[i]}
+ data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
+ data=${data_tgz%.*} # training-parallel-commoncrawl
+ data_lang1=${TRAIN_DATA[i+1]}
+ data_lang2=${TRAIN_DATA[i+2]}
+ if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
+ echo "Download "${data_url}
+ wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
+ fi
+
+ if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
+ echo "Extract "${data_tgz}
+ mkdir -p ${OUTPUT_DIR_DATA}/${data}
+ tar_type=${data_tgz:0-3}
+ if [ ${tar_type} == "tar" ]; then
+ tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
+ else
+ tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
+ fi
+ fi
+ # concatenate all training data
+ for data_lang in $data_lang1 $data_lang2; do
+ for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
+ data_dir=`dirname $f`
+ data_file=`basename $f`
+ f_base=${f%.*}
+ f_ext=${f##*.}
+ if [ $f_ext == "gz" ]; then
+ gunzip $f
+ l=${f_base##*.}
+ f_base=${f_base%.*}
+ else
+ l=${f_ext}
+ fi
+
+ if [ $i -eq 0 ]; then
+ cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
+ else
+ cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
+ fi
+ done
+ done
+done
+
+# Clone mosesdecoder
+if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
+ echo "Cloning moses for data processing"
+ git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
+fi
+
+# Extract develop and test data
+dev_test_data=""
+for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
+ data_url=${DEV_TEST_DATA[i]}
+ data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
+ data=${data_tgz%.*} # training-parallel-commoncrawl
+ data_lang1=${DEV_TEST_DATA[i+1]}
+ data_lang2=${DEV_TEST_DATA[i+2]}
+ if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
+ echo "Download "${data_url}
+ wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
+ fi
+
+ if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
+ echo "Extract "${data_tgz}
+ mkdir -p ${OUTPUT_DIR_DATA}/${data}
+ tar_type=${data_tgz:0-3}
+ if [ ${tar_type} == "tar" ]; then
+ tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
+ else
+ tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
+ fi
+ fi
+
+ for data_lang in $data_lang1 $data_lang2; do
+ for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
+ data_dir=`dirname $f`
+ data_file=`basename $f`
+ data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
+ l=`echo ${data_file} | cut -d '.' -f 2` # en
+ dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
+ if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
+ ${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
+ < $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
+ fi
+ done
+ done
+done
+
+# Tokenize data
+for l in ${LANG1} ${LANG2}; do
+ for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.$l$"`; do
+ f_base=${f%.*} # dir/train dir/newstest2016
+ f_out=$f_base.tok.$l
+ if [ ! -e $f_out ]; then
+ echo "Tokenize "$f
+ ${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l $l -threads 8 < $f > $f_out
+ fi
+ done
+done
+
+# Clean data
+for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
+ f_base=${f%.*} # dir/train dir/train.tok
+ f_out=${f_base}.clean
+ if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
+ echo "Clean "${f_base}
+ ${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 80
+ fi
+done
+
+# Clone subword-nmt and generate BPE data
+if [ ! -d ${OUTPUT_DIR}/subword-nmt ]; then
+ git clone https://github.com/rsennrich/subword-nmt.git ${OUTPUT_DIR}/subword-nmt
+fi
+
+# Generate BPE data and vocabulary
+for num_operations in 32000; do
+ if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
+ echo "Learn BPE with ${num_operations} merge operations"
+ cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
+ ${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
+ fi
+
+ for l in ${LANG1} ${LANG2}; do
+ for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
+ f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
+ f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
+ f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
+ if [ ! -e $f_out ]; then
+ echo "Apply BPE to "$f
+ ${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
+ fi
+ done
+ done
+
+ if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
+ echo "Create vocabulary for BPE data"
+ cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
+ ${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
+ fi
+done
+
+# Adapt to the reader
+for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
+ f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
+ f_out=${f_base}.${LANG1}-${LANG2}
+ if [ ! -e $f_out ]; then
+ paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
+ fi
+done
+if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
+ sed '1i\\n\n' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
+fi
+
+echo "All done."
diff --git a/transformer/images/multi_head_attention.png b/transformer/images/multi_head_attention.png
new file mode 100644
index 0000000000000000000000000000000000000000..427fb6b32aaeb7013066a167aab4fb97c024c2d6
Binary files /dev/null and b/transformer/images/multi_head_attention.png differ
diff --git a/transformer/images/transformer_network.png b/transformer/images/transformer_network.png
new file mode 100644
index 0000000000000000000000000000000000000000..34be0e5c7e2b08f858683d86353db5e81049c7ca
Binary files /dev/null and b/transformer/images/transformer_network.png differ
diff --git a/transformer/predict.py b/transformer/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a47ccdaef7426505ab69ee93ec20bfb2f765513
--- /dev/null
+++ b/transformer/predict.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import six
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from functools import partial
+
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.io import DataLoader
+from paddle.fluid.layers.utils import flatten
+
+from utils.configure import PDConfig
+from utils.check import check_gpu, check_version
+
+from model import Input, set_device
+from reader import prepare_infer_input, Seq2SeqDataset, Seq2SeqBatchSampler
+from transformer import InferTransformer, position_encoding_init
+
+
+def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
+ output_eos=False):
+ """
+ Post-process the decoded sequence.
+ """
+ eos_pos = len(seq) - 1
+ for i, idx in enumerate(seq):
+ if idx == eos_idx:
+ eos_pos = i
+ break
+ seq = [
+ idx for idx in seq[:eos_pos + 1]
+ if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
+ ]
+ return seq
+
+
+def do_predict(args):
+ device = set_device("gpu" if args.use_cuda else "cpu")
+ fluid.enable_dygraph(device) if args.eager_run else None
+
+ inputs = [
+ Input(
+ [None, None], "int64", name="src_word"),
+ Input(
+ [None, None], "int64", name="src_pos"),
+ Input(
+ [None, args.n_head, None, None],
+ "float32",
+ name="src_slf_attn_bias"),
+ Input(
+ [None, args.n_head, None, None],
+ "float32",
+ name="trg_src_attn_bias"),
+ ]
+
+ # define data
+ dataset = Seq2SeqDataset(
+ fpattern=args.predict_file,
+ src_vocab_fpath=args.src_vocab_fpath,
+ trg_vocab_fpath=args.trg_vocab_fpath,
+ token_delimiter=args.token_delimiter,
+ start_mark=args.special_token[0],
+ end_mark=args.special_token[1],
+ unk_mark=args.special_token[2],
+ byte_data=True)
+ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
+ args.unk_idx = dataset.get_vocab_summary()
+ trg_idx2word = Seq2SeqDataset.load_dict(
+ dict_path=args.trg_vocab_fpath, reverse=True, byte_data=True)
+ batch_sampler = Seq2SeqBatchSampler(
+ dataset=dataset,
+ use_token_batch=False,
+ batch_size=args.batch_size,
+ max_length=args.max_length)
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_sampler=batch_sampler,
+ places=device,
+ collate_fn=partial(
+ prepare_infer_input,
+ bos_idx=args.bos_idx,
+ eos_idx=args.eos_idx,
+ src_pad_idx=args.eos_idx,
+ n_head=args.n_head),
+ num_workers=0,
+ return_list=True)
+
+ # define model
+ transformer = InferTransformer(
+ args.src_vocab_size,
+ args.trg_vocab_size,
+ args.max_length + 1,
+ args.n_layer,
+ args.n_head,
+ args.d_key,
+ args.d_value,
+ args.d_model,
+ args.d_inner_hid,
+ args.prepostprocess_dropout,
+ args.attention_dropout,
+ args.relu_dropout,
+ args.preprocess_cmd,
+ args.postprocess_cmd,
+ args.weight_sharing,
+ args.bos_idx,
+ args.eos_idx,
+ beam_size=args.beam_size,
+ max_out_len=args.max_out_len)
+ transformer.prepare(inputs=inputs)
+
+ # load the trained model
+ assert args.init_from_params, (
+ "Please set init_from_params to load the infer model.")
+ transformer.load(args.init_from_params)
+
+ # TODO: use model.predict when support variant length
+ f = open(args.output_file, "wb")
+ for data in data_loader():
+ finished_seq = transformer.test(inputs=flatten(data))[0]
+ finished_seq = np.transpose(finished_seq, [0, 2, 1])
+ for ins in finished_seq:
+ for beam_idx, beam in enumerate(ins):
+ if beam_idx >= args.n_best: break
+ id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
+ word_list = [trg_idx2word[id] for id in id_list]
+ sequence = b" ".join(word_list) + b"\n"
+ f.write(sequence)
+
+
+if __name__ == "__main__":
+ args = PDConfig(yaml_file="./transformer.yaml")
+ args.build()
+ args.Print()
+ check_gpu(args.use_cuda)
+ check_version()
+
+ do_predict(args)
diff --git a/transformer/reader.py b/transformer/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..66fb8dc02b99f345f337d8a91b6c7eeaff71fe18
--- /dev/null
+++ b/transformer/reader.py
@@ -0,0 +1,500 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import glob
+import six
+import os
+import io
+import itertools
+from functools import partial
+
+import numpy as np
+import paddle.fluid as fluid
+from paddle.fluid.dygraph.parallel import ParallelEnv
+from paddle.fluid.io import BatchSampler, DataLoader, Dataset
+
+
+def create_data_loader(args, device):
+ data_loaders = [None, None]
+ data_files = [args.training_file, args.validation_file
+ ] if args.validation_file else [args.training_file]
+ for i, data_file in enumerate(data_files):
+ dataset = Seq2SeqDataset(
+ fpattern=data_file,
+ src_vocab_fpath=args.src_vocab_fpath,
+ trg_vocab_fpath=args.trg_vocab_fpath,
+ token_delimiter=args.token_delimiter,
+ start_mark=args.special_token[0],
+ end_mark=args.special_token[1],
+ unk_mark=args.special_token[2],
+ byte_data=True)
+ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
+ args.unk_idx = dataset.get_vocab_summary()
+ batch_sampler = Seq2SeqBatchSampler(
+ dataset=dataset,
+ use_token_batch=args.use_token_batch,
+ batch_size=args.batch_size,
+ pool_size=args.pool_size,
+ sort_type=args.sort_type,
+ shuffle=args.shuffle,
+ shuffle_batch=args.shuffle_batch,
+ max_length=args.max_length,
+ distribute_mode=True
+ if i == 0 else False) # every device eval all data
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_sampler=batch_sampler,
+ places=device,
+ collate_fn=partial(
+ prepare_train_input,
+ bos_idx=args.bos_idx,
+ eos_idx=args.eos_idx,
+ src_pad_idx=args.eos_idx,
+ trg_pad_idx=args.eos_idx,
+ n_head=args.n_head),
+ num_workers=0, # TODO: use multi-process
+ return_list=True)
+ data_loaders[i] = data_loader
+ return data_loaders
+
+
+def prepare_train_input(insts, bos_idx, eos_idx, src_pad_idx, trg_pad_idx,
+ n_head):
+ """
+ Put all padded data needed by training into a list.
+ """
+ src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
+ [inst[0] + [eos_idx] for inst in insts],
+ src_pad_idx,
+ n_head,
+ is_target=False)
+ src_word = src_word.reshape(-1, src_max_len)
+ src_pos = src_pos.reshape(-1, src_max_len)
+ trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
+ [[bos_idx] + inst[1] for inst in insts],
+ trg_pad_idx,
+ n_head,
+ is_target=True)
+ trg_word = trg_word.reshape(-1, trg_max_len)
+ trg_pos = trg_pos.reshape(-1, trg_max_len)
+
+ trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
+ [1, 1, trg_max_len, 1]).astype("float32")
+
+ lbl_word, lbl_weight, num_token = pad_batch_data(
+ [inst[1] + [eos_idx] for inst in insts],
+ trg_pad_idx,
+ n_head,
+ is_target=False,
+ is_label=True,
+ return_attn_bias=False,
+ return_max_len=False,
+ return_num_token=True)
+ lbl_word = lbl_word.reshape(-1, 1)
+ lbl_weight = lbl_weight.reshape(-1, 1)
+
+ data_inputs = [
+ src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
+ trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
+ ]
+
+ return data_inputs
+
+
+def prepare_infer_input(insts, bos_idx, eos_idx, src_pad_idx, n_head):
+ """
+ Put all padded data needed by beam search decoder into a list.
+ """
+ src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
+ [inst[0] + [eos_idx] for inst in insts],
+ src_pad_idx,
+ n_head,
+ is_target=False)
+ trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
+ [1, 1, 1, 1]).astype("float32")
+ src_word = src_word.reshape(-1, src_max_len)
+ src_pos = src_pos.reshape(-1, src_max_len)
+
+ data_inputs = [src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias]
+ return data_inputs
+
+
+def pad_batch_data(insts,
+ pad_idx,
+ n_head,
+ is_target=False,
+ is_label=False,
+ return_attn_bias=True,
+ return_max_len=True,
+ return_num_token=False):
+ """
+ Pad the instances to the max sequence length in batch, and generate the
+ corresponding position data and attention bias.
+ """
+ return_list = []
+ max_len = max(len(inst) for inst in insts)
+ # Any token included in dict can be used to pad, since the paddings' loss
+ # will be masked out by weights and make no effect on parameter gradients.
+ inst_data = np.array(
+ [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
+ return_list += [inst_data.astype("int64").reshape([-1, 1])]
+ if is_label: # label weight
+ inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst))
+ for inst in insts])
+ return_list += [inst_weight.astype("float32").reshape([-1, 1])]
+ else: # position data
+ inst_pos = np.array([
+ list(range(0, len(inst))) + [0] * (max_len - len(inst))
+ for inst in insts
+ ])
+ return_list += [inst_pos.astype("int64").reshape([-1, 1])]
+ if return_attn_bias:
+ if is_target:
+ # This is used to avoid attention on paddings and subsequent
+ # words.
+ slf_attn_bias_data = np.ones(
+ (inst_data.shape[0], max_len, max_len))
+ slf_attn_bias_data = np.triu(slf_attn_bias_data,
+ 1).reshape([-1, 1, max_len, max_len])
+ slf_attn_bias_data = np.tile(slf_attn_bias_data,
+ [1, n_head, 1, 1]) * [-1e9]
+ else:
+ # This is used to avoid attention on paddings.
+ slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
+ (max_len - len(inst))
+ for inst in insts])
+ slf_attn_bias_data = np.tile(
+ slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
+ [1, n_head, max_len, 1])
+ return_list += [slf_attn_bias_data.astype("float32")]
+ if return_max_len:
+ return_list += [max_len]
+ if return_num_token:
+ num_token = 0
+ for inst in insts:
+ num_token += len(inst)
+ return_list += [num_token]
+ return return_list if len(return_list) > 1 else return_list[0]
+
+
+class SortType(object):
+ GLOBAL = 'global'
+ POOL = 'pool'
+ NONE = "none"
+
+
+class Converter(object):
+ def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
+ self._vocab = vocab
+ self._beg = beg
+ self._end = end
+ self._unk = unk
+ self._delimiter = delimiter
+ self._add_beg = add_beg
+ self._add_end = add_end
+
+ def __call__(self, sentence):
+ return ([self._beg] if self._add_beg else []) + [
+ self._vocab.get(w, self._unk)
+ for w in sentence.split(self._delimiter)
+ ] + ([self._end] if self._add_end else [])
+
+
+class ComposedConverter(object):
+ def __init__(self, converters):
+ self._converters = converters
+
+ def __call__(self, fields):
+ return [
+ converter(field)
+ for field, converter in zip(fields, self._converters)
+ ]
+
+
+class SentenceBatchCreator(object):
+ def __init__(self, batch_size):
+ self.batch = []
+ self._batch_size = batch_size
+
+ def append(self, info):
+ self.batch.append(info)
+ if len(self.batch) == self._batch_size:
+ tmp = self.batch
+ self.batch = []
+ return tmp
+
+
+class TokenBatchCreator(object):
+ def __init__(self, batch_size):
+ self.batch = []
+ self.max_len = -1
+ self._batch_size = batch_size
+
+ def append(self, info):
+ cur_len = info.max_len
+ max_len = max(self.max_len, cur_len)
+ if max_len * (len(self.batch) + 1) > self._batch_size:
+ result = self.batch
+ self.batch = [info]
+ self.max_len = cur_len
+ return result
+ else:
+ self.max_len = max_len
+ self.batch.append(info)
+
+
+class SampleInfo(object):
+ def __init__(self, i, lens):
+ self.i = i
+ # take bos and eos into account
+ self.min_len = min(lens[0] + 1, lens[1] + 2)
+ self.max_len = max(lens[0] + 1, lens[1] + 2)
+
+
+class MinMaxFilter(object):
+ def __init__(self, max_len, min_len, underlying_creator):
+ self._min_len = min_len
+ self._max_len = max_len
+ self._creator = underlying_creator
+
+ def append(self, info):
+ if info.max_len > self._max_len or info.min_len < self._min_len:
+ return
+ else:
+ return self._creator.append(info)
+
+ @property
+ def batch(self):
+ return self._creator.batch
+
+
+class Seq2SeqDataset(Dataset):
+ def __init__(self,
+ src_vocab_fpath,
+ trg_vocab_fpath,
+ fpattern,
+ field_delimiter="\t",
+ token_delimiter=" ",
+ start_mark="",
+ end_mark="",
+ unk_mark="",
+ only_src=False,
+ trg_fpattern=None,
+ byte_data=False):
+ if byte_data:
+ # The WMT16 bpe data used here seems including bytes can not be
+ # decoded by utf8. Thus convert str to bytes, and use byte data
+ field_delimiter = field_delimiter.encode("utf8")
+ token_delimiter = token_delimiter.encode("utf8")
+ start_mark = start_mark.encode("utf8")
+ end_mark = end_mark.encode("utf8")
+ unk_mark = unk_mark.encode("utf8")
+ self._byte_data = byte_data
+ self._src_vocab = self.load_dict(src_vocab_fpath, byte_data=byte_data)
+ self._trg_vocab = self.load_dict(trg_vocab_fpath, byte_data=byte_data)
+ self._bos_idx = self._src_vocab[start_mark]
+ self._eos_idx = self._src_vocab[end_mark]
+ self._unk_idx = self._src_vocab[unk_mark]
+ self._field_delimiter = field_delimiter
+ self._token_delimiter = token_delimiter
+ self.load_src_trg_ids(fpattern, trg_fpattern)
+
+ def load_src_trg_ids(self, fpattern, trg_fpattern=None):
+ src_converter = Converter(
+ vocab=self._src_vocab,
+ beg=self._bos_idx,
+ end=self._eos_idx,
+ unk=self._unk_idx,
+ delimiter=self._token_delimiter,
+ add_beg=False,
+ add_end=False)
+
+ trg_converter = Converter(
+ vocab=self._trg_vocab,
+ beg=self._bos_idx,
+ end=self._eos_idx,
+ unk=self._unk_idx,
+ delimiter=self._token_delimiter,
+ add_beg=False,
+ add_end=False)
+
+ converters = ComposedConverter([src_converter, trg_converter])
+
+ self._src_seq_ids = []
+ self._trg_seq_ids = []
+ self._sample_infos = []
+
+ slots = [self._src_seq_ids, self._trg_seq_ids]
+ for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
+ lens = []
+ for field, slot in zip(converters(line), slots):
+ slot.append(field)
+ lens.append(len(field))
+ self._sample_infos.append(SampleInfo(i, lens))
+
+ def _load_lines(self, fpattern, trg_fpattern=None):
+ fpaths = glob.glob(fpattern)
+ fpaths = sorted(fpaths) # TODO: Add custum sort
+ assert len(fpaths) > 0, "no matching file to the provided data path"
+
+ (f_mode, f_encoding,
+ endl) = ("rb", None, b"\n") if self._byte_data else ("r", "utf8",
+ "\n")
+ if trg_fpattern is None:
+ for fpath in fpaths:
+ with io.open(fpath, f_mode, encoding=f_encoding) as f:
+ for line in f:
+ fields = line.strip(endl).split(self._field_delimiter)
+ yield fields
+ else:
+ # separated source and target language data files
+ # assume we can get aligned data by sort the two language files
+ # TODO: Need more rigorous check
+ trg_fpaths = glob.glob(trg_fpattern)
+ trg_fpaths = sorted(trg_fpaths)
+ assert len(fpaths) == len(
+ trg_fpaths
+ ), "the number of source language data files must equal \
+ with that of source language"
+
+ for fpath, trg_fpath in zip(fpaths, trg_fpaths):
+ with io.open(fpath, f_mode, encoding=f_encoding) as f:
+ with io.open(
+ trg_fpath, f_mode, encoding=f_encoding) as trg_f:
+ for line in zip(f, trg_f):
+ fields = [field.strip(endl) for field in line]
+ yield fields
+
+ @staticmethod
+ def load_dict(dict_path, reverse=False, byte_data=False):
+ word_dict = {}
+ (f_mode, f_encoding,
+ endl) = ("rb", None, b"\n") if byte_data else ("r", "utf8", "\n")
+ with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
+ for idx, line in enumerate(fdict):
+ if reverse:
+ word_dict[idx] = line.strip(endl)
+ else:
+ word_dict[line.strip(endl)] = idx
+ return word_dict
+
+ def get_vocab_summary(self):
+ return len(self._src_vocab), len(
+ self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
+
+ def __getitem__(self, idx):
+ return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
+ ) if self._trg_seq_ids else self._src_seq_ids[idx]
+
+ def __len__(self):
+ return len(self._sample_infos)
+
+
+class Seq2SeqBatchSampler(BatchSampler):
+ def __init__(self,
+ dataset,
+ batch_size,
+ pool_size=10000,
+ sort_type=SortType.NONE,
+ min_length=0,
+ max_length=100,
+ shuffle=False,
+ shuffle_batch=False,
+ use_token_batch=False,
+ clip_last_batch=False,
+ distribute_mode=True,
+ seed=0):
+ for arg, value in locals().items():
+ if arg != "self":
+ setattr(self, "_" + arg, value)
+ self._random = np.random
+ self._random.seed(seed)
+ # for multi-devices
+ self._distribute_mode = distribute_mode
+ self._nranks = ParallelEnv().nranks
+ self._local_rank = ParallelEnv().local_rank
+ self._device_id = ParallelEnv().dev_id
+
+ def __iter__(self):
+ # global sort or global shuffle
+ if self._sort_type == SortType.GLOBAL:
+ infos = sorted(
+ self._dataset._sample_infos, key=lambda x: x.max_len)
+ else:
+ if self._shuffle:
+ infos = self._dataset._sample_infos
+ self._random.shuffle(infos)
+ else:
+ infos = self._dataset._sample_infos
+
+ if self._sort_type == SortType.POOL:
+ reverse = True
+ for i in range(0, len(infos), self._pool_size):
+ # to avoid placing short next to long sentences
+ reverse = not reverse
+ infos[i:i + self._pool_size] = sorted(
+ infos[i:i + self._pool_size],
+ key=lambda x: x.max_len,
+ reverse=reverse)
+
+ batches = []
+ batch_creator = TokenBatchCreator(
+ self.
+ _batch_size) if self._use_token_batch else SentenceBatchCreator(
+ self._batch_size * self._nranks)
+ batch_creator = MinMaxFilter(self._max_length, self._min_length,
+ batch_creator)
+
+ for info in infos:
+ batch = batch_creator.append(info)
+ if batch is not None:
+ batches.append(batch)
+
+ if not self._clip_last_batch and len(batch_creator.batch) != 0:
+ batches.append(batch_creator.batch)
+
+ if self._shuffle_batch:
+ self._random.shuffle(batches)
+
+ if not self._use_token_batch:
+ # when producing batches according to sequence number, to confirm
+ # neighbor batches which would be feed and run parallel have similar
+ # length (thus similar computational cost) after shuffle, we as take
+ # them as a whole when shuffling and split here
+ batches = [[
+ batch[self._batch_size * i:self._batch_size * (i + 1)]
+ for i in range(self._nranks)
+ ] for batch in batches]
+ batches = list(itertools.chain.from_iterable(batches))
+
+ # for multi-device
+ for batch_id, batch in enumerate(batches):
+ if not self._distribute_mode or (
+ batch_id % self._nranks == self._local_rank):
+ batch_indices = [info.i for info in batch]
+ yield batch_indices
+ if self._distribute_mode and len(batches) % self._nranks != 0:
+ if self._local_rank >= len(batches) % self._nranks:
+ # use previous data to pad
+ yield batch_indices
+
+ def __len__(self):
+ if not self._use_token_batch:
+ batch_number = (
+ len(self._dataset) + self._batch_size * self._nranks - 1) // (
+ self._batch_size * self._nranks)
+ else:
+ # TODO(guosheng): fix the uncertain length
+ batch_number = 1
+ return batch_number
diff --git a/transformer/train.py b/transformer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..58df6afb2cadc3b471aaffd4ed4caebbbc0bbc3d
--- /dev/null
+++ b/transformer/train.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import six
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.io import DataLoader
+
+from utils.configure import PDConfig
+from utils.check import check_gpu, check_version
+
+from model import Input, set_device
+from callbacks import ProgBarLogger
+from reader import create_data_loader
+from transformer import Transformer, CrossEntropyCriterion
+
+
+class TrainCallback(ProgBarLogger):
+ def __init__(self, args, verbose=2):
+ # TODO(guosheng): save according to step
+ super(TrainCallback, self).__init__(args.print_step, verbose)
+ # the best cross-entropy value with label smoothing
+ loss_normalizer = -(
+ (1. - args.label_smooth_eps) * np.log(
+ (1. - args.label_smooth_eps)) + args.label_smooth_eps *
+ np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
+ self.loss_normalizer = loss_normalizer
+
+ def on_train_begin(self, logs=None):
+ super(TrainCallback, self).on_train_begin(logs)
+ self.train_metrics += ["normalized loss", "ppl"]
+
+ def on_train_batch_end(self, step, logs=None):
+ logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
+ logs["ppl"] = np.exp(min(logs["loss"][0], 100))
+ super(TrainCallback, self).on_train_batch_end(step, logs)
+
+ def on_eval_begin(self, logs=None):
+ super(TrainCallback, self).on_eval_begin(logs)
+ self.eval_metrics = list(
+ self.eval_metrics) + ["normalized loss", "ppl"]
+
+ def on_eval_batch_end(self, step, logs=None):
+ logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
+ logs["ppl"] = np.exp(min(logs["loss"][0], 100))
+ super(TrainCallback, self).on_eval_batch_end(step, logs)
+
+
+def do_train(args):
+ device = set_device("gpu" if args.use_cuda else "cpu")
+ fluid.enable_dygraph(device) if args.eager_run else None
+
+ # set seed for CE
+ random_seed = eval(str(args.random_seed))
+ if random_seed is not None:
+ fluid.default_main_program().random_seed = random_seed
+ fluid.default_startup_program().random_seed = random_seed
+
+ # define inputs
+ inputs = [
+ Input(
+ [None, None], "int64", name="src_word"),
+ Input(
+ [None, None], "int64", name="src_pos"),
+ Input(
+ [None, args.n_head, None, None],
+ "float32",
+ name="src_slf_attn_bias"),
+ Input(
+ [None, None], "int64", name="trg_word"),
+ Input(
+ [None, None], "int64", name="trg_pos"),
+ Input(
+ [None, args.n_head, None, None],
+ "float32",
+ name="trg_slf_attn_bias"),
+ Input(
+ [None, args.n_head, None, None],
+ "float32",
+ name="trg_src_attn_bias"),
+ ]
+ labels = [
+ Input(
+ [None, 1], "int64", name="label"),
+ Input(
+ [None, 1], "float32", name="weight"),
+ ]
+
+ # def dataloader
+ train_loader, eval_loader = create_data_loader(args, device)
+
+ # define model
+ transformer = Transformer(
+ args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
+ args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
+ args.d_inner_hid, args.prepostprocess_dropout, args.attention_dropout,
+ args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
+ args.weight_sharing, args.bos_idx, args.eos_idx)
+
+ transformer.prepare(
+ fluid.optimizer.Adam(
+ learning_rate=fluid.layers.noam_decay(
+ args.d_model,
+ args.warmup_steps,
+ learning_rate=args.learning_rate),
+ beta1=args.beta1,
+ beta2=args.beta2,
+ epsilon=float(args.eps),
+ parameter_list=transformer.parameters()),
+ CrossEntropyCriterion(args.label_smooth_eps),
+ inputs=inputs,
+ labels=labels)
+
+ ## init from some checkpoint, to resume the previous training
+ if args.init_from_checkpoint:
+ transformer.load(args.init_from_checkpoint)
+ ## init from some pretrain models, to better solve the current task
+ if args.init_from_pretrain_model:
+ transformer.load(args.init_from_pretrain_model, reset_optimizer=True)
+
+ # model train
+ transformer.fit(train_data=train_loader,
+ eval_data=eval_loader,
+ epochs=args.epoch,
+ eval_freq=1,
+ save_freq=1,
+ save_dir=args.save_model,
+ callbacks=[TrainCallback(args)])
+
+
+if __name__ == "__main__":
+ args = PDConfig(yaml_file="./transformer.yaml")
+ args.build()
+ args.Print()
+ check_gpu(args.use_cuda)
+ check_version()
+
+ do_train(args)
diff --git a/transformer/transformer.py b/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9caf4b04a1a34c5e856a789fbded8a53e917a3da
--- /dev/null
+++ b/transformer/transformer.py
@@ -0,0 +1,692 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import numpy as np
+
+import paddle.fluid as fluid
+import paddle.fluid.layers as layers
+from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
+from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
+from model import Model, CrossEntropy, Loss
+from text import TransformerBeamSearchDecoder, DynamicDecode
+
+
+def position_encoding_init(n_position, d_pos_vec):
+ """
+ Generate the initial values for the sinusoid position encoding table.
+ """
+ channels = d_pos_vec
+ position = np.arange(n_position)
+ num_timescales = channels // 2
+ log_timescale_increment = (np.log(float(1e4) / float(1)) /
+ (num_timescales - 1))
+ inv_timescales = np.exp(np.arange(
+ num_timescales)) * -log_timescale_increment
+ scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
+ 0)
+ signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
+ signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
+ position_enc = signal
+ return position_enc.astype("float32")
+
+
+class NoamDecay(LearningRateDecay):
+ """
+ learning rate scheduler
+ """
+
+ def __init__(self,
+ d_model,
+ warmup_steps,
+ static_lr=2.0,
+ begin=1,
+ step=1,
+ dtype='float32'):
+ super(NoamDecay, self).__init__(begin, step, dtype)
+ self.d_model = d_model
+ self.warmup_steps = warmup_steps
+ self.static_lr = static_lr
+
+ def step(self):
+ a = self.create_lr_var(self.step_num**-0.5)
+ b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
+ lr_value = (self.d_model**-0.5) * layers.elementwise_min(
+ a, b) * self.static_lr
+ return lr_value
+
+
+class PrePostProcessLayer(Layer):
+ """
+ PrePostProcessLayer
+ """
+
+ def __init__(self, process_cmd, d_model, dropout_rate):
+ super(PrePostProcessLayer, self).__init__()
+ self.process_cmd = process_cmd
+ self.functors = []
+ for cmd in self.process_cmd:
+ if cmd == "a": # add residual connection
+ self.functors.append(
+ lambda x, y: x + y if y is not None else x)
+ elif cmd == "n": # add layer normalization
+ self.functors.append(
+ self.add_sublayer(
+ "layer_norm_%d" % len(
+ self.sublayers(include_sublayers=False)),
+ LayerNorm(
+ normalized_shape=d_model,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(1.)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.)))))
+ elif cmd == "d": # add dropout
+ self.functors.append(lambda x: layers.dropout(
+ x, dropout_prob=dropout_rate, is_test=False)
+ if dropout_rate else x)
+
+ def forward(self, x, residual=None):
+ for i, cmd in enumerate(self.process_cmd):
+ if cmd == "a":
+ x = self.functors[i](x, residual)
+ else:
+ x = self.functors[i](x)
+ return x
+
+
+class MultiHeadAttention(Layer):
+ """
+ Multi-Head Attention
+ """
+
+ def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
+ super(MultiHeadAttention, self).__init__()
+ self.n_head = n_head
+ self.d_key = d_key
+ self.d_value = d_value
+ self.d_model = d_model
+ self.dropout_rate = dropout_rate
+ self.q_fc = Linear(
+ input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
+ self.k_fc = Linear(
+ input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
+ self.v_fc = Linear(
+ input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
+ self.proj_fc = Linear(
+ input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
+
+ def _prepare_qkv(self, queries, keys, values, cache=None):
+ if keys is None: # self-attention
+ keys, values = queries, queries
+ static_kv = False
+ else: # cross-attention
+ static_kv = True
+
+ q = self.q_fc(queries)
+ q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
+ q = layers.transpose(x=q, perm=[0, 2, 1, 3])
+
+ if cache is not None and static_kv and "static_k" in cache:
+ # for encoder-decoder attention in inference and has cached
+ k = cache["static_k"]
+ v = cache["static_v"]
+ else:
+ k = self.k_fc(keys)
+ v = self.v_fc(values)
+ k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
+ k = layers.transpose(x=k, perm=[0, 2, 1, 3])
+ v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
+ v = layers.transpose(x=v, perm=[0, 2, 1, 3])
+
+ if cache is not None:
+ if static_kv and not "static_k" in cache:
+ # for encoder-decoder attention in inference and has not cached
+ cache["static_k"], cache["static_v"] = k, v
+ elif not static_kv:
+ # for decoder self-attention in inference
+ cache_k, cache_v = cache["k"], cache["v"]
+ k = layers.concat([cache_k, k], axis=2)
+ v = layers.concat([cache_v, v], axis=2)
+ cache["k"], cache["v"] = k, v
+
+ return q, k, v
+
+ def forward(self, queries, keys, values, attn_bias, cache=None):
+ # compute q ,k ,v
+ q, k, v = self._prepare_qkv(queries, keys, values, cache)
+
+ # scale dot product attention
+ product = layers.matmul(
+ x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
+ if attn_bias is not None:
+ product += attn_bias
+ weights = layers.softmax(product)
+ if self.dropout_rate:
+ weights = layers.dropout(
+ weights, dropout_prob=self.dropout_rate, is_test=False)
+
+ out = layers.matmul(weights, v)
+
+ # combine heads
+ out = layers.transpose(out, perm=[0, 2, 1, 3])
+ out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
+
+ # project to output
+ out = self.proj_fc(out)
+ return out
+
+ def cal_kv(self, keys, values):
+ k = self.k_fc(keys)
+ v = self.v_fc(values)
+ k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
+ k = layers.transpose(x=k, perm=[0, 2, 1, 3])
+ v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
+ v = layers.transpose(x=v, perm=[0, 2, 1, 3])
+ return k, v
+
+
+class FFN(Layer):
+ """
+ Feed-Forward Network
+ """
+
+ def __init__(self, d_inner_hid, d_model, dropout_rate):
+ super(FFN, self).__init__()
+ self.dropout_rate = dropout_rate
+ self.fc1 = Linear(
+ input_dim=d_model, output_dim=d_inner_hid, act="relu")
+ self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
+
+ def forward(self, x):
+ hidden = self.fc1(x)
+ if self.dropout_rate:
+ hidden = layers.dropout(
+ hidden, dropout_prob=self.dropout_rate, is_test=False)
+ out = self.fc2(hidden)
+ return out
+
+
+class EncoderLayer(Layer):
+ """
+ EncoderLayer
+ """
+
+ def __init__(self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(EncoderLayer, self).__init__()
+
+ self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
+ self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ attn_output = self.self_attn(
+ self.preprocesser1(enc_input), None, None, attn_bias)
+ attn_output = self.postprocesser1(attn_output, enc_input)
+
+ ffn_output = self.ffn(self.preprocesser2(attn_output))
+ ffn_output = self.postprocesser2(ffn_output, attn_output)
+ return ffn_output
+
+
+class Encoder(Layer):
+ """
+ encoder
+ """
+
+ def __init__(self,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(Encoder, self).__init__()
+
+ self.encoder_layers = list()
+ for i in range(n_layer):
+ self.encoder_layers.append(
+ self.add_sublayer(
+ "layer_%d" % i,
+ EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd,
+ postprocess_cmd)))
+ self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ for encoder_layer in self.encoder_layers:
+ enc_output = encoder_layer(enc_input, attn_bias)
+ enc_input = enc_output
+
+ return self.processer(enc_output)
+
+
+class Embedder(Layer):
+ """
+ Word Embedding + Position Encoding
+ """
+
+ def __init__(self, vocab_size, emb_dim, bos_idx=0):
+ super(Embedder, self).__init__()
+
+ self.word_embedder = Embedding(
+ size=[vocab_size, emb_dim],
+ padding_idx=bos_idx,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Normal(0., emb_dim**-0.5)))
+
+ def forward(self, word):
+ word_emb = self.word_embedder(word)
+ return word_emb
+
+
+class WrapEncoder(Layer):
+ """
+ embedder + encoder
+ """
+
+ def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key,
+ d_value, d_model, d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd, word_embedder):
+ super(WrapEncoder, self).__init__()
+
+ self.emb_dropout = prepostprocess_dropout
+ self.emb_dim = d_model
+ self.word_embedder = word_embedder
+ self.pos_encoder = Embedding(
+ size=[max_length, self.emb_dim],
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.NumpyArrayInitializer(
+ position_encoding_init(max_length, self.emb_dim)),
+ trainable=False))
+
+ self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ def forward(self, src_word, src_pos, src_slf_attn_bias):
+ word_emb = self.word_embedder(src_word)
+ word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5)
+ pos_enc = self.pos_encoder(src_pos)
+ pos_enc.stop_gradient = True
+ emb = word_emb + pos_enc
+ enc_input = layers.dropout(
+ emb, dropout_prob=self.emb_dropout,
+ is_test=False) if self.emb_dropout else emb
+
+ enc_output = self.encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class DecoderLayer(Layer):
+ """
+ decoder
+ """
+
+ def __init__(self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+ super(DecoderLayer, self).__init__()
+
+ self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.cross_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
+ self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self,
+ dec_input,
+ enc_output,
+ self_attn_bias,
+ cross_attn_bias,
+ cache=None):
+ self_attn_output = self.self_attn(
+ self.preprocesser1(dec_input), None, None, self_attn_bias, cache)
+ self_attn_output = self.postprocesser1(self_attn_output, dec_input)
+
+ cross_attn_output = self.cross_attn(
+ self.preprocesser2(self_attn_output), enc_output, enc_output,
+ cross_attn_bias, cache)
+ cross_attn_output = self.postprocesser2(cross_attn_output,
+ self_attn_output)
+
+ ffn_output = self.ffn(self.preprocesser3(cross_attn_output))
+ ffn_output = self.postprocesser3(ffn_output, cross_attn_output)
+
+ return ffn_output
+
+
+class Decoder(Layer):
+ """
+ decoder
+ """
+
+ def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout, relu_dropout,
+ preprocess_cmd, postprocess_cmd):
+ super(Decoder, self).__init__()
+
+ self.decoder_layers = list()
+ for i in range(n_layer):
+ self.decoder_layers.append(
+ self.add_sublayer(
+ "layer_%d" % i,
+ DecoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd,
+ postprocess_cmd)))
+ self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self,
+ dec_input,
+ enc_output,
+ self_attn_bias,
+ cross_attn_bias,
+ caches=None):
+ for i, decoder_layer in enumerate(self.decoder_layers):
+ dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
+ cross_attn_bias, None
+ if caches is None else caches[i])
+ dec_input = dec_output
+
+ return self.processer(dec_output)
+
+ def prepare_static_cache(self, enc_output):
+ return [
+ dict(
+ zip(("static_k", "static_v"),
+ decoder_layer.cross_attn.cal_kv(enc_output, enc_output)))
+ for decoder_layer in self.decoder_layers
+ ]
+
+
+class WrapDecoder(Layer):
+ """
+ embedder + decoder
+ """
+
+ def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key,
+ d_value, d_model, d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd, share_input_output_embed, word_embedder):
+ super(WrapDecoder, self).__init__()
+
+ self.emb_dropout = prepostprocess_dropout
+ self.emb_dim = d_model
+ self.word_embedder = word_embedder
+ self.pos_encoder = Embedding(
+ size=[max_length, self.emb_dim],
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.NumpyArrayInitializer(
+ position_encoding_init(max_length, self.emb_dim)),
+ trainable=False))
+
+ self.decoder = Decoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ if share_input_output_embed:
+ self.linear = lambda x: layers.matmul(x=x,
+ y=self.word_embedder.
+ word_embedder.weight,
+ transpose_y=True)
+ else:
+ self.linear = Linear(
+ input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False)
+
+ def forward(self,
+ trg_word,
+ trg_pos,
+ trg_slf_attn_bias,
+ trg_src_attn_bias,
+ enc_output,
+ caches=None):
+ word_emb = self.word_embedder(trg_word)
+ word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5)
+ pos_enc = self.pos_encoder(trg_pos)
+ pos_enc.stop_gradient = True
+ emb = word_emb + pos_enc
+ dec_input = layers.dropout(
+ emb, dropout_prob=self.emb_dropout,
+ is_test=False) if self.emb_dropout else emb
+ dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
+ trg_src_attn_bias, caches)
+ dec_output = layers.reshape(
+ dec_output,
+ shape=[-1, dec_output.shape[-1]], )
+ logits = self.linear(dec_output)
+ return logits
+
+
+class CrossEntropyCriterion(Loss):
+ def __init__(self, label_smooth_eps):
+ super(CrossEntropyCriterion, self).__init__()
+ self.label_smooth_eps = label_smooth_eps
+
+ def forward(self, outputs, labels):
+ predict, (label, weights) = outputs[0], labels
+ if self.label_smooth_eps:
+ label = layers.label_smooth(
+ label=layers.one_hot(
+ input=label, depth=predict.shape[-1]),
+ epsilon=self.label_smooth_eps)
+
+ cost = layers.softmax_with_cross_entropy(
+ logits=predict,
+ label=label,
+ soft_label=True if self.label_smooth_eps else False)
+ weighted_cost = cost * weights
+ sum_cost = layers.reduce_sum(weighted_cost)
+ token_num = layers.reduce_sum(weights)
+ token_num.stop_gradient = True
+ avg_cost = sum_cost / token_num
+ return avg_cost
+
+
+class Transformer(Model):
+ """
+ model
+ """
+
+ def __init__(self,
+ src_vocab_size,
+ trg_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_id=0,
+ eos_id=1):
+ super(Transformer, self).__init__()
+ src_word_embedder = Embedder(
+ vocab_size=src_vocab_size, emb_dim=d_model, bos_idx=bos_id)
+ self.encoder = WrapEncoder(
+ src_vocab_size, max_length, n_layer, n_head, d_key, d_value,
+ d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd, postprocess_cmd, src_word_embedder)
+ if weight_sharing:
+ assert src_vocab_size == trg_vocab_size, (
+ "Vocabularies in source and target should be same for weight sharing."
+ )
+ trg_word_embedder = src_word_embedder
+ else:
+ trg_word_embedder = Embedder(
+ vocab_size=trg_vocab_size, emb_dim=d_model, bos_idx=bos_id)
+ self.decoder = WrapDecoder(
+ trg_vocab_size, max_length, n_layer, n_head, d_key, d_value,
+ d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing,
+ trg_word_embedder)
+
+ self.trg_vocab_size = trg_vocab_size
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.d_key = d_key
+ self.d_value = d_value
+
+ def forward(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
+ trg_slf_attn_bias, trg_src_attn_bias):
+ enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
+ predict = self.decoder(trg_word, trg_pos, trg_slf_attn_bias,
+ trg_src_attn_bias, enc_output)
+ return predict
+
+
+class TransfomerCell(object):
+ """
+ Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
+ used as RNNCell
+ """
+
+ def __init__(self, decoder):
+ self.decoder = decoder
+
+ def __call__(self, inputs, states, trg_src_attn_bias, enc_output,
+ static_caches):
+ trg_word, trg_pos = inputs
+ for cache, static_cache in zip(states, static_caches):
+ cache.update(static_cache)
+ logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
+ enc_output, states)
+ new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states]
+ return logits, new_states
+
+
+class InferTransformer(Transformer):
+ """
+ model for prediction
+ """
+
+ def __init__(self,
+ src_vocab_size,
+ trg_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_id=0,
+ eos_id=1,
+ beam_size=4,
+ max_out_len=256):
+ args = dict(locals())
+ args.pop("self")
+ args.pop("__class__", None) # py3
+ self.beam_size = args.pop("beam_size")
+ self.max_out_len = args.pop("max_out_len")
+ super(InferTransformer, self).__init__(**args)
+ cell = TransfomerCell(self.decoder)
+ self.beam_search_decoder = DynamicDecode(
+ TransformerBeamSearchDecoder(
+ cell, bos_id, eos_id, beam_size, var_dim_in_state=2),
+ max_out_len,
+ is_test=True)
+
+ def forward(self, src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias):
+ enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
+ ## init states (caches) for transformer, need to be updated according to selected beam
+ caches = [{
+ "k": layers.fill_constant_batch_size_like(
+ input=enc_output,
+ shape=[-1, self.n_head, 0, self.d_key],
+ dtype=enc_output.dtype,
+ value=0),
+ "v": layers.fill_constant_batch_size_like(
+ input=enc_output,
+ shape=[-1, self.n_head, 0, self.d_value],
+ dtype=enc_output.dtype,
+ value=0),
+ } for i in range(self.n_layer)]
+ enc_output = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
+ enc_output, self.beam_size)
+ trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
+ trg_src_attn_bias, self.beam_size)
+ static_caches = self.decoder.decoder.prepare_static_cache(enc_output)
+ rs, _ = self.beam_search_decoder(
+ inits=caches,
+ enc_output=enc_output,
+ trg_src_attn_bias=trg_src_attn_bias,
+ static_caches=static_caches)
+ return rs
diff --git a/transformer/transformer.yaml b/transformer/transformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb5ceb5634be15f0bb8233578e3aed78ef437008
--- /dev/null
+++ b/transformer/transformer.yaml
@@ -0,0 +1,112 @@
+# used for continuous evaluation
+enable_ce: False
+
+eager_run: True
+
+# The frequency to save trained models when training.
+save_step: 10000
+# The frequency to fetch and print output when training.
+print_step: 100
+# path of the checkpoint, to resume the previous training
+init_from_checkpoint: ""
+# path of the pretrain model, to better solve the current task
+init_from_pretrain_model: ""
+# path of trained parameter, to make prediction
+init_from_params: "trained_params/step_100000/"
+# the directory for saving model
+save_model: "trained_models"
+# the directory for saving inference model.
+inference_model_dir: "infer_model"
+# Set seed for CE or debug
+random_seed: None
+# The pattern to match training data files.
+training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
+# The pattern to match validation data files.
+validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
+# The pattern to match test data files.
+predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
+# The file to output the translation results of predict_file to.
+output_file: "predict.txt"
+# The path of vocabulary file of source language.
+src_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
+# The path of vocabulary file of target language.
+trg_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000"
+# The , and tokens in the dictionary.
+special_token: ["", "", ""]
+
+# whether to use cuda
+use_cuda: True
+
+# args for reader, see reader.py for details
+token_delimiter: " "
+use_token_batch: True
+pool_size: 200000
+sort_type: "pool"
+shuffle: True
+shuffle_batch: True
+batch_size: 4096
+
+# Hyparams for training:
+# the number of epoches for training
+epoch: 30
+# the hyper parameters for Adam optimizer.
+# This static learning_rate will be multiplied to the LearningRateScheduler
+# derived learning rate the to get the final learning rate.
+learning_rate: 2.0
+beta1: 0.9
+beta2: 0.997
+eps: 1e-9
+# the parameters for learning rate scheduling.
+warmup_steps: 8000
+# the weight used to mix up the ground-truth distribution and the fixed
+# uniform distribution in label smoothing when training.
+# Set this as zero if label smoothing is not wanted.
+label_smooth_eps: 0.1
+
+# Hyparams for generation:
+# the parameters for beam search.
+beam_size: 5
+max_out_len: 256
+# the number of decoded sentences to output.
+n_best: 1
+
+# Hyparams for model:
+# These following five vocabularies related configurations will be set
+# automatically according to the passed vocabulary path and special tokens.
+# size of source word dictionary.
+src_vocab_size: 10000
+# size of target word dictionay
+trg_vocab_size: 10000
+# index for token
+bos_idx: 0
+# index for token
+eos_idx: 1
+# index for token
+unk_idx: 2
+# max length of sequences deciding the size of position encoding table.
+max_length: 256
+# the dimension for word embeddings, which is also the last dimension of
+# the input and output of multi-head attention, position-wise feed-forward
+# networks, encoder and decoder.
+d_model: 512
+# size of the hidden layer in position-wise feed-forward networks.
+d_inner_hid: 2048
+# the dimension that keys are projected to for dot-product attention.
+d_key: 64
+# the dimension that values are projected to for dot-product attention.
+d_value: 64
+# number of head used in multi-head attention.
+n_head: 8
+# number of sub-layers to be stacked in the encoder and decoder.
+n_layer: 6
+# dropout rates of different modules.
+prepostprocess_dropout: 0.1
+attention_dropout: 0.1
+relu_dropout: 0.1
+# to process before each sub-layer
+preprocess_cmd: "n" # layer normalization
+# to process after each sub-layer
+postprocess_cmd: "da" # dropout + residual connection
+# the flag indicating whether to share embedding and softmax weights.
+# vocabularies in source and target should be same for weight sharing.
+weight_sharing: True
diff --git a/transformer/utils/__init__.py b/transformer/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/transformer/utils/check.py b/transformer/utils/check.py
new file mode 100644
index 0000000000000000000000000000000000000000..305fa3705f5c313569986cbdb15c8afeda5a79c1
--- /dev/null
+++ b/transformer/utils/check.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import paddle.fluid as fluid
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['check_gpu', 'check_version']
+
+
+def check_gpu(use_gpu):
+ """
+ Log error and exit when set use_gpu=true in paddlepaddle
+ cpu version.
+ """
+ err = "Config use_gpu cannot be set as true while you are " \
+ "using paddlepaddle cpu version ! \nPlease try: \n" \
+ "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
+ "\t2. Set use_gpu as false in config file to run " \
+ "model on CPU"
+
+ try:
+ if use_gpu and not fluid.is_compiled_with_cuda():
+ logger.error(err)
+ sys.exit(1)
+ except Exception as e:
+ pass
+
+
+def check_version():
+ """
+ Log error and exit when the installed version of paddlepaddle is
+ not satisfied.
+ """
+ err = "PaddlePaddle version 1.6 or higher is required, " \
+ "or a suitable develop version is satisfied as well. \n" \
+ "Please make sure the version is good with your code." \
+
+ try:
+ fluid.require_version('1.6.0')
+ except Exception as e:
+ logger.error(err)
+ sys.exit(1)
diff --git a/transformer/utils/configure.py b/transformer/utils/configure.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e601282fee572518435eaed38a4ed8e26fc5f9
--- /dev/null
+++ b/transformer/utils/configure.py
@@ -0,0 +1,350 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import argparse
+import json
+import yaml
+import six
+import logging
+
+logging_only_message = "%(message)s"
+logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
+
+
+class JsonConfig(object):
+ """
+ A high-level api for handling json configure file.
+ """
+
+ def __init__(self, config_path):
+ self._config_dict = self._parse(config_path)
+
+ def _parse(self, config_path):
+ try:
+ with open(config_path) as json_file:
+ config_dict = json.load(json_file)
+ except:
+ raise IOError("Error in parsing bert model config file '%s'" %
+ config_path)
+ else:
+ return config_dict
+
+ def __getitem__(self, key):
+ return self._config_dict[key]
+
+ def print_config(self):
+ for arg, value in sorted(six.iteritems(self._config_dict)):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+class ArgumentGroup(object):
+ def __init__(self, parser, title, des):
+ self._group = parser.add_argument_group(title=title, description=des)
+
+ def add_arg(self, name, type, default, help, **kwargs):
+ type = str2bool if type == bool else type
+ self._group.add_argument(
+ "--" + name,
+ default=default,
+ type=type,
+ help=help + ' Default: %(default)s.',
+ **kwargs)
+
+
+class ArgConfig(object):
+ """
+ A high-level api for handling argument configs.
+ """
+
+ def __init__(self):
+ parser = argparse.ArgumentParser()
+
+ train_g = ArgumentGroup(parser, "training", "training options.")
+ train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
+ train_g.add_arg("learning_rate", float, 5e-5,
+ "Learning rate used to train with warmup.")
+ train_g.add_arg(
+ "lr_scheduler",
+ str,
+ "linear_warmup_decay",
+ "scheduler of learning rate.",
+ choices=['linear_warmup_decay', 'noam_decay'])
+ train_g.add_arg("weight_decay", float, 0.01,
+ "Weight decay rate for L2 regularizer.")
+ train_g.add_arg(
+ "warmup_proportion", float, 0.1,
+ "Proportion of training steps to perform linear learning rate warmup for."
+ )
+ train_g.add_arg("save_steps", int, 1000,
+ "The steps interval to save checkpoints.")
+ train_g.add_arg("use_fp16", bool, False,
+ "Whether to use fp16 mixed precision training.")
+ train_g.add_arg(
+ "loss_scaling", float, 1.0,
+ "Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
+ )
+ train_g.add_arg("pred_dir", str, None,
+ "Path to save the prediction results")
+
+ log_g = ArgumentGroup(parser, "logging", "logging related.")
+ log_g.add_arg("skip_steps", int, 10,
+ "The steps interval to print loss.")
+ log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
+
+ run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
+ run_type_g.add_arg("use_cuda", bool, True,
+ "If set, use GPU for training.")
+ run_type_g.add_arg(
+ "use_fast_executor", bool, False,
+ "If set, use fast parallel executor (in experiment).")
+ run_type_g.add_arg(
+ "num_iteration_per_drop_scope", int, 1,
+ "Ihe iteration intervals to clean up temporary variables.")
+ run_type_g.add_arg("do_train", bool, True,
+ "Whether to perform training.")
+ run_type_g.add_arg("do_predict", bool, True,
+ "Whether to perform prediction.")
+
+ custom_g = ArgumentGroup(parser, "customize", "customized options.")
+
+ self.custom_g = custom_g
+
+ self.parser = parser
+
+ def add_arg(self, name, dtype, default, descrip):
+ self.custom_g.add_arg(name, dtype, default, descrip)
+
+ def build_conf(self):
+ return self.parser.parse_args()
+
+
+def str2bool(v):
+ # because argparse does not support to parse "true, False" as python
+ # boolean directly
+ return v.lower() in ("true", "t", "1")
+
+
+def print_arguments(args, log=None):
+ if not log:
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(six.iteritems(vars(args))):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+ else:
+ log.info('----------- Configuration Arguments -----------')
+ for arg, value in sorted(six.iteritems(vars(args))):
+ log.info('%s: %s' % (arg, value))
+ log.info('------------------------------------------------')
+
+
+class PDConfig(object):
+ """
+ A high-level API for managing configuration files in PaddlePaddle.
+ Can jointly work with command-line-arugment, json files and yaml files.
+ """
+
+ def __init__(self, json_file="", yaml_file="", fuse_args=True):
+ """
+ Init funciton for PDConfig.
+ json_file: the path to the json configure file.
+ yaml_file: the path to the yaml configure file.
+ fuse_args: if fuse the json/yaml configs with argparse.
+ """
+ assert isinstance(json_file, str)
+ assert isinstance(yaml_file, str)
+
+ if json_file != "" and yaml_file != "":
+ raise Warning(
+ "json_file and yaml_file can not co-exist for now. please only use one configure file type."
+ )
+ return
+
+ self.args = None
+ self.arg_config = {}
+ self.json_config = {}
+ self.yaml_config = {}
+
+ parser = argparse.ArgumentParser()
+
+ self.default_g = ArgumentGroup(parser, "default", "default options.")
+ self.yaml_g = ArgumentGroup(parser, "yaml", "options from yaml.")
+ self.json_g = ArgumentGroup(parser, "json", "options from json.")
+ self.com_g = ArgumentGroup(parser, "custom", "customized options.")
+
+ self.default_g.add_arg("do_train", bool, False,
+ "Whether to perform training.")
+ self.default_g.add_arg("do_predict", bool, False,
+ "Whether to perform predicting.")
+ self.default_g.add_arg("do_eval", bool, False,
+ "Whether to perform evaluating.")
+ self.default_g.add_arg("do_save_inference_model", bool, False,
+ "Whether to perform model saving for inference.")
+
+ # NOTE: args for profiler
+ self.default_g.add_arg("is_profiler", int, 0, "the switch of profiler tools. (used for benchmark)")
+ self.default_g.add_arg("profiler_path", str, './', "the profiler output file path. (used for benchmark)")
+ self.default_g.add_arg("max_iter", int, 0, "the max train batch num.(used for benchmark)")
+
+ self.parser = parser
+
+ if json_file != "":
+ self.load_json(json_file, fuse_args=fuse_args)
+
+ if yaml_file:
+ self.load_yaml(yaml_file, fuse_args=fuse_args)
+
+ def load_json(self, file_path, fuse_args=True):
+
+ if not os.path.exists(file_path):
+ raise Warning("the json file %s does not exist." % file_path)
+ return
+
+ with open(file_path, "r") as fin:
+ self.json_config = json.loads(fin.read())
+ fin.close()
+
+ if fuse_args:
+ for name in self.json_config:
+ if isinstance(self.json_config[name], list):
+ self.json_g.add_arg(
+ name,
+ type(self.json_config[name][0]),
+ self.json_config[name],
+ "This is from %s" % file_path,
+ nargs=len(self.json_config[name]))
+ continue
+ if not isinstance(self.json_config[name], int) \
+ and not isinstance(self.json_config[name], float) \
+ and not isinstance(self.json_config[name], str) \
+ and not isinstance(self.json_config[name], bool):
+
+ continue
+
+ self.json_g.add_arg(name,
+ type(self.json_config[name]),
+ self.json_config[name],
+ "This is from %s" % file_path)
+
+ def load_yaml(self, file_path, fuse_args=True):
+
+ if not os.path.exists(file_path):
+ raise Warning("the yaml file %s does not exist." % file_path)
+ return
+
+ with open(file_path, "r") as fin:
+ self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
+ fin.close()
+
+ if fuse_args:
+ for name in self.yaml_config:
+ if isinstance(self.yaml_config[name], list):
+ self.yaml_g.add_arg(
+ name,
+ type(self.yaml_config[name][0]),
+ self.yaml_config[name],
+ "This is from %s" % file_path,
+ nargs=len(self.yaml_config[name]))
+ continue
+
+ if not isinstance(self.yaml_config[name], int) \
+ and not isinstance(self.yaml_config[name], float) \
+ and not isinstance(self.yaml_config[name], str) \
+ and not isinstance(self.yaml_config[name], bool):
+
+ continue
+
+ self.yaml_g.add_arg(name,
+ type(self.yaml_config[name]),
+ self.yaml_config[name],
+ "This is from %s" % file_path)
+
+ def build(self):
+ self.args = self.parser.parse_args()
+ self.arg_config = vars(self.args)
+
+ def __add__(self, new_arg):
+ assert isinstance(new_arg, list) or isinstance(new_arg, tuple)
+ assert len(new_arg) >= 3
+ assert self.args is None
+
+ name = new_arg[0]
+ dtype = new_arg[1]
+ dvalue = new_arg[2]
+ desc = new_arg[3] if len(
+ new_arg) == 4 else "Description is not provided."
+
+ self.com_g.add_arg(name, dtype, dvalue, desc)
+
+ return self
+
+ def __getattr__(self, name):
+ if name in self.arg_config:
+ return self.arg_config[name]
+
+ if name in self.json_config:
+ return self.json_config[name]
+
+ if name in self.yaml_config:
+ return self.yaml_config[name]
+
+ raise Warning("The argument %s is not defined." % name)
+
+ def Print(self):
+
+ print("-" * 70)
+ for name in self.arg_config:
+ print("%s:\t\t\t\t%s" % (str(name), str(self.arg_config[name])))
+
+ for name in self.json_config:
+ if name not in self.arg_config:
+ print("%s:\t\t\t\t%s" %
+ (str(name), str(self.json_config[name])))
+
+ for name in self.yaml_config:
+ if name not in self.arg_config:
+ print("%s:\t\t\t\t%s" %
+ (str(name), str(self.yaml_config[name])))
+
+ print("-" * 70)
+
+
+if __name__ == "__main__":
+ """
+ pd_config = PDConfig(json_file = "./test/bert_config.json")
+ pd_config.build()
+
+ print(pd_config.do_train)
+ print(pd_config.hidden_size)
+
+ pd_config = PDConfig(yaml_file = "./test/bert_config.yaml")
+ pd_config.build()
+
+ print(pd_config.do_train)
+ print(pd_config.hidden_size)
+ """
+
+ pd_config = PDConfig(yaml_file="./test/bert_config.yaml")
+ pd_config += ("my_age", int, 18, "I am forever 18.")
+ pd_config.build()
+
+ print(pd_config.do_train)
+ print(pd_config.hidden_size)
+ print(pd_config.my_age)
diff --git a/tsm/README.md b/tsm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6bc5c794432cbcef9559e18f0f8d9b977a98e386
--- /dev/null
+++ b/tsm/README.md
@@ -0,0 +1,147 @@
+# TSM 视频分类模型
+
+---
+
+## 内容
+
+- [模型简介](#模型简介)
+- [快速开始](#快速开始)
+- [参考论文](#参考论文)
+
+
+## 模型简介
+
+Temporal Shift Module是由MIT和IBM Watson AI Lab的Ji Lin,Chuang Gan和Song Han等人提出的通过时间位移来提高网络视频理解能力的模块,其位移操作原理如下图所示。
+
+
+
+Temporal shift module
+
+
+上图中矩阵表示特征图中的temporal和channel维度,通过将一部分的channel在temporal维度上向前位移一步,一部分的channel在temporal维度上向后位移一步,位移后的空缺补零。通过这种方式在特征图中引入temporal维度上的上下文交互,提高在时间维度上的视频理解能力。
+
+TSM模型是将Temporal Shift Module插入到ResNet网络中构建的视频分类模型,本模型库实现版本为以ResNet-50作为主干网络的TSM模型。
+
+详细内容请参考论文[Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1)
+
+## 快速开始
+
+### 安装说明
+
+#### paddle安装
+
+ 本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
+
+#### 代码下载及环境变量设置
+
+ 克隆代码库到本地,并设置`PYTHONPATH`环境变量
+
+ ```bash
+ git clone https://github.com/PaddlePaddle/hapi
+ cd hapi
+ export PYTHONPATH=$PYTHONPATH:`pwd`
+ cd tsm
+ ```
+
+### 数据准备
+
+TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](./dataset/README.md)
+
+#### 小数据集验证
+
+为了便于快速迭代,我们采用了较小的数据集进行动态图训练验证,从Kinetics-400数据集中选取分类标签(label)分别为 0, 2, 3, 4, 6, 7, 9, 12, 14, 15的即前10类数据验证模型精度。
+
+### 模型训练
+
+数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`tsm_checkpoint`目录下。
+
+`main.py`脚本参数可通过如下命令查询
+
+```bash
+python main.py --help
+```
+
+#### 静态图训练
+
+使用如下方式进行单卡训练:
+
+```bash
+export CUDA_VISIBLE_DEVICES=0
+python main.py --data= --batch_size=16
+```
+
+使用如下方式进行多卡训练:
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data= --batch_size=8
+```
+
+#### 动态图训练
+
+动态图训练只需要在运行脚本时添加`-d`参数即可。
+
+使用如下方式进行单卡训练:
+
+```bash
+export CUDA_VISIBLE_DEVICES=0
+python main.py --data= --batch_size=16 -d
+```
+
+使用如下方式进行多卡训练:
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch main.py --data= --batch_size=8 -d
+```
+
+**注意:** 对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数
+
+### 模型评估
+
+可通过如下两种方式进行模型评估。
+
+1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重评估
+
+```bash
+python main.py --data= --eval_only
+```
+
+2. 加载checkpoint进行精度评估
+
+```bash
+python main.py --data= --eval_only --weights=tsm_checkpoint/final
+```
+
+#### 评估精度
+
+在10类小数据集下训练模型权重见[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams),评估精度如下:
+
+|Top-1|Top-5|
+|:-:|:-:|
+|76%|98%|
+
+### 模型推断
+
+可通过如下两种方式进行模型推断。
+
+1. 自动下载Paddle发布的[TSM-ResNet50](https://paddlemodels.bj.bcebos.com/hapi/tsm_resnet50.pdparams)权重推断
+
+```bash
+python infer.py --data= --label_list= --infer_file=
+```
+
+2. 加载checkpoint进行精度推断
+
+```bash
+python infer.py --data= --label_list= --infer_file= --weights=tsm_checkpoint/final
+```
+
+模型推断结果会以如下日志形式输出
+
+```text
+2020-04-03 07:37:16,321-INFO: Sample ./kineteics/val_10/data_batch_10-042_6 predict label: 6, ground truth label: 6
+```
+
+## 参考论文
+
+- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han
+
diff --git a/tsm/check.py b/tsm/check.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c07568c7f1a0319f791cb39244494f3ddf9f12
--- /dev/null
+++ b/tsm/check.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import paddle.fluid as fluid
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['check_gpu', 'check_version']
+
+
+def check_gpu(use_gpu):
+ """
+ Log error and exit when set use_gpu=true in paddlepaddle
+ cpu version.
+ """
+ err = "Config use_gpu cannot be set as true while you are " \
+ "using paddlepaddle cpu version ! \nPlease try: \n" \
+ "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
+ "\t2. Set use_gpu as false in config file to run " \
+ "model on CPU"
+
+ try:
+ if use_gpu and not fluid.is_compiled_with_cuda():
+ logger.error(err)
+ sys.exit(1)
+ except Exception as e:
+ pass
+
+
+def check_version(version='1.7.0'):
+ """
+ Log error and exit when the installed version of paddlepaddle is
+ not satisfied.
+ """
+ err = "PaddlePaddle version {} or higher is required, " \
+ "or a suitable develop version is satisfied as well. \n" \
+ "Please make sure the version is good with your code." \
+ .format(version)
+
+ try:
+ fluid.require_version(version)
+ except Exception as e:
+ logger.error(err)
+ sys.exit(1)
diff --git a/tsm/dataset/README.md b/tsm/dataset/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..55613ef3cbaf9715ba89c231b85614a29d57a136
--- /dev/null
+++ b/tsm/dataset/README.md
@@ -0,0 +1,78 @@
+# 数据使用说明
+
+## Kinetics数据集
+
+Kinetics数据集是DeepMind公开的大规模视频动作识别数据集,有Kinetics400与Kinetics600两个版本。这里使用Kinetics400数据集,具体的数据预处理过程如下。
+
+### mp4视频下载
+在Code\_Root目录下创建文件夹
+
+ cd $Code_Root/data/dataset && mkdir kinetics
+
+ cd kinetics && mkdir data_k400 && cd data_k400
+
+ mkdir train_mp4 && mkdir val_mp4
+
+ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo ](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics)即可下载Kinetics400的mp4视频集合。将kinetics400的训练与验证集合分别下载到data/dataset/kinetics/data\_k400/train\_mp4与data/dataset/kinetics/data\_k400/val\_mp4。
+
+### mp4文件预处理
+
+为提高数据读取速度,提前将mp4文件解帧并打pickle包,dataloader从视频的pkl文件中读取数据(该方法耗费更多存储空间)。pkl文件里打包的内容为(video-id, label, [frame1, frame2,...,frameN])。
+
+在 data/dataset/kinetics/data\_k400目录下创建目录train\_pkl和val\_pkl
+
+ cd $Code_Root/data/dataset/kinetics/data_k400
+
+ mkdir train_pkl && mkdir val_pkl
+
+进入$Code\_Root/data/dataset/kinetics目录,使用video2pkl.py脚本进行数据转化。首先需要下载[train](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_train.csv)和[validation](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_val.csv)数据集的文件列表。
+
+首先生成预处理需要的数据集标签文件
+
+ python generate_label.py kinetics-400_train.csv kinetics400_label.txt
+
+然后执行如下程序:
+
+ python video2pkl.py kinetics-400_train.csv $Source_dir $Target_dir 8 #以8个进程为例
+
+- 该脚本依赖`ffmpeg`库,请预先安装`ffmpeg`
+
+对于train数据,
+
+ Source_dir = $Code_Root/data/dataset/kinetics/data_k400/train_mp4
+
+ Target_dir = $Code_Root/data/dataset/kinetics/data_k400/train_pkl
+
+对于val数据,
+
+ Source_dir = $Code_Root/data/dataset/kinetics/data_k400/val_mp4
+
+ Target_dir = $Code_Root/data/dataset/kinetics/data_k400/val_pkl
+
+这样即可将mp4文件解码并保存为pkl文件。
+
+### 生成训练和验证集list
+··
+ cd $Code_Root/data/dataset/kinetics
+
+ ls $Code_Root/data/dataset/kinetics/data_k400/train_pkl/* > train.list
+
+ ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > val.list
+
+ ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > test.list
+
+ ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > infer.list
+
+即可生成相应的文件列表,train.list和val.list的每一行表示一个pkl文件的绝对路径,示例如下:
+
+ /ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-097
+ /ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-114
+ /ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-118
+ ...
+
+或者
+
+ /ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-085
+ /ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-086
+ /ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-090
+ ...
diff --git a/tsm/dataset/kinetics/generate_label.py b/tsm/dataset/kinetics/generate_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7608e86244c305bc31aa341d34320b71034c2e2
--- /dev/null
+++ b/tsm/dataset/kinetics/generate_label.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+# kinetics-400_train.csv should be down loaded first and set as sys.argv[1]
+# sys.argv[2] can be set as kinetics400_label.txt
+# python generate_label.py kinetics-400_train.csv kinetics400_label.txt
+
+num_classes = 400
+
+fname = sys.argv[1]
+outname = sys.argv[2]
+fl = open(fname).readlines()
+fl = fl[1:]
+outf = open(outname, 'w')
+
+label_list = []
+for line in fl:
+ label = line.strip().split(',')[0].strip('"')
+ if label in label_list:
+ continue
+ else:
+ label_list.append(label)
+
+assert len(label_list
+ ) == num_classes, "there should be {} labels in list, but ".format(
+ num_classes, len(label_list))
+
+label_list.sort()
+for i in range(num_classes):
+ outf.write('{} {}'.format(label_list[i], i) + '\n')
+
+outf.close()
diff --git a/tsm/dataset/kinetics/video2pkl.py b/tsm/dataset/kinetics/video2pkl.py
new file mode 100644
index 0000000000000000000000000000000000000000..78d1b09b7bf6efb7f96535fa66bee2762bbccc5d
--- /dev/null
+++ b/tsm/dataset/kinetics/video2pkl.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+import os
+import sys
+import glob
+try:
+ import cPickle as pickle
+except:
+ import pickle
+from multiprocessing import Pool
+
+# example command line: python generate_k400_pkl.py kinetics-400_train.csv 8
+#
+# kinetics-400_train.csv is the training set file of K400 official release
+# each line contains laebl,youtube_id,time_start,time_end,split,is_cc
+
+assert (len(sys.argv) == 5)
+
+f = open(sys.argv[1])
+source_dir = sys.argv[2]
+target_dir = sys.argv[3]
+num_threads = sys.argv[4]
+all_video_entries = [x.strip().split(',') for x in f.readlines()]
+all_video_entries = all_video_entries[1:]
+f.close()
+
+category_label_map = {}
+f = open('kinetics400_label.txt')
+for line in f:
+ ens = line.strip().split(' ')
+ category = " ".join(ens[0:-1])
+ label = int(ens[-1])
+ category_label_map[category] = label
+f.close()
+
+
+def generate_pkl(entry):
+ mode = entry[4]
+ category = entry[0].strip('"')
+ category_dir = category
+ video_path = os.path.join(
+ './',
+ entry[1] + "_%06d" % int(entry[2]) + "_%06d" % int(entry[3]) + ".mp4")
+ video_path = os.path.join(source_dir, category_dir, video_path)
+ label = category_label_map[category]
+
+ vid = './' + video_path.split('/')[-1].split('.')[0]
+ if os.path.exists(video_path):
+ if not os.path.exists(vid):
+ os.makedirs(vid)
+ os.system('ffmpeg -i ' + video_path + ' -q 0 ' + vid + '/%06d.jpg')
+ else:
+ print("File not exists {}".format(video_path))
+ return
+
+ images = sorted(glob.glob(vid + '/*.jpg'))
+ ims = []
+ for img in images:
+ f = open(img, 'rb')
+ ims.append(f.read())
+ f.close()
+
+ output_pkl = vid + ".pkl"
+ output_pkl = os.path.join(target_dir, output_pkl)
+ f = open(output_pkl, 'wb')
+ pickle.dump((vid, label, ims), f, protocol=2)
+ f.close()
+
+ os.system('rm -rf %s' % vid)
+
+
+pool = Pool(processes=int(sys.argv[4]))
+pool.map(generate_pkl, all_video_entries)
+pool.close()
+pool.join()
diff --git a/tsm/images/temporal_shift.png b/tsm/images/temporal_shift.png
new file mode 100644
index 0000000000000000000000000000000000000000..7679c4459d2b0ee37134b99fe1e8177b1a69f8b0
Binary files /dev/null and b/tsm/images/temporal_shift.png differ
diff --git a/tsm/infer.py b/tsm/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..78dbe2cc6ab92dc2a85fee8f186b1b1ae8d74fdd
--- /dev/null
+++ b/tsm/infer.py
@@ -0,0 +1,91 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import argparse
+import numpy as np
+
+from model import Input, set_device
+from models import tsm_resnet50
+
+from check import check_gpu, check_version
+from kinetics_dataset import KineticsDataset
+from transforms import *
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+def main():
+ device = set_device(FLAGS.device)
+ fluid.enable_dygraph(device) if FLAGS.dynamic else None
+
+ transform = Compose([GroupScale(),
+ GroupCenterCrop(),
+ NormalizeImage()])
+ dataset = KineticsDataset(
+ pickle_file=FLAGS.infer_file,
+ label_list=FLAGS.label_list,
+ mode='test',
+ transform=transform)
+ labels = dataset.label_list
+
+ model = tsm_resnet50(num_classes=len(labels),
+ pretrained=FLAGS.weights is None)
+
+ inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
+
+ model.prepare(inputs=inputs, device=FLAGS.device)
+
+ if FLAGS.weights is not None:
+ model.load(FLAGS.weights, reset_optimizer=True)
+
+ imgs, label = dataset[0]
+ pred = model.test([imgs[np.newaxis, :]])
+ pred = labels[np.argmax(pred)]
+ logger.info("Sample {} predict label: {}, ground truth label: {}" \
+ .format(FLAGS.infer_file, pred, labels[int(label)]))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser("CNN training on TSM")
+ parser.add_argument(
+ "--data", type=str, default='dataset/kinetics',
+ help="path to dataset root directory")
+ parser.add_argument(
+ "--device", type=str, default='gpu',
+ help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_true',
+ help="enable dygraph mode")
+ parser.add_argument(
+ "--label_list", type=str, default=None,
+ help="path to category index label list file")
+ parser.add_argument(
+ "--infer_file", type=str, default=None,
+ help="path to pickle file for inference")
+ parser.add_argument(
+ "-w",
+ "--weights",
+ default=None,
+ type=str,
+ help="weights path for evaluation")
+ FLAGS = parser.parse_args()
+
+ check_gpu(str.lower(FLAGS.device) == 'gpu')
+ check_version()
+ main()
diff --git a/tsm/kinetics_dataset.py b/tsm/kinetics_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e07543f37392744a2bf82ecc9b038e78d2d5524
--- /dev/null
+++ b/tsm/kinetics_dataset.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import six
+import sys
+import random
+import numpy as np
+from PIL import Image, ImageEnhance
+
+try:
+ import cPickle as pickle
+ from cStringIO import StringIO
+except ImportError:
+ import pickle
+ from io import BytesIO
+
+from paddle.fluid.io import Dataset
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['KineticsDataset']
+
+KINETICS_CLASS_NUM = 400
+
+
+class KineticsDataset(Dataset):
+ """
+ Kinetics dataset
+
+ Args:
+ file_list (str): path to file list
+ pickle_dir (str): path to pickle file directory
+ label_list (str): path to label_list file, if set None, the
+ default class number 400 of kinetics dataset will be
+ used. Default None
+ mode (str): 'train' or 'val' mode, segmentation methods will
+ be different in these 2 modes. Default 'train'
+ seg_num (int): segment number to sample from each video.
+ Default 8
+ seg_len (int): frame number of each segment. Default 1
+ transform (callable): transforms to perform on video samples,
+ None for no transforms. Default None.
+ """
+
+ def __init__(self,
+ file_list=None,
+ pickle_dir=None,
+ pickle_file=None,
+ label_list=None,
+ mode='train',
+ seg_num=8,
+ seg_len=1,
+ transform=None):
+ assert str.lower(mode) in ['train', 'val', 'test'], \
+ "mode can only be 'train' 'val' or 'test'"
+ self.mode = str.lower(mode)
+
+ if self.mode in ['train', 'val']:
+ assert os.path.isfile(file_list), \
+ "file_list {} not a file".format(file_list)
+ with open(file_list) as f:
+ self.pickle_paths = [l.strip() for l in f]
+
+ assert os.path.isdir(pickle_dir), \
+ "pickle_dir {} not a directory".format(pickle_dir)
+ self.pickle_dir = pickle_dir
+ else:
+ assert os.path.isfile(pickle_file), \
+ "pickle_file {} not a file".format(pickle_file)
+ self.pickle_dir = ''
+ self.pickle_paths = [pickle_file]
+
+ self.label_list = label_list
+ if self.label_list is not None:
+ assert os.path.isfile(self.label_list), \
+ "label_list {} not a file".format(self.label_list)
+ with open(self.label_list) as f:
+ self.label_list = [int(l.strip()) for l in f]
+
+ self.seg_num = seg_num
+ self.seg_len = seg_len
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.pickle_paths)
+
+ def __getitem__(self, idx):
+ pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx])
+
+ try:
+ if six.PY2:
+ data = pickle.load(open(pickle_path, 'rb'))
+ else:
+ data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
+
+ vid, label, frames = data
+ if len(frames) < 1:
+ logger.error("{} contains no frame".format(pickle_path))
+ sys.exit(-1)
+ except Exception as e:
+ logger.error("Load {} failed: {}".format(pickle_path, e))
+ sys.exit(-1)
+
+ if self.label_list is not None:
+ label = self.label_list.index(label)
+ imgs = self._video_loader(frames)
+
+ if self.transform:
+ imgs, label = self.transform(imgs, label)
+ return imgs, np.array([label])
+
+ @property
+ def num_classes(self):
+ return KINETICS_CLASS_NUM if self.label_list is None \
+ else len(self.label_list)
+
+ def _video_loader(self, frames):
+ videolen = len(frames)
+ average_dur = int(videolen / self.seg_num)
+
+ imgs = []
+ for i in range(self.seg_num):
+ idx = 0
+ if self.mode == 'train':
+ if average_dur >= self.seg_len:
+ idx = random.randint(0, average_dur - self.seg_len)
+ idx += i * average_dur
+ elif average_dur >= 1:
+ idx += i * average_dur
+ else:
+ idx = i
+ else:
+ if average_dur >= self.seg_len:
+ idx = (average_dur - self.seg_len) // 2
+ idx += i * average_dur
+ elif average_dur >= 1:
+ idx += i * average_dur
+ else:
+ idx = i
+
+ for jj in range(idx, idx + self.seg_len):
+ imgbuf = frames[int(jj % videolen)]
+ img = self._imageloader(imgbuf)
+ imgs.append(img)
+
+ return imgs
+
+ def _imageloader(self, buf):
+ if isinstance(buf, str):
+ img = Image.open(StringIO(buf))
+ else:
+ img = Image.open(BytesIO(buf))
+
+ return img.convert('RGB')
+
diff --git a/tsm/main.py b/tsm/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..07868dbdc43565341b19ef6fe69c693f812c6258
--- /dev/null
+++ b/tsm/main.py
@@ -0,0 +1,156 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import argparse
+import numpy as np
+
+from paddle import fluid
+from paddle.fluid.dygraph.parallel import ParallelEnv
+
+from model import Model, CrossEntropy, Input, set_device
+from metrics import Accuracy
+from models import tsm_resnet50
+
+from check import check_gpu, check_version
+from kinetics_dataset import KineticsDataset
+from transforms import *
+
+
+def make_optimizer(step_per_epoch, parameter_list=None):
+ boundaries = [e * step_per_epoch for e in [40, 60]]
+ values = [FLAGS.lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
+
+ learning_rate = fluid.layers.piecewise_decay(
+ boundaries=boundaries,
+ values=values)
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=learning_rate,
+ regularization=fluid.regularizer.L2Decay(1e-4),
+ momentum=0.9,
+ parameter_list=parameter_list)
+
+ return optimizer
+
+
+def main():
+ device = set_device(FLAGS.device)
+ fluid.enable_dygraph(device) if FLAGS.dynamic else None
+
+ train_transform = Compose([GroupScale(),
+ GroupMultiScaleCrop(),
+ GroupRandomCrop(),
+ GroupRandomFlip(),
+ NormalizeImage()])
+ train_dataset = KineticsDataset(
+ file_list=os.path.join(FLAGS.data, 'train_10.list'),
+ pickle_dir=os.path.join(FLAGS.data, 'train_10'),
+ label_list=os.path.join(FLAGS.data, 'label_list'),
+ transform=train_transform)
+ val_transform = Compose([GroupScale(),
+ GroupCenterCrop(),
+ NormalizeImage()])
+ val_dataset = KineticsDataset(
+ file_list=os.path.join(FLAGS.data, 'val_10.list'),
+ pickle_dir=os.path.join(FLAGS.data, 'val_10'),
+ label_list=os.path.join(FLAGS.data, 'label_list'),
+ mode='val',
+ transform=val_transform)
+
+ pretrained = FLAGS.eval_only and FLAGS.weights is None
+ model = tsm_resnet50(num_classes=train_dataset.num_classes,
+ pretrained=pretrained)
+
+ step_per_epoch = int(len(train_dataset) / FLAGS.batch_size \
+ / ParallelEnv().nranks)
+ optim = make_optimizer(step_per_epoch, model.parameters())
+
+ inputs = [Input([None, 8, 3, 224, 224], 'float32', name='image')]
+ labels = [Input([None, 1], 'int64', name='label')]
+
+ model.prepare(
+ optim,
+ CrossEntropy(),
+ metrics=Accuracy(topk=(1, 5)),
+ inputs=inputs,
+ labels=labels,
+ device=FLAGS.device)
+
+ if FLAGS.eval_only:
+ if FLAGS.weights is not None:
+ model.load(FLAGS.weights, reset_optimizer=True)
+
+ model.evaluate(
+ val_dataset,
+ batch_size=FLAGS.batch_size,
+ num_workers=FLAGS.num_workers)
+ return
+
+ if FLAGS.resume is not None:
+ model.load(FLAGS.resume)
+
+ model.fit(train_data=train_dataset,
+ eval_data=val_dataset,
+ epochs=FLAGS.epoch,
+ batch_size=FLAGS.batch_size,
+ save_dir='tsm_checkpoint',
+ num_workers=FLAGS.num_workers,
+ drop_last=True,
+ shuffle=True)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser("CNN training on TSM")
+ parser.add_argument(
+ "--data", type=str, default='dataset/kinetics',
+ help="path to dataset root directory")
+ parser.add_argument(
+ "--device", type=str, default='gpu', help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_true', help="enable dygraph mode")
+ parser.add_argument(
+ "--eval_only", action='store_true', help="run evaluation only")
+ parser.add_argument(
+ "-e", "--epoch", default=70, type=int, help="number of epoch")
+ parser.add_argument(
+ "-j", "--num_workers", default=4, type=int, help="read worker number")
+ parser.add_argument(
+ '--lr',
+ '--learning-rate',
+ default=1e-2,
+ type=float,
+ metavar='LR',
+ help='initial learning rate')
+ parser.add_argument(
+ "-b", "--batch_size", default=16, type=int, help="batch size")
+ parser.add_argument(
+ "-r",
+ "--resume",
+ default=None,
+ type=str,
+ help="checkpoint path to resume")
+ parser.add_argument(
+ "-w",
+ "--weights",
+ default=None,
+ type=str,
+ help="weights path for evaluation")
+ FLAGS = parser.parse_args()
+
+ check_gpu(str.lower(FLAGS.device) == 'gpu')
+ check_version()
+ main()
diff --git a/tsm/transforms.py b/tsm/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..230e2f8a332797002d76cf1317f425adea893da2
--- /dev/null
+++ b/tsm/transforms.py
@@ -0,0 +1,246 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import traceback
+import numpy as np
+from PIL import Image
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['GroupScale', 'GroupMultiScaleCrop', 'GroupRandomCrop',
+ 'GroupRandomFlip', 'GroupCenterCrop', 'NormalizeImage',
+ 'Compose']
+
+
+class Compose(object):
+ def __init__(self, transforms=[]):
+ self.transforms = transforms
+
+ def __call__(self, *data):
+ for f in self.transforms:
+ try:
+ data = f(*data)
+ except Exception as e:
+ stack_info = traceback.format_exc()
+ logger.info("fail to perform transform [{}] with error: "
+ "{} and stack:\n{}".format(f, e, str(stack_info)))
+ raise e
+ return data
+
+
+class GroupScale(object):
+ """
+ Group scale image
+
+ Args:
+ target_size (int): image resize target size
+ """
+ def __init__(self, target_size=224):
+ self.target_size = target_size
+
+ def __call__(self, imgs, label):
+ resized_imgs = []
+ for i in range(len(imgs)):
+ img = imgs[i]
+ w, h = img.size
+ if (w <= h and w == self.target_size) or \
+ (h <= w and h == self.target_size):
+ resized_imgs.append(img)
+ continue
+
+ if w < h:
+ ow = self.target_size
+ oh = int(self.target_size * 4.0 / 3.0)
+ resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
+ else:
+ oh = self.target_size
+ ow = int(self.target_size * 4.0 / 3.0)
+ resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
+
+ return resized_imgs, label
+
+
+class GroupMultiScaleCrop(object):
+ """
+ FIXME: add comments
+ """
+ def __init__(self,
+ short_size=256,
+ scales=None,
+ max_distort=1,
+ fix_crop=True,
+ more_fix_crop=True):
+ self.short_size = short_size
+ self.scales = scales if scales is not None \
+ else [1, .875, .75, .66]
+ self.max_distort = max_distort
+ self.fix_crop = fix_crop
+ self.more_fix_crop = more_fix_crop
+
+ def __call__(self, imgs, label):
+ input_size = [self.short_size, self.short_size]
+ im_size = imgs[0].size
+
+ # get random crop offset
+ def _sample_crop_size(im_size):
+ image_w, image_h = im_size[0], im_size[1]
+
+ base_size = min(image_w, image_h)
+ crop_sizes = [int(base_size * x) for x in self.scales]
+ crop_h = [
+ input_size[1] if abs(x - input_size[1]) < 3 else x
+ for x in crop_sizes
+ ]
+ crop_w = [
+ input_size[0] if abs(x - input_size[0]) < 3 else x
+ for x in crop_sizes
+ ]
+
+ pairs = []
+ for i, h in enumerate(crop_h):
+ for j, w in enumerate(crop_w):
+ if abs(i - j) <= self.max_distort:
+ pairs.append((w, h))
+ crop_pair = random.choice(pairs)
+ if not self.fix_crop:
+ w_offset = np.random.randint(0, image_w - crop_pair[0])
+ h_offset = np.random.randint(0, image_h - crop_pair[1])
+ else:
+ w_step = (image_w - crop_pair[0]) / 4
+ h_step = (image_h - crop_pair[1]) / 4
+
+ ret = list()
+ ret.append((0, 0)) # upper left
+ if w_step != 0:
+ ret.append((4 * w_step, 0)) # upper right
+ if h_step != 0:
+ ret.append((0, 4 * h_step)) # lower left
+ if h_step != 0 and w_step != 0:
+ ret.append((4 * w_step, 4 * h_step)) # lower right
+ if h_step != 0 or w_step != 0:
+ ret.append((2 * w_step, 2 * h_step)) # center
+
+ if self.more_fix_crop:
+ ret.append((0, 2 * h_step)) # center left
+ ret.append((4 * w_step, 2 * h_step)) # center right
+ ret.append((2 * w_step, 4 * h_step)) # lower center
+ ret.append((2 * w_step, 0 * h_step)) # upper center
+
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
+
+ w_offset, h_offset = random.choice(ret)
+
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
+
+ crop_w, crop_h, offset_w, offset_h = _sample_crop_size(im_size)
+ crop_imgs = [
+ img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h))
+ for img in imgs
+ ]
+ ret_imgs = [
+ img.resize((input_size[0], input_size[1]), Image.BILINEAR)
+ for img in crop_imgs
+ ]
+
+ return ret_imgs, label
+
+
+class GroupRandomCrop(object):
+ def __init__(self, target_size=224):
+ self.target_size = target_size
+
+ def __call__(self, imgs, label):
+ w, h = imgs[0].size
+ th, tw = self.target_size, self.target_size
+
+ assert (w >= self.target_size) and (h >= self.target_size), \
+ "image width({}) and height({}) should be larger than " \
+ "crop size".format(w, h, self.target_size)
+
+ out_images = []
+ x1 = np.random.randint(0, w - tw)
+ y1 = np.random.randint(0, h - th)
+
+ for img in imgs:
+ if w == tw and h == th:
+ out_images.append(img)
+ else:
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
+
+ return out_images, label
+
+
+class GroupRandomFlip(object):
+ def __call__(self, imgs, label):
+ v = np.random.random()
+ if v < 0.5:
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
+ return ret, label
+ else:
+ return imgs, label
+
+
+class GroupCenterCrop(object):
+ def __init__(self, target_size=224):
+ self.target_size = target_size
+
+ def __call__(self, imgs, label):
+ crop_imgs = []
+ for img in imgs:
+ w, h = img.size
+ th, tw = self.target_size, self.target_size
+ assert (w >= self.target_size) and (h >= self.target_size), \
+ "image width({}) and height({}) should be larger " \
+ "than crop size".format(w, h, self.target_size)
+ x1 = int(round((w - tw) / 2.))
+ y1 = int(round((h - th) / 2.))
+ crop_imgs.append(img.crop((x1, y1, x1 + tw, y1 + th)))
+
+ return crop_imgs, label
+
+
+class NormalizeImage(object):
+ def __init__(self,
+ target_size=224,
+ img_mean=[0.485, 0.456, 0.406],
+ img_std=[0.229, 0.224, 0.225],
+ seg_num=8,
+ seg_len=1):
+ self.target_size = target_size
+ self.img_mean = np.array(img_mean).reshape((3, 1, 1)).astype('float32')
+ self.img_std = np.array(img_std).reshape((3, 1, 1)).astype('float32')
+ self.seg_num = seg_num
+ self.seg_len = seg_len
+
+ def __call__(self, imgs, label):
+ np_imgs = (np.array(imgs[0]).astype('float32').transpose(
+ (2, 0, 1))).reshape(1, 3, self.target_size,
+ self.target_size) / 255
+ for i in range(len(imgs) - 1):
+ img = (np.array(imgs[i + 1]).astype('float32').transpose(
+ (2, 0, 1))).reshape(1, 3, self.target_size,
+ self.target_size) / 255
+ np_imgs = np.concatenate((np_imgs, img))
+
+ np_imgs -= self.img_mean
+ np_imgs /= self.img_std
+ np_imgs = np.reshape(np_imgs, (self.seg_num, self.seg_len * 3,
+ self.target_size, self.target_size))
+
+ return np_imgs, label
diff --git a/yolov3.py b/yolov3.py
deleted file mode 100644
index 6c609f24dce60293ee42a599324d595a6875a0f6..0000000000000000000000000000000000000000
--- a/yolov3.py
+++ /dev/null
@@ -1,568 +0,0 @@
-# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import contextlib
-import os
-import random
-import time
-
-from functools import partial
-
-import cv2
-import numpy as np
-from pycocotools.coco import COCO
-
-import paddle
-import paddle.fluid as fluid
-from paddle.fluid.dygraph.nn import Conv2D
-from paddle.fluid.param_attr import ParamAttr
-from paddle.fluid.regularizer import L2Decay
-
-from model import Model, Loss, Input
-from resnet import ResNet, ConvBNLayer
-
-import logging
-FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
-logging.basicConfig(level=logging.INFO, format=FORMAT)
-logger = logging.getLogger(__name__)
-
-
-# XXX transfer learning
-class ResNetBackBone(ResNet):
- def __init__(self, depth=50):
- super(ResNetBackBone, self).__init__(depth=depth)
- delattr(self, 'fc')
-
- def forward(self, inputs):
- x = self.conv(inputs)
- x = self.pool(x)
- outputs = []
- for layer in self.layers:
- x = layer(x)
- outputs.append(x)
- return outputs
-
-
-class YoloDetectionBlock(fluid.dygraph.Layer):
- def __init__(self, num_channels, num_filters):
- super(YoloDetectionBlock, self).__init__()
-
- assert num_filters % 2 == 0, \
- "num_filters {} cannot be divided by 2".format(num_filters)
-
- self.conv0 = ConvBNLayer(
- num_channels=num_channels,
- num_filters=num_filters,
- filter_size=1,
- act='leaky_relu')
- self.conv1 = ConvBNLayer(
- num_channels=num_filters,
- num_filters=num_filters * 2,
- filter_size=3,
- act='leaky_relu')
- self.conv2 = ConvBNLayer(
- num_channels=num_filters * 2,
- num_filters=num_filters,
- filter_size=1,
- act='leaky_relu')
- self.conv3 = ConvBNLayer(
- num_channels=num_filters,
- num_filters=num_filters * 2,
- filter_size=3,
- act='leaky_relu')
- self.route = ConvBNLayer(
- num_channels=num_filters * 2,
- num_filters=num_filters,
- filter_size=1,
- act='leaky_relu')
- self.tip = ConvBNLayer(
- num_channels=num_filters,
- num_filters=num_filters * 2,
- filter_size=3,
- act='leaky_relu')
-
- def forward(self, inputs):
- out = self.conv0(inputs)
- out = self.conv1(out)
- out = self.conv2(out)
- out = self.conv3(out)
- route = self.route(out)
- tip = self.tip(route)
- return route, tip
-
-
-class YOLOv3(Model):
- def __init__(self, num_classes=80):
- super(YOLOv3, self).__init__()
- self.num_classes = num_classes
- self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
- 59, 119, 116, 90, 156, 198, 373, 326]
- self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
- self.valid_thresh = 0.005
- self.nms_thresh = 0.45
- self.nms_topk = 400
- self.nms_posk = 100
- self.draw_thresh = 0.5
-
- self.backbone = ResNetBackBone()
- self.block_outputs = []
- self.yolo_blocks = []
- self.route_blocks = []
-
- for idx, num_chan in enumerate([2048, 1280, 640]):
- yolo_block = self.add_sublayer(
- "detecton_block_{}".format(idx),
- YoloDetectionBlock(num_chan, num_filters=512 // (2**idx)))
- self.yolo_blocks.append(yolo_block)
-
- num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5)
-
- block_out = self.add_sublayer(
- "block_out_{}".format(idx),
- Conv2D(num_channels=1024 // (2**idx),
- num_filters=num_filters,
- filter_size=1,
- param_attr=ParamAttr(
- initializer=fluid.initializer.Normal(0., 0.02)),
- bias_attr=ParamAttr(
- initializer=fluid.initializer.Constant(0.0),
- regularizer=L2Decay(0.))))
- self.block_outputs.append(block_out)
- if idx < 2:
- route = self.add_sublayer(
- "route_{}".format(idx),
- ConvBNLayer(num_channels=512 // (2**idx),
- num_filters=256 // (2**idx),
- filter_size=1,
- act='leaky_relu'))
- self.route_blocks.append(route)
-
- def forward(self, inputs, img_info):
- outputs = []
- boxes = []
- scores = []
- downsample = 32
-
- feats = self.backbone(inputs)
- feats = feats[::-1][:len(self.anchor_masks)]
- route = None
- for idx, feat in enumerate(feats):
- if idx > 0:
- feat = fluid.layers.concat(input=[route, feat], axis=1)
- route, tip = self.yolo_blocks[idx](feat)
- block_out = self.block_outputs[idx](tip)
- outputs.append(block_out)
-
- if idx < 2:
- route = self.route_blocks[idx](route)
- route = fluid.layers.resize_nearest(route, scale=2)
-
- if self.mode == 'test':
- anchor_mask = self.anchor_masks[idx]
- mask_anchors = []
- for m in anchor_mask:
- mask_anchors.append(self.anchors[2 * m])
- mask_anchors.append(self.anchors[2 * m + 1])
- img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
- img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
- b, s = fluid.layers.yolo_box(
- x=block_out,
- img_size=img_shape,
- anchors=mask_anchors,
- class_num=self.num_classes,
- conf_thresh=self.valid_thresh,
- downsample_ratio=downsample)
-
- boxes.append(b)
- scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
-
- downsample //= 2
-
- if self.mode != 'test':
- return outputs
-
- return [img_id, fluid.layers.multiclass_nms(
- bboxes=fluid.layers.concat(boxes, axis=1),
- scores=fluid.layers.concat(scores, axis=2),
- score_threshold=self.valid_thresh,
- nms_top_k=self.nms_topk,
- keep_top_k=self.nms_posk,
- nms_threshold=self.nms_thresh,
- background_label=-1)]
-
-
-class YoloLoss(Loss):
- def __init__(self, num_classes=80):
- super(YoloLoss, self).__init__()
- self.num_classes = num_classes
- self.ignore_thresh = 0.7
- self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
- 59, 119, 116, 90, 156, 198, 373, 326]
- self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
-
- def forward(self, outputs, labels):
- downsample = 32
- gt_box, gt_label, gt_score = labels
- losses = []
-
- for idx, out in enumerate(outputs):
- anchor_mask = self.anchor_masks[idx]
- loss = fluid.layers.yolov3_loss(
- x=out,
- gt_box=gt_box,
- gt_label=gt_label,
- gt_score=gt_score,
- anchor_mask=anchor_mask,
- downsample_ratio=downsample,
- anchors=self.anchors,
- class_num=self.num_classes,
- ignore_thresh=self.ignore_thresh,
- use_label_smooth=True)
- loss = fluid.layers.reduce_mean(loss)
- losses.append(loss)
- downsample //= 2
- return losses
-
-
-def make_optimizer(parameter_list=None):
- base_lr = FLAGS.lr
- warm_up_iter = 4000
- momentum = 0.9
- weight_decay = 5e-4
- boundaries = [400000, 450000]
- values = [base_lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
- learning_rate = fluid.layers.piecewise_decay(
- boundaries=boundaries,
- values=values)
- learning_rate = fluid.layers.linear_lr_warmup(
- learning_rate=learning_rate,
- warmup_steps=warm_up_iter,
- start_lr=0.0,
- end_lr=base_lr)
- optimizer = fluid.optimizer.Momentum(
- learning_rate=learning_rate,
- regularization=fluid.regularizer.L2Decay(weight_decay),
- momentum=momentum,
- parameter_list=parameter_list)
- return optimizer
-
-
-def _iou_matrix(a, b):
- tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
- br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
- area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
- area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
- area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
- area_o = (area_a[:, np.newaxis] + area_b - area_i)
- return area_i / (area_o + 1e-10)
-
-
-def _crop_box_with_center_constraint(box, crop):
- cropped_box = box.copy()
- cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
- cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
- cropped_box[:, :2] -= crop[:2]
- cropped_box[:, 2:] -= crop[:2]
- centers = (box[:, :2] + box[:, 2:]) / 2
- valid = np.logical_and(
- crop[:2] <= centers, centers < crop[2:]).all(axis=1)
- valid = np.logical_and(
- valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
- return cropped_box, np.where(valid)[0]
-
-
-def random_crop(inputs):
- aspect_ratios = [.5, 2.]
- thresholds = [.0, .1, .3, .5, .7, .9]
- scaling = [.3, 1.]
-
- img, img_ids, gt_box, gt_label = inputs
- h, w = img.shape[:2]
-
- if len(gt_box) == 0:
- return inputs
-
- np.random.shuffle(thresholds)
- for thresh in thresholds:
- found = False
- for i in range(50):
- scale = np.random.uniform(*scaling)
- min_ar, max_ar = aspect_ratios
- ar = np.random.uniform(max(min_ar, scale**2),
- min(max_ar, scale**-2))
- crop_h = int(h * scale / np.sqrt(ar))
- crop_w = int(w * scale * np.sqrt(ar))
- crop_y = np.random.randint(0, h - crop_h)
- crop_x = np.random.randint(0, w - crop_w)
- crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
- iou = _iou_matrix(gt_box, np.array([crop_box], dtype=np.float32))
- if iou.max() < thresh:
- continue
-
- cropped_box, valid_ids = _crop_box_with_center_constraint(
- gt_box, np.array(crop_box, dtype=np.float32))
- if valid_ids.size > 0:
- found = True
- break
-
- if found:
- x1, y1, x2, y2 = crop_box
- img = img[y1:y2, x1:x2, :]
- gt_box = np.take(cropped_box, valid_ids, axis=0)
- gt_label = np.take(gt_label, valid_ids, axis=0)
- return img, img_ids, gt_box, gt_label
-
- return inputs
-
-
-# XXX mix up, color distort and random expand are skipped for simplicity
-def sample_transform(inputs, mode='train', num_max_boxes=50):
- if mode == 'train':
- img, img_id, gt_box, gt_label = random_crop(inputs)
- else:
- img, img_id, gt_box, gt_label = inputs
-
- h, w = img.shape[:2]
- # random flip
- if mode == 'train' and np.random.uniform(0., 1.) > .5:
- img = img[:, ::-1, :]
- if len(gt_box) > 0:
- swap = gt_box.copy()
- gt_box[:, 0] = w - swap[:, 2] - 1
- gt_box[:, 2] = w - swap[:, 0] - 1
-
- if len(gt_label) == 0:
- gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32)
- gt_label = np.zeros([num_max_boxes], dtype=np.int32)
- return img, gt_box, gt_label
-
- gt_box = gt_box[:num_max_boxes, :]
- gt_label = gt_label[:num_max_boxes, 0]
- # normalize boxes
- gt_box /= np.array([w, h] * 2, dtype=np.float32)
- gt_box[:, 2:] = gt_box[:, 2:] - gt_box[:, :2]
- gt_box[:, :2] = gt_box[:, :2] + gt_box[:, 2:] / 2.
-
- pad = num_max_boxes - gt_label.size
- gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant')
- gt_label = np.pad(gt_label, ((0, pad)), mode='constant')
-
- return img, img_id, gt_box, gt_label
-
-
-def batch_transform(batch, mode='train'):
- if mode == 'train':
- d = np.random.choice(
- [320, 352, 384, 416, 448, 480, 512, 544, 576, 608])
- interp = np.random.choice(range(5))
- else:
- d = 608
- interp = cv2.INTER_CUBIC
- # transpose batch
- imgs, img_ids, gt_boxes, gt_labels = list(zip(*batch))
- img_shapes = np.array([[im.shape[0], im.shape[1]] for im in imgs]).astype('int32')
- imgs = np.array([cv2.resize(
- img, (d, d), interpolation=interp) for img in imgs])
-
- # transpose, permute and normalize
- imgs = imgs.astype(np.float32)[..., ::-1]
- mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
- std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
- invstd = 1. / std
- imgs -= mean
- imgs *= invstd
- imgs = imgs.transpose((0, 3, 1, 2))
-
- img_ids = np.array(img_ids)
- img_info = np.concatenate([img_ids, img_shapes], axis=1)
- gt_boxes = np.array(gt_boxes)
- gt_labels = np.array(gt_labels)
- # XXX since mix up is not used, scores are all ones
- gt_scores = np.ones_like(gt_labels, dtype=np.float32)
- return [imgs, img_info], [gt_boxes, gt_labels, gt_scores]
-
-
-def coco2017(root_dir, mode='train'):
- json_path = os.path.join(
- root_dir, 'annotations/instances_{}2017.json'.format(mode))
- coco = COCO(json_path)
- img_ids = coco.getImgIds()
- imgs = coco.loadImgs(img_ids)
- class_map = {v: i + 1 for i, v in enumerate(coco.getCatIds())}
- samples = []
-
- for img in imgs:
- img_path = os.path.join(
- root_dir, '{}2017'.format(mode), img['file_name'])
- file_path = img_path
- width = img['width']
- height = img['height']
- ann_ids = coco.getAnnIds(imgIds=img['id'], iscrowd=False)
- anns = coco.loadAnns(ann_ids)
-
- gt_box = []
- gt_label = []
-
- for ann in anns:
- x1, y1, w, h = ann['bbox']
- x2 = x1 + w - 1
- y2 = y1 + h - 1
- x1 = np.clip(x1, 0, width - 1)
- x2 = np.clip(x2, 0, width - 1)
- y1 = np.clip(y1, 0, height - 1)
- y2 = np.clip(y2, 0, height - 1)
- if ann['area'] <= 0 or x2 < x1 or y2 < y1:
- continue
- gt_label.append(ann['category_id'])
- gt_box.append([x1, y1, x2, y2])
-
- gt_box = np.array(gt_box, dtype=np.float32)
- gt_label = np.array([class_map[cls] for cls in gt_label],
- dtype=np.int32)[:, np.newaxis]
- im_id = np.array([img['id']], dtype=np.int32)
-
- if gt_label.size == 0 and not mode == 'train':
- continue
- samples.append((file_path, im_id.copy(), gt_box.copy(), gt_label.copy()))
-
- def iterator():
- if mode == 'train':
- np.random.shuffle(samples)
- for file_path, im_id, gt_box, gt_label in samples:
- img = cv2.imread(file_path)
- yield img, im_id, gt_box, gt_label
-
- return iterator
-
-
-# XXX coco metrics not included for simplicity
-def run(model, loader, mode='train'):
- total_loss = 0.
- total_time = 0.
- device_ids = list(range(FLAGS.num_devices))
- start = time.time()
-
- for idx, batch in enumerate(loader()):
- losses = getattr(model, mode)(batch[0], batch[1])
-
- total_loss += np.sum(losses)
- if idx > 1: # skip first two steps
- total_time += time.time() - start
- if idx % 10 == 0:
- logger.info("{:04d}: loss {:0.3f} time: {:0.3f}".format(
- idx, total_loss / (idx + 1), total_time / max(1, (idx - 1))))
- start = time.time()
-
-
-def main():
- @contextlib.contextmanager
- def null_guard():
- yield
-
- epoch = FLAGS.epoch
- batch_size = FLAGS.batch_size
- guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard()
-
- train_loader = fluid.io.xmap_readers(
- batch_transform,
- paddle.batch(
- fluid.io.xmap_readers(
- sample_transform,
- coco2017(FLAGS.data, 'train'),
- process_num=8,
- buffer_size=4 * batch_size),
- batch_size=batch_size,
- drop_last=True),
- process_num=2, buffer_size=4)
-
- val_sample_transform = partial(sample_transform, mode='val')
- val_batch_transform = partial(batch_transform, mode='val')
-
- val_loader = fluid.io.xmap_readers(
- val_batch_transform,
- paddle.batch(
- fluid.io.xmap_readers(
- val_sample_transform,
- coco2017(FLAGS.data, 'val'),
- process_num=8,
- buffer_size=4 * batch_size),
- batch_size=1),
- process_num=2, buffer_size=4)
-
- if not os.path.exists('yolo_checkpoints'):
- os.mkdir('yolo_checkpoints')
-
- with guard:
- NUM_CLASSES = 7
- NUM_MAX_BOXES = 50
- model = YOLOv3(num_classes=NUM_CLASSES)
- # XXX transfer learning
- if FLAGS.pretrain_weights is not None:
- model.backbone.load(FLAGS.pretrain_weights)
- if FLAGS.weights is not None:
- model.load(FLAGS.weights)
- optim = make_optimizer(parameter_list=model.parameters())
- anno_path = os.path.join(FLAGS.data, 'annotations', 'instances_val2017.json')
- inputs = [Input([None, 3, None, None], 'float32', name='image'),
- Input([None, 3], 'int32', name='img_info')]
- labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
- Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
- Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')]
- model.prepare(optim,
- YoloLoss(num_classes=NUM_CLASSES),
- # For YOLOv3, output variable in train/eval is different,
- # which is not supported by metric, add by callback later?
- # metrics=COCOMetric(anno_path, with_background=False)
- inputs=inputs,
- labels = labels)
-
- for e in range(epoch):
- logger.info("======== train epoch {} ========".format(e))
- run(model, train_loader)
- model.save('yolo_checkpoints/{:02d}'.format(e))
- logger.info("======== eval epoch {} ========".format(e))
- run(model, val_loader, mode='eval')
- # should be called in fit()
- for metric in model._metrics:
- metric.accumulate()
- metric.reset()
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser("Yolov3 Training on COCO")
- parser.add_argument('data', metavar='DIR', help='path to COCO dataset')
- parser.add_argument(
- "-d", "--dynamic", action='store_true', help="enable dygraph mode")
- parser.add_argument(
- "-e", "--epoch", default=300, type=int, help="number of epoch")
- parser.add_argument(
- '--lr', '--learning-rate', default=0.001, type=float, metavar='LR',
- help='initial learning rate')
- parser.add_argument(
- "-b", "--batch_size", default=64, type=int, help="batch size")
- parser.add_argument(
- "-n", "--num_devices", default=8, type=int, help="number of devices")
- parser.add_argument(
- "-p", "--pretrain_weights", default=None, type=str,
- help="path to pretrained weights")
- parser.add_argument(
- "-w", "--weights", default=None, type=str,
- help="path to model weights")
- FLAGS = parser.parse_args()
- assert FLAGS.data, "error: must provide data path"
- main()
diff --git a/yolov3/README.md b/yolov3/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc6d302a544f9b8d9cd06fc363b90d053919a5e9
--- /dev/null
+++ b/yolov3/README.md
@@ -0,0 +1,203 @@
+# YOLOv3 目标检测模型
+
+---
+
+## 内容
+
+- [模型简介](#模型简介)
+- [快速开始](#快速开始)
+- [参考论文](#参考论文)
+
+
+## 模型简介
+
+[YOLOv3](https://arxiv.org/abs/1804.02767) 是由 [Joseph Redmon](https://arxiv.org/search/cs?searchtype=author&query=Redmon%2C+J) 和 [Ali Farhadi](https://arxiv.org/search/cs?searchtype=author&query=Farhadi%2C+A) 提出的单阶段检测器, 该检测器与达到同样精度的传统目标检测方法相比,推断速度能达到接近两倍.
+
+传统目标检测方法通过两阶段检测,第一阶段生成预选框,第二阶段对预选框进行分类和位置坐标的调整,而YOLO将目标检测看做是对框位置和类别概率的一个单阶段回归问题,使得YOLO能达到近两倍的检测速度。而YOLOv3在YOLO的基础上引入的多尺度预测,使得YOLOv3网络对于小物体的检测精度大幅提高。
+
+[YOLOv3](https://arxiv.org/abs/1804.02767) 是一阶段End2End的目标检测器。其目标检测原理如下图所示:
+
+
+YOLOv3检测原理
+
+
+YOLOv3将输入图像分成S\*S个格子,每个格子预测B个bounding box,每个bounding box预测内容包括: Location(x, y, w, h)、Confidence Score和C个类别的概率,因此YOLOv3输出层的channel数为B\*(5 + C)。YOLOv3的loss函数也有三部分组成:Location误差,Confidence误差和分类误差。
+
+YOLOv3的网络结构如下图所示:
+
+
+YOLOv3网络结构
+
+
+YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层和输出层组成。
+
+1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络:DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。
+
+2. 特征融合层。为了解决之前YOLO版本对小目标不敏感的问题,YOLOv3采用了3个不同尺度的特征图来进行目标检测,分别为13\*13,26\*26,52\*52,用来检测大、中、小三种目标。特征融合层选取 DarkNet 产出的三种尺度特征图作为输入,借鉴了FPN(feature pyramid networks)的思想,通过一系列的卷积层和上采样对各尺度的特征图进行融合。
+
+3. 输出层。同样使用了全卷积结构,其中最后一个卷积层的卷积核个数是255:3\*(80+4+1)=255,3表示一个grid cell包含3个bounding box,4表示框的4个坐标信息,1表示Confidence Score,80表示COCO数据集中80个类别的概率。
+
+
+## 快速开始
+
+### 安装说明
+
+#### paddle安装
+
+ 本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
+
+#### 代码下载及环境变量设置
+
+ 克隆代码库到本地,并设置`PYTHONPATH`环境变量
+
+ ```bash
+ git clone https://github.com/PaddlePaddle/hapi
+ cd hapi
+ export PYTHONPATH=$PYTHONPATH:`pwd`
+ cd tsm
+ ```
+
+#### 安装COCO-API
+
+ 训练前需要首先下载[COCO-API](https://github.com/cocodataset/cocoapi):
+
+ ```bash
+ git clone https://github.com/cocodataset/cocoapi.git
+ cd cocoapi/PythonAPI
+ # if cython is not installed
+ pip install Cython
+ # Install into global site-packages
+ make install
+ # Alternatively, if you do not have permissions or prefer
+ # not to install the COCO API into global site-packages
+ python setup.py install --user
+ ```
+
+### 数据准备
+
+模型目前支持COCO数据集格式的数据读入和精度评估,我们同时提供了将转换为COCO数据集的格式的Pascal VOC数据集下载,可通过如下命令下载。
+
+ ```bash
+ python dataset/download_voc.py
+ ```
+
+数据目录结构如下:
+
+ ```
+ dataset/voc/
+ ├── annotations
+ │ ├── instances_train2017.json
+ │ ├── instances_val2017.json
+ | ...
+ ├── train2017
+ │ ├── 1013.jpg
+ │ ├── 1014.jpg
+ | ...
+ ├── val2017
+ │ ├── 2551.jpg
+ │ ├── 2552.jpg
+ | ...
+ ```
+
+### 模型训练
+
+数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`yolo_checkpoint`目录下。
+
+YOLOv3模型训练总batch_size为64训练,以下以使用4卡Tesla P40每卡batch_size为16训练介绍训练方式。对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数。
+
+
+`main.py`脚本参数可通过如下命令查询
+
+```bash
+python main.py --help
+```
+
+#### 静态图训练
+
+使用如下方式进行多卡训练:
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --data= --batch_size=16
+```
+
+#### 动态图训练
+
+动态图训练只需要在运行脚本时添加`-d`参数即可。
+
+使用如下方式进行多卡训练:
+
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py -m paddle.distributed.launch --data= --batch_size=16 -d
+```
+
+
+### 模型评估
+
+YOLOv3模型输出为LoDTensor,只支持使用batch_size为1进行评估,可通过如下两种方式进行模型评估。
+
+1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估
+
+```bash
+python main.py --data=dataset/voc --eval_only
+```
+
+2. 加载checkpoint进行精度评估
+
+```bash
+python main.py --data=dataset/voc --eval_only --weights=yolo_checkpoint/no_mixup/final
+```
+
+同样可以通过指定`-d`参数进行动态图模式的评估。
+
+#### 评估精度
+
+在10类小数据集下训练模型权重见[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams),评估精度如下:
+
+```bash
+Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.503
+Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.779
+Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.562
+Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.190
+Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.390
+Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.578
+Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.405
+Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.591
+Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.599
+Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.294
+Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.506
+Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.670
+```
+
+### 模型推断及可视化
+
+可通过如下两种方式进行模型推断。
+
+1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估
+
+```bash
+python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg
+```
+
+2. 加载checkpoint进行精度评估
+
+```bash
+python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg --weights=yolo_checkpoint/mo_mixup/final
+```
+
+推断结果可视化图像会保存于`--output`指定的文件夹下,默认保存于`./output`目录。
+
+模型推断会输出如下检测结果日志:
+
+```text
+2020-04-02 08:26:47,268-INFO: detect bicycle at [116.14993, 127.278336, 579.7716, 438.44214] score: 0.97
+2020-04-02 08:26:47,273-INFO: detect dog at [127.44086, 215.71997, 316.04276, 539.7584] score: 0.99
+2020-04-02 08:26:47,274-INFO: detect car at [475.42343, 80.007484, 687.16095, 171.27374] score: 0.98
+2020-04-02 08:26:47,274-INFO: Detection bbox results save in output/dog.jpg
+```
+
+## 参考论文
+
+- [You Only Look Once: Unified, Real-Time Object Detection](https://arxiv.org/abs/1506.02640v5), Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi.
+- [YOLOv3: An Incremental Improvement](https://arxiv.org/abs/1804.02767v1), Joseph Redmon, Ali Farhadi.
+- [Bag of Freebies for Training Object Detection Neural Networks](https://arxiv.org/abs/1902.04103v3), Zhi Zhang, Tong He, Hang Zhang, Zhongyue Zhang, Junyuan Xie, Mu Li.
+
diff --git a/yolov3/__init__.py b/yolov3/__init__.py
deleted file mode 100644
index 9118340d83fefa17d4a7e8fc577ee22a2d3a2656..0000000000000000000000000000000000000000
--- a/yolov3/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
diff --git a/yolov3/coco.py b/yolov3/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..34809246c1f90d3ad029842c19ae5f2c3eba08b0
--- /dev/null
+++ b/yolov3/coco.py
@@ -0,0 +1,275 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import cv2
+import numpy as np
+from pycocotools.coco import COCO
+
+from paddle.fluid.io import Dataset
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['COCODataset']
+
+
+class COCODataset(Dataset):
+ """
+ Load dataset with MS-COCO format.
+
+ Args:
+ dataset_dir (str): root directory for dataset.
+ image_dir (str): directory for images.
+ anno_path (str): voc annotation file path.
+ sample_num (int): number of samples to load, -1 means all.
+ use_default_label (bool): whether use the default mapping of
+ label to integer index. Default True.
+ with_background (bool): whether load background as a class,
+ default True.
+ transform (callable): callable transform to perform on samples,
+ default None.
+ mixup (bool): whether return image mixup samples, default False.
+ alpha (float): alpha factor of beta distribution to generate
+ mixup score, used only when mixup is True, default 1.5
+ beta (float): beta factor of beta distribution to generate
+ mixup score, used only when mixup is True, default 1.5
+ """
+
+ def __init__(self,
+ dataset_dir='',
+ image_dir='',
+ anno_path='',
+ sample_num=-1,
+ with_background=True,
+ transform=None,
+ mixup=False,
+ alpha=1.5,
+ beta=1.5):
+ # roidbs is list of dict whose structure is:
+ # {
+ # 'im_file': im_fname, # image file name
+ # 'im_id': im_id, # image id
+ # 'h': im_h, # height of image
+ # 'w': im_w, # width
+ # 'is_crowd': is_crowd,
+ # 'gt_class': gt_class,
+ # 'gt_bbox': gt_bbox,
+ # 'gt_score': gt_score,
+ # 'difficult': difficult
+ # }
+
+ self._anno_path = os.path.join(dataset_dir, anno_path)
+ self._image_dir = os.path.join(dataset_dir, image_dir)
+ assert os.path.exists(self._anno_path), \
+ "anno_path {} not exists".format(anno_path)
+ assert os.path.exists(self._image_dir), \
+ "image_dir {} not exists".format(image_dir)
+
+ self._sample_num = sample_num
+ self._with_background = with_background
+ self._transform = transform
+ self._mixup = mixup
+ self._alpha = alpha
+ self._beta = beta
+
+ # load in dataset roidbs
+ self._load_roidb_and_cname2cid()
+
+ def _load_roidb_and_cname2cid(self):
+ assert self._anno_path.endswith('.json'), \
+ 'invalid coco annotation file: ' + anno_path
+ coco = COCO(self._anno_path)
+ img_ids = coco.getImgIds()
+ cat_ids = coco.getCatIds()
+ records = []
+ ct = 0
+
+ # when with_background = True, mapping category to classid, like:
+ # background:0, first_class:1, second_class:2, ...
+ catid2clsid = dict({
+ catid: i + int(self._with_background)
+ for i, catid in enumerate(cat_ids)
+ })
+ cname2cid = dict({
+ coco.loadCats(catid)[0]['name']: clsid
+ for catid, clsid in catid2clsid.items()
+ })
+
+ for img_id in img_ids:
+ img_anno = coco.loadImgs(img_id)[0]
+ im_fname = img_anno['file_name']
+ im_w = float(img_anno['width'])
+ im_h = float(img_anno['height'])
+
+ ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
+ instances = coco.loadAnns(ins_anno_ids)
+
+ bboxes = []
+ for inst in instances:
+ x, y, box_w, box_h = inst['bbox']
+ x1 = max(0, x)
+ y1 = max(0, y)
+ x2 = min(im_w - 1, x1 + max(0, box_w - 1))
+ y2 = min(im_h - 1, y1 + max(0, box_h - 1))
+ if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
+ inst['clean_bbox'] = [x1, y1, x2, y2]
+ bboxes.append(inst)
+ else:
+ logger.warn(
+ 'Found an invalid bbox in annotations: im_id: {}, '
+ 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
+ img_id, float(inst['area']), x1, y1, x2, y2))
+ num_bbox = len(bboxes)
+
+ gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
+ gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
+ gt_score = np.ones((num_bbox, 1), dtype=np.float32)
+ is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
+ difficult = np.zeros((num_bbox, 1), dtype=np.int32)
+ gt_poly = [None] * num_bbox
+
+ for i, box in enumerate(bboxes):
+ catid = box['category_id']
+ gt_class[i][0] = catid2clsid[catid]
+ gt_bbox[i, :] = box['clean_bbox']
+ is_crowd[i][0] = box['iscrowd']
+ if 'segmentation' in box:
+ gt_poly[i] = box['segmentation']
+
+ im_fname = os.path.join(self._image_dir,
+ im_fname) if self._image_dir else im_fname
+ coco_rec = {
+ 'im_file': im_fname,
+ 'im_id': np.array([img_id]),
+ 'h': im_h,
+ 'w': im_w,
+ 'is_crowd': is_crowd,
+ 'gt_class': gt_class,
+ 'gt_bbox': gt_bbox,
+ 'gt_score': gt_score,
+ 'gt_poly': gt_poly,
+ }
+
+ records.append(coco_rec)
+ ct += 1
+ if self._sample_num > 0 and ct >= self._sample_num:
+ break
+ assert len(records) > 0, 'not found any coco record in %s' % (self._anno_path)
+ logger.info('{} samples in file {}'.format(ct, self._anno_path))
+ self._roidbs, self._cname2cid = records, cname2cid
+
+ @property
+ def num_classes(self):
+ return len(self._cname2cid)
+
+ def __len__(self):
+ return len(self._roidbs)
+
+ def _getitem_by_index(self, idx):
+ roidb = self._roidbs[idx]
+ with open(roidb['im_file'], 'rb') as f:
+ data = np.frombuffer(f.read(), dtype='uint8')
+ im = cv2.imdecode(data, 1)
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ im_info = np.array([roidb['im_id'][0], roidb['h'], roidb['w']], dtype='int32')
+ gt_bbox = roidb['gt_bbox']
+ gt_class = roidb['gt_class']
+ gt_score = roidb['gt_score']
+ return im_info, im, gt_bbox, gt_class, gt_score
+
+ def __getitem__(self, idx):
+ im_info, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx)
+
+ if self._mixup:
+ mixup_idx = idx + np.random.randint(1, self.__len__())
+ mixup_idx %= self.__len__()
+ _, mixup_im, mixup_bbox, mixup_class, _ = \
+ self._getitem_by_index(mixup_idx)
+
+ im, gt_bbox, gt_class, gt_score = \
+ self._mixup_image(im, gt_bbox, gt_class, mixup_im,
+ mixup_bbox, mixup_class)
+
+ if self._transform:
+ im_info, im, gt_bbox, gt_class, gt_score = \
+ self._transform(im_info, im, gt_bbox, gt_class, gt_score)
+
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+ def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2):
+ factor = np.random.beta(self._alpha, self._beta)
+ factor = max(0.0, min(1.0, factor))
+ if factor >= 1.0:
+ return img1, bbox1, class1, np.ones_like(class1, dtype="float32")
+ if factor <= 0.0:
+ return img2, bbox2, class2, np.ones_like(class2, dtype="float32")
+
+ h = max(img1.shape[0], img2.shape[0])
+ w = max(img1.shape[1], img2.shape[1])
+ img = np.zeros((h, w, img1.shape[2]), 'float32')
+ img[:img1.shape[0], :img1.shape[1], :] = \
+ img1.astype('float32') * factor
+ img[:img2.shape[0], :img2.shape[1], :] += \
+ img2.astype('float32') * (1.0 - factor)
+
+ gt_bbox = np.concatenate((bbox1, bbox2), axis=0)
+ gt_class = np.concatenate((class1, class2), axis=0)
+
+ score1 = np.ones_like(class1, dtype="float32") * factor
+ score2 = np.ones_like(class2, dtype="float32") * (1.0 - factor)
+ gt_score = np.concatenate((score1, score2), axis=0)
+
+ return img, gt_bbox, gt_class, gt_score
+
+ @property
+ def mixup(self):
+ return self._mixup
+
+ @mixup.setter
+ def mixup(self, value):
+ if not isinstance(value, bool):
+ raise ValueError("mixup should be a boolean number")
+ logger.info("{} set mixup to {}".format(self, value))
+ self._mixup = value
+
+def pascalvoc_label(with_background=True):
+ labels_map = {
+ 'aeroplane': 1,
+ 'bicycle': 2,
+ 'bird': 3,
+ 'boat': 4,
+ 'bottle': 5,
+ 'bus': 6,
+ 'car': 7,
+ 'cat': 8,
+ 'chair': 9,
+ 'cow': 10,
+ 'diningtable': 11,
+ 'dog': 12,
+ 'horse': 13,
+ 'motorbike': 14,
+ 'person': 15,
+ 'pottedplant': 16,
+ 'sheep': 17,
+ 'sofa': 18,
+ 'train': 19,
+ 'tvmonitor': 20
+ }
+ if not with_background:
+ labels_map = {k: v - 1 for k, v in labels_map.items()}
+ return labels_map
diff --git a/yolov3/coco_metric.py b/yolov3/coco_metric.py
index ec7bcac24b3dde91d3ae85e39e7bf9e5151f43ec..2f2f9825b1f90c08afa7b6089641d5a4b28be51d 100644
--- a/yolov3/coco_metric.py
+++ b/yolov3/coco_metric.py
@@ -17,8 +17,6 @@ import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
-from metrics import Metric
-
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -26,12 +24,13 @@ logger = logging.getLogger(__name__)
__all__ = ['COCOMetric']
-
OUTFILE = './bbox.json'
-# considered to change to a callback later
-class COCOMetric(Metric):
+# COCOMetric behavior is different from Metric defined in high
+# level API, COCOMetric will and con only accumulate on the epoch
+# end, so we impliment COCOMetric as not a high level API Metric
+class COCOMetric():
"""
Metrci for MS-COCO dataset, only support update with batch
size as 1.
@@ -43,26 +42,24 @@ class COCOMetric(Metric):
"""
def __init__(self, anno_path, with_background=True, **kwargs):
- super(COCOMetric, self).__init__(**kwargs)
self.anno_path = anno_path
self.with_background = with_background
self.bbox_results = []
self.coco_gt = COCO(anno_path)
cat_ids = self.coco_gt.getCatIds()
- self.clsid2catid = dict(
- {i + int(with_background): catid
- for i, catid in enumerate(cat_ids)})
+ self.clsid2catid = dict(
+ {i + int(with_background): catid
+ for i, catid in enumerate(cat_ids)})
- def update(self, preds, *args, **kwargs):
- im_ids, bboxes = preds
- assert im_ids.shape[0] == 1, \
+ def update(self, img_id, bboxes):
+ assert img_id.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch
return
- im_id = int(im_ids)
+ img_id = int(img_id)
for i in range(bboxes.shape[0]):
dt = bboxes[i, :]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
@@ -72,7 +69,7 @@ class COCOMetric(Metric):
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
- 'image_id': im_id,
+ 'image_id': img_id,
'category_id': catid,
'bbox': bbox,
'score': score
@@ -83,30 +80,30 @@ class COCOMetric(Metric):
self.bbox_results = []
def accumulate(self):
- if len(self.bbox_results) == 0:
- logger.warning("The number of valid bbox detected is zero.\n \
- Please use reasonable model and check input data.\n \
- stop COCOMetric accumulate!")
- return [0.0]
- with open(OUTFILE, 'w') as f:
- json.dump(self.bbox_results, f)
-
- map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
- # flush coco evaluation result
- sys.stdout.flush()
+ if len(self.bbox_results) == 0:
+ logger.warning("The number of valid bbox detected is zero.\n \
+ Please use reasonable model and check input data.\n \
+ stop COCOMetric accumulate!")
+ return [0.0]
+ with open(OUTFILE, 'w') as f:
+ json.dump(self.bbox_results, f)
+
+ map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
+ # flush coco evaluation result
+ sys.stdout.flush()
self.result = map_stats[0]
- return self.result
+ return [self.result]
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
- assert coco_gt != None or anno_file != None
-
- if coco_gt == None:
- coco_gt = COCO(anno_file)
- logger.info("Start evaluate...")
- coco_dt = coco_gt.loadRes(jsonfile)
- coco_eval = COCOeval(coco_gt, coco_dt, style)
- coco_eval.evaluate()
- coco_eval.accumulate()
- coco_eval.summarize()
- return coco_eval.stats
+ assert coco_gt != None or anno_file != None
+
+ if coco_gt == None:
+ coco_gt = COCO(anno_file)
+ logger.info("Start evaluate...")
+ coco_dt = coco_gt.loadRes(jsonfile)
+ coco_eval = COCOeval(coco_gt, coco_dt, style)
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ return coco_eval.stats
diff --git a/yolov3/dataset/download_voc.py b/yolov3/dataset/download_voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b064ed4034e5fa1471c8094a78266d531d9c111
--- /dev/null
+++ b/yolov3/dataset/download_voc.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import sys
+import tarfile
+
+from models.download import _download
+
+import logging
+logger = logging.getLogger(__name__)
+
+DATASETS = {
+ 'voc': [
+ ('https://paddlemodels.bj.bcebos.com/hapi/voc.tar',
+ '9faeb7fd997aeea843092fd608d5bcb4', ),
+ ],
+}
+
+def download_decompress_file(data_dir, url, md5):
+ logger.info("Downloading from {}".format(url))
+ tar_file = _download(url, data_dir, md5)
+ logger.info("Decompressing {}".format(tar_file))
+ with tarfile.open(tar_file) as tf:
+ tf.extractall(path=data_dir)
+ os.remove(tar_file)
+
+
+if __name__ == "__main__":
+ data_dir = osp.split(osp.realpath(sys.argv[0]))[0]
+ for name, infos in DATASETS.items():
+ for info in infos:
+ download_decompress_file(data_dir, *info)
+
diff --git a/yolov3/image/YOLOv3.jpg b/yolov3/image/YOLOv3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..06b81f545247c1d542fd661f947eb0cf3edc480e
Binary files /dev/null and b/yolov3/image/YOLOv3.jpg differ
diff --git a/yolov3/image/YOLOv3_structure.jpg b/yolov3/image/YOLOv3_structure.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..51bd2d1733e2f78945d3e871cb5b649aad95d633
Binary files /dev/null and b/yolov3/image/YOLOv3_structure.jpg differ
diff --git a/yolov3/image/dog.jpg b/yolov3/image/dog.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..77b0381222eaed50867643f4166092c781e56d5b
Binary files /dev/null and b/yolov3/image/dog.jpg differ
diff --git a/yolov3/infer.py b/yolov3/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f19e86615a0b1c8c57f3469f5a5bdcaa85535e9c
--- /dev/null
+++ b/yolov3/infer.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import argparse
+import numpy as np
+from PIL import Image
+
+from paddle import fluid
+from paddle.fluid.optimizer import Momentum
+from paddle.fluid.io import DataLoader
+
+from model import Model, Input, set_device
+from models import yolov3_darknet53, YoloLoss
+
+from coco import COCODataset
+from transforms import *
+from visualizer import draw_bbox
+
+import logging
+logger = logging.getLogger(__name__)
+
+IMAGE_MEAN = [0.485, 0.456, 0.406]
+IMAGE_STD = [0.229, 0.224, 0.225]
+
+
+def get_save_image_name(output_dir, image_path):
+ """
+ Get save image name from source image path.
+ """
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ image_name = os.path.split(image_path)[-1]
+ name, ext = os.path.splitext(image_name)
+ return os.path.join(output_dir, "{}".format(name)) + ext
+
+
+def load_labels(label_list, with_background=True):
+ idx = int(with_background)
+ cat2name = {}
+ with open(label_list) as f:
+ for line in f.readlines():
+ line = line.strip()
+ if line:
+ cat2name[idx] = line
+ idx += 1
+ return cat2name
+
+
+def main():
+ device = set_device(FLAGS.device)
+ fluid.enable_dygraph(device) if FLAGS.dynamic else None
+
+ inputs = [Input([None, 3], 'int32', name='img_info'),
+ Input([None, 3, None, None], 'float32', name='image')]
+
+ cat2name = load_labels(FLAGS.label_list, with_background=False)
+
+ model = yolov3_darknet53(num_classes=len(cat2name),
+ model_mode='test',
+ pretrained=FLAGS.weights is None)
+
+ model.prepare(inputs=inputs, device=FLAGS.device)
+
+ if FLAGS.weights is not None:
+ model.load(FLAGS.weights, reset_optimizer=True)
+
+ # image preprocess
+ orig_img = Image.open(FLAGS.infer_image).convert('RGB')
+ w, h = orig_img.size
+ img = orig_img.resize((608, 608), Image.BICUBIC)
+ img = np.array(img).astype('float32') / 255.0
+ img -= np.array(IMAGE_MEAN)
+ img /= np.array(IMAGE_STD)
+ img = img.transpose((2, 0, 1))[np.newaxis, :]
+ img_info = np.array([0, h, w]).astype('int32')[np.newaxis, :]
+
+ _, bboxes = model.test([img_info, img])
+
+ vis_img = draw_bbox(orig_img, cat2name, bboxes, FLAGS.draw_threshold)
+ save_name = get_save_image_name(FLAGS.output_dir, FLAGS.infer_image)
+ logger.info("Detection bbox results save in {}".format(save_name))
+ vis_img.save(save_name, quality=95)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser("Yolov3 Training on VOC")
+ parser.add_argument(
+ "--device", type=str, default='gpu', help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_true', help="enable dygraph mode")
+ parser.add_argument(
+ "--label_list", type=str, default=None,
+ help="path to category label list file")
+ parser.add_argument(
+ "-t", "--draw_threshold", type=float, default=0.5,
+ help="threshold to reserve the result for visualization")
+ parser.add_argument(
+ "-i", "--infer_image", type=str, default=None,
+ help="image path for inference")
+ parser.add_argument(
+ "-o", "--output_dir", type=str, default='output',
+ help="directory to save inference result if --visualize is set")
+ parser.add_argument(
+ "-w", "--weights", default=None, type=str,
+ help="path to weights for inference")
+ FLAGS = parser.parse_args()
+ assert os.path.isfile(FLAGS.infer_image), \
+ "infer_image {} not a file".format(FLAGS.infer_image)
+ assert os.path.isfile(FLAGS.label_list), \
+ "label_list {} not a file".format(FLAGS.label_list)
+ main()
diff --git a/yolov3/main.py b/yolov3/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c24d196877586475f6aba1f949c3207665fcce
--- /dev/null
+++ b/yolov3/main.py
@@ -0,0 +1,208 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import contextlib
+import os
+
+import numpy as np
+
+from paddle import fluid
+from paddle.fluid.optimizer import Momentum
+from paddle.fluid.io import DataLoader
+
+from model import Model, Input, set_device
+from distributed import DistributedBatchSampler
+from models import yolov3_darknet53, YoloLoss
+
+from coco_metric import COCOMetric
+from coco import COCODataset
+from transforms import *
+
+NUM_MAX_BOXES = 50
+
+
+def make_optimizer(step_per_epoch, parameter_list=None):
+ base_lr = FLAGS.lr
+ warm_up_iter = 1000
+ momentum = 0.9
+ weight_decay = 5e-4
+ boundaries = [step_per_epoch * e for e in [200, 250]]
+ values = [base_lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
+ learning_rate = fluid.layers.piecewise_decay(
+ boundaries=boundaries,
+ values=values)
+ learning_rate = fluid.layers.linear_lr_warmup(
+ learning_rate=learning_rate,
+ warmup_steps=warm_up_iter,
+ start_lr=0.0,
+ end_lr=base_lr)
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=learning_rate,
+ regularization=fluid.regularizer.L2Decay(weight_decay),
+ momentum=momentum,
+ parameter_list=parameter_list)
+ return optimizer
+
+
+def main():
+ device = set_device(FLAGS.device)
+ fluid.enable_dygraph(device) if FLAGS.dynamic else None
+
+ inputs = [Input([None, 3], 'int32', name='img_info'),
+ Input([None, 3, None, None], 'float32', name='image')]
+ labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
+ Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
+ Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')]
+
+ if not FLAGS.eval_only: # training mode
+ train_transform = Compose([ColorDistort(),
+ RandomExpand(),
+ RandomCrop(),
+ RandomFlip(),
+ NormalizeBox(),
+ PadBox(),
+ BboxXYXY2XYWH()])
+ train_collate_fn = BatchCompose([RandomShape(), NormalizeImage()])
+ dataset = COCODataset(dataset_dir=FLAGS.data,
+ anno_path='annotations/instances_train2017.json',
+ image_dir='train2017',
+ with_background=False,
+ mixup=True,
+ transform=train_transform)
+ batch_sampler = DistributedBatchSampler(dataset,
+ batch_size=FLAGS.batch_size,
+ shuffle=True,
+ drop_last=True)
+ loader = DataLoader(dataset,
+ batch_sampler=batch_sampler,
+ places=device,
+ num_workers=FLAGS.num_workers,
+ return_list=True,
+ collate_fn=train_collate_fn)
+ else: # evaluation mode
+ eval_transform = Compose([ResizeImage(target_size=608),
+ NormalizeBox(),
+ PadBox(),
+ BboxXYXY2XYWH()])
+ eval_collate_fn = BatchCompose([NormalizeImage()])
+ dataset = COCODataset(dataset_dir=FLAGS.data,
+ anno_path='annotations/instances_val2017.json',
+ image_dir='val2017',
+ with_background=False,
+ transform=eval_transform)
+ # batch_size can only be 1 in evaluation for YOLOv3
+ # prediction bbox is a LoDTensor
+ batch_sampler = DistributedBatchSampler(dataset,
+ batch_size=1,
+ shuffle=False,
+ drop_last=False)
+ loader = DataLoader(dataset,
+ batch_sampler=batch_sampler,
+ places=device,
+ num_workers=FLAGS.num_workers,
+ return_list=True,
+ collate_fn=eval_collate_fn)
+
+ pretrained = FLAGS.eval_only and FLAGS.weights is None
+ model = yolov3_darknet53(num_classes=dataset.num_classes,
+ model_mode='eval' if FLAGS.eval_only else 'train',
+ pretrained=pretrained)
+
+ if FLAGS.pretrain_weights is not None:
+ model.load(FLAGS.pretrain_weights, skip_mismatch=True, reset_optimizer=True)
+
+ optim = make_optimizer(len(batch_sampler), parameter_list=model.parameters())
+
+ model.prepare(optim,
+ YoloLoss(num_classes=dataset.num_classes),
+ inputs=inputs, labels=labels,
+ device=FLAGS.device)
+
+ # NOTE: we implement COCO metric of YOLOv3 model here, separately
+ # from 'prepare' and 'fit' framework for follwing reason:
+ # 1. YOLOv3 network structure is different between 'train' and
+ # 'eval' mode, in 'eval' mode, output prediction bbox is not the
+ # feature map used for YoloLoss calculating
+ # 2. COCO metric behavior is also different from defined Metric
+ # for COCO metric should not perform accumulate in each iteration
+ # but only accumulate at the end of an epoch
+ if FLAGS.eval_only:
+ if FLAGS.weights is not None:
+ model.load(FLAGS.weights, reset_optimizer=True)
+ preds = model.predict(loader, stack_outputs=False)
+ _, _, _, img_ids, bboxes = preds
+
+ anno_path = os.path.join(FLAGS.data, 'annotations/instances_val2017.json')
+ coco_metric = COCOMetric(anno_path=anno_path, with_background=False)
+ for img_id, bbox in zip(img_ids, bboxes):
+ coco_metric.update(img_id, bbox)
+ coco_metric.accumulate()
+ coco_metric.reset()
+ return
+
+ if FLAGS.resume is not None:
+ model.load(FLAGS.resume)
+
+ model.fit(train_data=loader,
+ epochs=FLAGS.epoch - FLAGS.no_mixup_epoch,
+ save_dir="yolo_checkpoint/mixup",
+ save_freq=10)
+
+ # do not use image mixup transfrom in laste FLAGS.no_mixup_epoch epoches
+ dataset.mixup = False
+ model.fit(train_data=loader,
+ epochs=FLAGS.no_mixup_epoch,
+ save_dir="yolo_checkpoint/no_mixup",
+ save_freq=5)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser("Yolov3 Training on VOC")
+ parser.add_argument(
+ "--data", type=str, default='dataset/voc',
+ help="path to dataset directory")
+ parser.add_argument(
+ "--device", type=str, default='gpu', help="device to use, gpu or cpu")
+ parser.add_argument(
+ "-d", "--dynamic", action='store_true', help="enable dygraph mode")
+ parser.add_argument(
+ "--eval_only", action='store_true', help="run evaluation only")
+ parser.add_argument(
+ "-e", "--epoch", default=300, type=int, help="number of epoch")
+ parser.add_argument(
+ "--no_mixup_epoch", default=30, type=int,
+ help="number of the last N epoch without image mixup")
+ parser.add_argument(
+ '--lr', '--learning-rate', default=0.001, type=float, metavar='LR',
+ help='initial learning rate')
+ parser.add_argument(
+ "-b", "--batch_size", default=8, type=int, help="batch size")
+ parser.add_argument(
+ "-j", "--num_workers", default=4, type=int, help="reader worker number")
+ parser.add_argument(
+ "-p", "--pretrain_weights", default=None, type=str,
+ help="path to pretrained weights")
+ parser.add_argument(
+ "-r", "--resume", default=None, type=str,
+ help="path to model weights")
+ parser.add_argument(
+ "-w", "--weights", default=None, type=str,
+ help="path to weights for evaluation")
+ FLAGS = parser.parse_args()
+ assert FLAGS.data, "error: must provide data path"
+ main()
diff --git a/yolov3/transforms.py b/yolov3/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5fbe46cbbfdb39efe3025a351b407b82dbf33c4
--- /dev/null
+++ b/yolov3/transforms.py
@@ -0,0 +1,620 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import cv2
+import traceback
+import numpy as np
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['ColorDistort', 'RandomExpand', 'RandomCrop', 'RandomFlip',
+ 'NormalizeBox', 'PadBox', 'RandomShape', 'NormalizeImage',
+ 'BboxXYXY2XYWH', 'ResizeImage', 'Compose', 'BatchCompose']
+
+
+class Compose(object):
+ def __init__(self, transforms=[]):
+ self.transforms = transforms
+
+ def __call__(self, *data):
+ for f in self.transforms:
+ try:
+ data = f(*data)
+ except Exception as e:
+ stack_info = traceback.format_exc()
+ logger.info("fail to perform transform [{}] with error: "
+ "{} and stack:\n{}".format(f, e, str(stack_info)))
+ raise e
+ return data
+
+
+class BatchCompose(object):
+ def __init__(self, transforms=[]):
+ self.transforms = transforms
+
+ def __call__(self, data):
+ for f in self.transforms:
+ try:
+ data = f(data)
+ except Exception as e:
+ stack_info = traceback.format_exc()
+ logger.info("fail to perform batch transform [{}] with error: "
+ "{} and stack:\n{}".format(f, e, str(stack_info)))
+ raise e
+
+ # sample list to batch data
+ batch = list(zip(*data))
+
+ return batch
+
+
+class ColorDistort(object):
+ """Random color distortion.
+
+ Args:
+ hue (list): hue settings.
+ in [lower, upper, probability] format.
+ saturation (list): saturation settings.
+ in [lower, upper, probability] format.
+ contrast (list): contrast settings.
+ in [lower, upper, probability] format.
+ brightness (list): brightness settings.
+ in [lower, upper, probability] format.
+ random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
+ order.
+ """
+
+ def __init__(self,
+ hue=[-18, 18, 0.5],
+ saturation=[0.5, 1.5, 0.5],
+ contrast=[0.5, 1.5, 0.5],
+ brightness=[0.5, 1.5, 0.5],
+ random_apply=True):
+ self.hue = hue
+ self.saturation = saturation
+ self.contrast = contrast
+ self.brightness = brightness
+ self.random_apply = random_apply
+
+ def apply_hue(self, img):
+ low, high, prob = self.hue
+ if np.random.uniform(0., 1.) < prob:
+ return img
+
+ img = img.astype(np.float32)
+
+ # XXX works, but result differ from HSV version
+ delta = np.random.uniform(low, high)
+ u = np.cos(delta * np.pi)
+ w = np.sin(delta * np.pi)
+ bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
+ tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
+ [0.211, -0.523, 0.311]])
+ ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
+ [1.0, -1.107, 1.705]])
+ t = np.dot(np.dot(ityiq, bt), tyiq).T
+ img = np.dot(img, t)
+ return img
+
+ def apply_saturation(self, img):
+ low, high, prob = self.saturation
+ if np.random.uniform(0., 1.) < prob:
+ return img
+ delta = np.random.uniform(low, high)
+
+ img = img.astype(np.float32)
+ gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
+ gray = gray.sum(axis=2, keepdims=True)
+ gray *= (1.0 - delta)
+ img *= delta
+ img += gray
+ return img
+
+ def apply_contrast(self, img):
+ low, high, prob = self.contrast
+ if np.random.uniform(0., 1.) < prob:
+ return img
+ delta = np.random.uniform(low, high)
+
+ img = img.astype(np.float32)
+ img *= delta
+ return img
+
+ def apply_brightness(self, img):
+ low, high, prob = self.brightness
+ if np.random.uniform(0., 1.) < prob:
+ return img
+ delta = np.random.uniform(low, high)
+
+ img = img.astype(np.float32)
+ img += delta
+ return img
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ if self.random_apply:
+ distortions = np.random.permutation([
+ self.apply_brightness, self.apply_contrast,
+ self.apply_saturation, self.apply_hue
+ ])
+ for func in distortions:
+ im = func(im)
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+ im = self.apply_brightness(im)
+
+ if np.random.randint(0, 2):
+ im = self.apply_contrast(im)
+ im = self.apply_saturation(im)
+ im = self.apply_hue(im)
+ else:
+ im = self.apply_saturation(im)
+ im = self.apply_hue(im)
+ im = self.apply_contrast(im)
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+
+class RandomExpand(object):
+ """Random expand the canvas.
+
+ Args:
+ ratio (float): maximum expansion ratio.
+ prob (float): probability to expand.
+ fill_value (list): color value used to fill the canvas. in RGB order.
+ """
+
+ def __init__(self, ratio=4., prob=0.5, fill_value=[123.675, 116.28, 103.53]):
+ assert ratio > 1.01, "expand ratio must be larger than 1.01"
+ self.ratio = ratio
+ self.prob = prob
+ self.fill_value = fill_value
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ if np.random.uniform(0., 1.) < self.prob:
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+ height, width, _ = im.shape
+ expand_ratio = np.random.uniform(1., self.ratio)
+ h = int(height * expand_ratio)
+ w = int(width * expand_ratio)
+ if not h > height or not w > width:
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+ y = np.random.randint(0, h - height)
+ x = np.random.randint(0, w - width)
+ canvas = np.ones((h, w, 3), dtype=np.uint8)
+ canvas *= np.array(self.fill_value, dtype=np.uint8)
+ canvas[y:y + height, x:x + width, :] = im.astype(np.uint8)
+
+ gt_bbox += np.array([x, y, x, y], dtype=np.float32)
+
+ return [im_info, canvas, gt_bbox, gt_class, gt_score]
+
+
+class RandomCrop():
+ """Random crop image and bboxes.
+
+ Args:
+ aspect_ratio (list): aspect ratio of cropped region.
+ in [min, max] format.
+ thresholds (list): iou thresholds for decide a valid bbox crop.
+ scaling (list): ratio between a cropped region and the original image.
+ in [min, max] format.
+ num_attempts (int): number of tries before giving up.
+ allow_no_crop (bool): allow return without actually cropping them.
+ cover_all_box (bool): ensure all bboxes are covered in the final crop.
+ """
+
+ def __init__(self,
+ aspect_ratio=[.5, 2.],
+ thresholds=[.0, .1, .3, .5, .7, .9],
+ scaling=[.3, 1.],
+ num_attempts=50,
+ allow_no_crop=True,
+ cover_all_box=False):
+ self.aspect_ratio = aspect_ratio
+ self.thresholds = thresholds
+ self.scaling = scaling
+ self.num_attempts = num_attempts
+ self.allow_no_crop = allow_no_crop
+ self.cover_all_box = cover_all_box
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ if len(gt_bbox) == 0:
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+ # NOTE Original method attempts to generate one candidate for each
+ # threshold then randomly sample one from the resulting list.
+ # Here a short circuit approach is taken, i.e., randomly choose a
+ # threshold and attempt to find a valid crop, and simply return the
+ # first one found.
+ # The probability is not exactly the same, kinda resembling the
+ # "Monty Hall" problem. Actually carrying out the attempts will affect
+ # observability (just like opening doors in the "Monty Hall" game).
+ thresholds = list(self.thresholds)
+ if self.allow_no_crop:
+ thresholds.append('no_crop')
+ np.random.shuffle(thresholds)
+
+ for thresh in thresholds:
+ if thresh == 'no_crop':
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+ h, w, _ = im.shape
+ found = False
+ for i in range(self.num_attempts):
+ scale = np.random.uniform(*self.scaling)
+ min_ar, max_ar = self.aspect_ratio
+ aspect_ratio = np.random.uniform(
+ max(min_ar, scale**2), min(max_ar, scale**-2))
+ crop_h = int(h * scale / np.sqrt(aspect_ratio))
+ crop_w = int(w * scale * np.sqrt(aspect_ratio))
+ crop_y = np.random.randint(0, h - crop_h)
+ crop_x = np.random.randint(0, w - crop_w)
+ crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
+ iou = self._iou_matrix(
+ gt_bbox, np.array(
+ [crop_box], dtype=np.float32))
+ if iou.max() < thresh:
+ continue
+
+ if self.cover_all_box and iou.min() < thresh:
+ continue
+
+ cropped_box, valid_ids = self._crop_box_with_center_constraint(
+ gt_bbox, np.array(
+ crop_box, dtype=np.float32))
+ if valid_ids.size > 0:
+ found = True
+ break
+
+ if found:
+ im = self._crop_image(im, crop_box)
+ gt_bbox = np.take(cropped_box, valid_ids, axis=0)
+ gt_class = np.take(gt_class, valid_ids, axis=0)
+ gt_score = np.take(gt_score, valid_ids, axis=0)
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+ def _iou_matrix(self, a, b):
+ tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ area_o = (area_a[:, np.newaxis] + area_b - area_i)
+ return area_i / (area_o + 1e-10)
+
+ def _crop_box_with_center_constraint(self, box, crop):
+ cropped_box = box.copy()
+
+ cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
+ cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
+ cropped_box[:, :2] -= crop[:2]
+ cropped_box[:, 2:] -= crop[:2]
+
+ centers = (box[:, :2] + box[:, 2:]) / 2
+ valid = np.logical_and(crop[:2] <= centers,
+ centers < crop[2:]).all(axis=1)
+ valid = np.logical_and(
+ valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
+
+ return cropped_box, np.where(valid)[0]
+
+ def _crop_image(self, img, crop):
+ x1, y1, x2, y2 = crop
+ return img[y1:y2, x1:x2, :]
+
+
+class RandomFlip():
+ def __init__(self, prob=0.5, is_normalized=False):
+ """
+ Args:
+ prob (float): the probability of flipping image
+ is_normalized (bool): whether the bbox scale to [0,1]
+ """
+ self.prob = prob
+ self.is_normalized = is_normalized
+ if not (isinstance(self.prob, float) and
+ isinstance(self.is_normalized, bool)):
+ raise TypeError("{}: input type is invalid.".format(self))
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ """Filp the image and bounding box.
+ Operators:
+ 1. Flip the image numpy.
+ 2. Transform the bboxes' x coordinates.
+ (Must judge whether the coordinates are normalized!)
+ """
+
+ if not isinstance(im, np.ndarray):
+ raise TypeError("{}: image is not a numpy array.".format(self))
+ if len(im.shape) != 3:
+ raise ImageError("{}: image is not 3-dimensional.".format(self))
+ height, width, _ = im.shape
+ if np.random.uniform(0, 1) < self.prob:
+ im = im[:, ::-1, :]
+ if gt_bbox.shape[0] > 0:
+ oldx1 = gt_bbox[:, 0].copy()
+ oldx2 = gt_bbox[:, 2].copy()
+ if self.is_normalized:
+ gt_bbox[:, 0] = 1 - oldx2
+ gt_bbox[:, 2] = 1 - oldx1
+ else:
+ gt_bbox[:, 0] = width - oldx2 - 1
+ gt_bbox[:, 2] = width - oldx1 - 1
+ if gt_bbox.shape[0] != 0 and (
+ gt_bbox[:, 2] < gt_bbox[:, 0]).all():
+ m = "{}: invalid box, x2 should be greater than x1".format(
+ self)
+ raise ValueError(m)
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+
+class NormalizeBox(object):
+ """Transform the bounding box's coornidates to [0,1]."""
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ height, width, _ = im.shape
+ for i in range(gt_bbox.shape[0]):
+ gt_bbox[i][0] = gt_bbox[i][0] / width
+ gt_bbox[i][1] = gt_bbox[i][1] / height
+ gt_bbox[i][2] = gt_bbox[i][2] / width
+ gt_bbox[i][3] = gt_bbox[i][3] / height
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+
+class PadBox(object):
+ def __init__(self, num_max_boxes=50):
+ """
+ Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
+ Args:
+ num_max_boxes (int): the max number of bboxes
+ """
+ self.num_max_boxes = num_max_boxes
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ gt_num = min(self.num_max_boxes, len(gt_bbox))
+ num_max = self.num_max_boxes
+
+ pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
+ if gt_num > 0:
+ pad_bbox[:gt_num, :] = gt_bbox[:gt_num, :]
+ gt_bbox = pad_bbox
+
+ pad_class = np.zeros((num_max), dtype=np.int32)
+ if gt_num > 0:
+ pad_class[:gt_num] = gt_class[:gt_num, 0]
+ gt_class = pad_class
+
+ pad_score = np.zeros((num_max), dtype=np.float32)
+ if gt_num > 0:
+ pad_score[:gt_num] = gt_score[:gt_num, 0]
+ gt_score = pad_score
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+
+class BboxXYXY2XYWH(object):
+ """
+ Convert bbox XYXY format to XYWH format.
+ """
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2]
+ gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2.
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
+
+class RandomShape(object):
+ """
+ Randomly reshape a batch. If random_inter is True, also randomly
+ select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR,
+ cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is
+ False, use cv2.INTER_NEAREST.
+
+ Args:
+ sizes (list): list of int, random choose a size from these
+ random_inter (bool): whether to randomly interpolation, defalut true.
+ """
+
+ def __init__(self,
+ sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
+ random_inter=True):
+ self.sizes = sizes
+ self.random_inter = random_inter
+ self.interps = [
+ cv2.INTER_NEAREST,
+ cv2.INTER_LINEAR,
+ cv2.INTER_AREA,
+ cv2.INTER_CUBIC,
+ cv2.INTER_LANCZOS4,
+ ] if random_inter else []
+
+ def __call__(self, samples):
+ shape = np.random.choice(self.sizes)
+ method = np.random.choice(self.interps) if self.random_inter \
+ else cv2.INTER_NEAREST
+ for i in range(len(samples)):
+ im = samples[i][1]
+ h, w = im.shape[:2]
+ scale_x = float(shape) / w
+ scale_y = float(shape) / h
+ im = cv2.resize(
+ im, None, None, fx=scale_x, fy=scale_y, interpolation=method)
+ samples[i][1] = im
+ return samples
+
+
+class NormalizeImage(object):
+ def __init__(self,
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225],
+ scale=True,
+ channel_first=True):
+ """
+ Args:
+ mean (list): the pixel mean
+ std (list): the pixel variance
+ scale (bool): whether scale image to [0, 1]
+ channel_first (bool): whehter change [h, w, c] to [c, h, w]
+ """
+ self.mean = mean
+ self.std = std
+ self.scale = scale
+ self.channel_first = channel_first
+ if not (isinstance(self.mean, list) and isinstance(self.std, list) and
+ isinstance(self.scale, bool)):
+ raise TypeError("{}: input type is invalid.".format(self))
+ from functools import reduce
+ if reduce(lambda x, y: x * y, self.std) == 0:
+ raise ValueError('{}: std is invalid!'.format(self))
+
+ def __call__(self, samples):
+ """Normalize the image.
+ Operators:
+ 1. (optional) Scale the image to [0,1]
+ 2. Each pixel minus mean and is divided by std
+ 3. (optional) permute channel
+ """
+ for i in range(len(samples)):
+ im = samples[i][1]
+ im = im.astype(np.float32, copy=False)
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
+ if self.scale:
+ im = im / 255.0
+ im -= mean
+ im /= std
+ if self.channel_first:
+ im = im.transpose((2, 0, 1))
+ samples[i][1] = im
+ return samples
+
+
+def _iou_matrix(a, b):
+ tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+ area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ area_o = (area_a[:, np.newaxis] + area_b - area_i)
+ return area_i / (area_o + 1e-10)
+
+
+def _crop_box_with_center_constraint(box, crop):
+ cropped_box = box.copy()
+ cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
+ cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
+ cropped_box[:, :2] -= crop[:2]
+ cropped_box[:, 2:] -= crop[:2]
+ centers = (box[:, :2] + box[:, 2:]) / 2
+ valid = np.logical_and(
+ crop[:2] <= centers, centers < crop[2:]).all(axis=1)
+ valid = np.logical_and(
+ valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
+ return cropped_box, np.where(valid)[0]
+
+
+def random_crop(inputs):
+ aspect_ratios = [.5, 2.]
+ thresholds = [.0, .1, .3, .5, .7, .9]
+ scaling = [.3, 1.]
+
+ img, img_ids, gt_box, gt_label = inputs
+ h, w = img.shape[:2]
+
+ if len(gt_box) == 0:
+ return inputs
+
+ np.random.shuffle(thresholds)
+ for thresh in thresholds:
+ found = False
+ for i in range(50):
+ scale = np.random.uniform(*scaling)
+ min_ar, max_ar = aspect_ratios
+ ar = np.random.uniform(max(min_ar, scale**2),
+ min(max_ar, scale**-2))
+ crop_h = int(h * scale / np.sqrt(ar))
+ crop_w = int(w * scale * np.sqrt(ar))
+ crop_y = np.random.randint(0, h - crop_h)
+ crop_x = np.random.randint(0, w - crop_w)
+ crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
+ iou = _iou_matrix(gt_box, np.array([crop_box], dtype=np.float32))
+ if iou.max() < thresh:
+ continue
+
+ cropped_box, valid_ids = _crop_box_with_center_constraint(
+ gt_box, np.array(crop_box, dtype=np.float32))
+ if valid_ids.size > 0:
+ found = True
+ break
+
+ if found:
+ x1, y1, x2, y2 = crop_box
+ img = img[y1:y2, x1:x2, :]
+ gt_box = np.take(cropped_box, valid_ids, axis=0)
+ gt_label = np.take(gt_label, valid_ids, axis=0)
+ return img, img_ids, gt_box, gt_label
+
+ return inputs
+
+
+class ResizeImage(object):
+ def __init__(self,
+ target_size=0,
+ interp=cv2.INTER_CUBIC):
+ """
+ Rescale image to the specified target size.
+ If target_size is list, selected a scale randomly as the specified
+ target size.
+
+ Args:
+ target_size (int|list): the target size of image's short side,
+ multi-scale training is adopted when type is list.
+ interp (int): the interpolation method
+ """
+ self.interp = int(interp)
+ if not (isinstance(target_size, int) or isinstance(target_size, list)):
+ raise TypeError(
+ "Type of target_size is invalid. Must be Integer or List, now is {}".
+ format(type(target_size)))
+ self.target_size = target_size
+
+ def __call__(self, im_info, im, gt_bbox, gt_class, gt_score):
+ """ Resize the image numpy.
+ """
+ if not isinstance(im, np.ndarray):
+ raise TypeError("{}: image type is not numpy.".format(self))
+ if len(im.shape) != 3:
+ raise ImageError('{}: image is not 3-dimensional.'.format(self))
+ im_shape = im.shape
+ im_scale_x = float(self.target_size) / float(im_shape[1])
+ im_scale_y = float(self.target_size) / float(im_shape[0])
+ resize_w = self.target_size
+ resize_h = self.target_size
+
+ im = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale_x,
+ fy=im_scale_y,
+ interpolation=self.interp)
+
+ return [im_info, im, gt_bbox, gt_class, gt_score]
+
diff --git a/yolov3/visualizer.py b/yolov3/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4433df8606ec140fe08f197e445aea6df89bf445
--- /dev/null
+++ b/yolov3/visualizer.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from PIL import Image, ImageDraw
+
+import logging
+logger = logging.getLogger(__name__)
+
+__all__ = ['draw_bbox']
+
+
+def color_map(num_classes):
+ color_map = num_classes * [0, 0, 0]
+ for i in range(0, num_classes):
+ j = 0
+ lab = i
+ while lab:
+ color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+ color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+ color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+ j += 1
+ lab >>= 3
+ color_map = np.array(color_map).reshape(-1, 3)
+ return color_map
+
+
+def draw_bbox(image, catid2name, bboxes, threshold):
+ """
+ Draw bbox on image
+ """
+ bboxes = np.array(bboxes)
+ if bboxes.shape[1] != 6:
+ logger.info("No bbox detect")
+ return image
+
+ draw = ImageDraw.Draw(image)
+
+ catid2color = {}
+ color_list = color_map(len(catid2name))
+ for bbox in bboxes:
+ catid, score, xmin, ymin, xmax, ymax = bbox
+
+ if score < threshold:
+ continue
+
+ if catid not in catid2color:
+ idx = np.random.randint(len(color_list))
+ catid2color[catid] = color_list[idx]
+ color = tuple(catid2color[catid])
+
+ # draw bbox
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
+ (xmin, ymin)],
+ width=2,
+ fill=color)
+ logger.info("detect {} at {} score: {:.2f}".format(
+ catid2name[int(catid)], [xmin, ymin, xmax, ymax], score))
+
+ # draw label
+ text = "{} {:.2f}".format(catid2name[catid], score)
+ tw, th = draw.textsize(text)
+ draw.rectangle(
+ [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
+ draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
+
+ return image