提交 fe885ccc 编写于 作者: X xuezhong

Merge branch 'develop' of https://github.com/PaddlePaddle/models into add_ce

[submodule "fluid/SimNet"]
path = fluid/SimNet
[submodule "fluid/PaddleNLP/SimNet"]
path = fluid/PaddleNLP/SimNet
url = https://github.com/baidu/AnyQ.git
[submodule "fluid/LAC"]
path = fluid/LAC
url = https://github.com/baidu/lac
[submodule "fluid/Senta"]
path = fluid/Senta
url = https://github.com/baidu/Senta
[submodule "fluid/PaddleNLP/LAC"]
path = fluid/PaddleNLP/LAC
url = https://github.com/baidu/lac.git
[submodule "fluid/PaddleNLP/Senta"]
path = fluid/PaddleNLP/Senta
url = https://github.com/baidu/Senta.git
Subproject commit 66660503bb6e8f34adc4715ccf42cad77ed46ded
运行本目录下的程序示例需要使用 PaddlePaddle 最新的 develop branch 版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。
---
## Pyramidbox 人脸检测
## Table of Contents
......
......@@ -7,7 +7,6 @@
- [Introduction](#introduction)
- [Data preparation](#data-preparation)
- [Training](#training)
- [Finetuning](#finetuning)
- [Evaluation](#evaluation)
- [Inference and Visualization](#inference-and-visualization)
- [Appendix](#appendix)
......@@ -24,10 +23,10 @@ Running sample code in this directory requires PaddelPaddle Fluid v.1.0.0 and la
Faster RCNN model
</p>
1. Base conv layerAs a CNN objective dection, Faster RCNN extract feature maps using a basic convolutional network. The feature maps then can be shared by RPN and fc layers. This sampel uses [ResNet-50](https://arxiv.org/abs/1512.03385) as base conv layer.
2. Region Proposal Network (RPN)RPN generates proposals for detection。This block generates anchors by a set of size and ratio and classifies anchors into fore-ground and back-ground by softmax. Then refine anchors to obtain more precise proposals using box regression.
3. RoI pooling。This layer takes feature maps and proposals as input. The proposals are mapped to feature maps and pooled to the same size. The output are sent to fc layers for classification and regression.
4. Detection layerUsing the output of roi pooling to compute the class and locatoin of each proposal in two fc layers.
1. Base conv layer. As a CNN objective dection, Faster RCNN extract feature maps using a basic convolutional network. The feature maps then can be shared by RPN and fc layers. This sampel uses [ResNet-50](https://arxiv.org/abs/1512.03385) as base conv layer.
2. Region Proposal Network (RPN). RPN generates proposals for detection。This block generates anchors by a set of size and ratio and classifies anchors into fore-ground and back-ground by softmax. Then refine anchors to obtain more precise proposals using box regression.
3. RoI Align. This layer takes feature maps and proposals as input. The proposals are mapped to feature maps and pooled to the same size. The output are sent to fc layers for classification and regression. RoIPool and RoIAlign are used separately to this layer and it can be set in roi\_func in config.py.
4. Detection layer. Using the output of roi pooling to compute the class and locatoin of each proposal in two fc layers.
## Data preparation
......@@ -42,10 +41,9 @@ Train the model on [MS-COCO dataset](http://cocodataset.org/#download), download
After data preparation, one can start the training step by:
python train.py \
--max_size=1333 \
--scales=[800] \
--batch_size=8 \
--model_save_dir=output/
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model}
--data_dir=${path_to_data}
- Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train.
- For more help on arguments:
......@@ -83,7 +81,7 @@ To train the model, [cocoapi](https://github.com/cocodataset/cocoapi) is needed.
**model configuration:**
* Use RoIPooling.
* Use RoIAlign and RoIPool separately.
* NMS threshold=0.7. During training, pre\_nms=12000, post\_nms=2000; during test, pre\_nms=6000, post\_nms=1000.
* In generating proposal lables, fg\_fraction=0.25, fg\_thresh=0.5, bg\_thresh_hi=0.5, bg\_thresh\_lo=0.0.
* In rpn target assignment, rpn\_fg\_fraction=0.5, rpn\_positive\_overlap=0.7, rpn\_negative\_overlap=0.3.
......@@ -102,20 +100,10 @@ Training result is shown as below:
<img src="image/train_loss.jpg" height=500 width=650 hspace='10'/> <br />
Faster RCNN train loss
</p>
* Fluid all padding: Each image padding to 1333\*1333.
* Fluid minibatch padding: Images in one batch padding to the same size. This method is same as detectron.
* Fluid no padding: Images without padding.
## Finetuning
Finetuning is to finetune model weights in a specific task by loading pretrained weights. After initializing ```pretrained_model```, one can finetune a model as:
python train.py
--max_size=1333 \
--scales=800 \
--pretrained_model=${path_to_pretrain_model} \
--batch_size= 8\
--model_save_dir=output/
* Fluid RoIPool minibatch padding: Use RoIPool. Images in one batch padding to the same size. This method is same as detectron.
* Fluid RoIpool no padding: Use RoIPool. Images without padding.
* Fluid RoIAlign no padding: Use RoIAlign. Images without padding.
## Evaluation
......@@ -125,10 +113,9 @@ Evaluation is to evaluate the performance of a trained model. This sample provid
python eval_coco_map.py \
--dataset=coco2017 \
--pretrained_mode=${path_to_pretrain_model} \
--batch_size=1 \
--nms_threshold=0.5 \
--score_threshold=0.05
--pretrained_model=${path_to_pretrain_model} \
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to eval.
Evalutaion result is shown as below:
<p align="center">
......@@ -136,16 +123,17 @@ Evalutaion result is shown as below:
Faster RCNN mAP
</p>
| Model | Batch size | Max iteration | mAP |
| :------------------------------ | :------------: | :-------------------:|------: |
| Detectron | 8 | 180000 | 0.315 |
| Fluid minibatch padding | 8 | 180000 | 0.314 |
| Fluid all padding | 8 | 180000 | 0.308 |
| Fluid no padding |8 | 180000 | 0.316 |
* Fluid all padding: Each image padding to 1333\*1333.
* Fluid minibatch padding: Images in one batch padding to the same size. This method is same as detectron.
* Fluid no padding: Images without padding.
| Model | RoI function | Batch size | Max iteration | mAP |
| :--------------- | :--------: | :------------: | :------------------: |------: |
| Detectron_RoIPool | RoIPool | 8 | 180000 | 0.315 |
| Fluid RoIPool minibatch padding | RoIPool | 8 | 180000 | 0.314 |
| Fluid RoIPool no padding | RoIPool | 8 | 180000 | 0.316 |
| Detectron_RoIAlign | RoIAlign | 8 | 180000 | 0.346 |
| Fluid RoIAlign no padding | RoIAlign | 8 | 180000 | 0.345 |
* Fluid RoIPool minibatch padding: Use RoIPool. Images in one batch padding to the same size. This method is same as detectron.
* Fluid RoIPool no padding: Images without padding.
* Fluid RoIAlign no padding: Images without padding.
## Inference and Visualization
......
......@@ -7,7 +7,6 @@
- [简介](#简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [参数微调](#参数微调)
- [模型评估](#模型评估)
- [模型推断及可视化](#模型推断及可视化)
- [附录](#附录)
......@@ -26,7 +25,7 @@ Faster RCNN 目标检测模型
1. 基础卷积层。作为一种卷积神经网络目标检测方法,Faster RCNN首先使用一组基础的卷积网络提取图像的特征图。特征图被后续RPN层和全连接层共享。本示例采用[ResNet-50](https://arxiv.org/abs/1512.03385)作为基础卷积层。
2. 区域生成网络(RPN)。RPN网络用于生成候选区域(proposals)。该层通过一组固定的尺寸和比例得到一组锚点(anchors), 通过softmax判断锚点属于前景或者背景,再利用区域回归修正锚点从而获得精确的候选区域。
3. RoI池化。该层收集输入的特征图和候选区域,将候选区域映射到特征图中并池化为统一大小的区域特征图,送入全连接层判定目标类别
3. RoI Align。该层收集输入的特征图和候选区域,将候选区域映射到特征图中并池化为统一大小的区域特征图,送入全连接层判定目标类别, 该层可选用RoIPool和RoIAlign两种方式,在config.py中设置roi\_func
4. 检测层。利用区域特征图计算候选区域的类别,同时再次通过区域回归获得检测框最终的精确位置。
## 数据准备
......@@ -41,11 +40,9 @@ Faster RCNN 目标检测模型
数据准备完毕后,可以通过如下的方式启动训练:
python train.py \
--max_size=1333 \
--scales=[800] \
--batch_size=8 \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model}
--data_dir=${path_to_data}
- 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
- 可选参数见:
......@@ -74,11 +71,11 @@ Faster RCNN 目标检测模型
# not to install the COCO API into global site-packages
python2 setup.py install --user
**数据读取器说明:** 数据读取器定义在reader.py中。所有图像将短边等比例缩放至`scales`,若长边大于`max_size`, 则再次将长边等比例缩放至`max_iter`。在训练阶段,对图像采用水平翻转。支持将同一个batch内的图像padding为相同尺寸。
**数据读取器说明:** 数据读取器定义在reader.py中。所有图像将短边等比例缩放至`scales`,若长边大于`max_size`, 则再次将长边等比例缩放至`max_size`。在训练阶段,对图像采用水平翻转。支持将同一个batch内的图像padding为相同尺寸。
**模型设置:**
* 使用RoIPooling
* 分别使用RoIAlign和RoIPool两种方法
* 训练过程pre\_nms=12000, post\_nms=2000,测试过程pre\_nms=6000, post\_nms=1000。nms阈值为0.7。
* RPN网络得到labels的过程中,fg\_fraction=0.25,fg\_thresh=0.5,bg\_thresh_hi=0.5,bg\_thresh\_lo=0.0
* RPN选择anchor时,rpn\_fg\_fraction=0.5,rpn\_positive\_overlap=0.7,rpn\_negative\_overlap=0.3
......@@ -89,9 +86,10 @@ Faster RCNN 目标检测模型
<img src="image/train_loss.jpg" height=500 width=650 hspace='10'/> <br />
Faster RCNN 训练loss
</p>
* Fluid all padding: 每张图像填充为1333\*1333大小。
* Fluid minibatch padding: 同一个batch内的图像填充为相同尺寸。该方法与detectron处理相同。
* Fluid no padding: 不对图像做填充处理。
* Fluid RoIPool minibatch padding: 使用RoIPool,同一个batch内的图像填充为相同尺寸。该方法与detectron处理相同。
* Fluid RoIPool no padding: 使用RoIPool,不对图像做填充处理。
* Fluid RoIAlign no padding: 使用RoIAlign,不对图像做填充处理。
**训练策略:**
......@@ -109,10 +107,9 @@ Faster RCNN 训练loss
python eval_coco_map.py \
--dataset=coco2017 \
--pretrained_mode=${path_to_pretrain_model} \
--batch_size=1 \
--nms_threshold=0.5 \
--score_threshold=0.05
--pretrained_model=${path_to_pretrain_model} \
- 通过设置export CUDA\_VISIBLE\_DEVICES=0指定单卡GPU评估。
下图为模型评估结果:
<p align="center">
......@@ -120,16 +117,20 @@ Faster RCNN 训练loss
Faster RCNN mAP
</p>
| 模型 | 批量大小 | 迭代次数 | mAP |
| :------------------------------ | :------------: | :------------------: |------: |
| Detectron | 8 | 180000 | 0.315 |
| Fluid minibatch padding | 8 | 180000 | 0.314 |
| Fluid all padding | 8 | 180000 | 0.308 |
| Fluid no padding |8 | 180000 | 0.316 |
| 模型 | RoI处理方式 | 批量大小 | 迭代次数 | mAP |
| :--------------- | :--------: | :------------: | :------------------: |------: |
| Detectron RoIPool | RoIPool | 8 | 180000 | 0.315 |
| Fluid RoIPool minibatch padding | RoIPool | 8 | 180000 | 0.314 |
| Fluid RoIPool no padding | RoIPool | 8 | 180000 | 0.316 |
| Detectron RoIAlign | RoIAlign | 8 | 180000 | 0.346 |
| Fluid RoIAlign no padding | RoIAlign | 8 | 180000 | 0.345 |
* Fluid all padding: 每张图像填充为1333\*1333大小
* Fluid minibatch padding: 同一个batch内的图像填充为相同尺寸。该方法与detectron处理相同
* Fluid no padding: 不对图像做填充处理。
* Fluid RoIPool minibatch padding: 使用RoIPool,同一个batch内的图像填充为相同尺寸。该方法与detectron处理相同
* Fluid RoIPool no padding: 使用RoIPool,不对图像做填充处理
* Fluid RoIAlign no padding: 使用RoIAlign,不对图像做填充处理。
## 模型推断及可视化
......
......@@ -52,7 +52,7 @@ def train():
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
step_num = len(lr_steps)
step_num = len(cfg.lr_steps)
values = [learning_rate * (gamma**i) for i in range(step_num + 1)]
optimizer = fluid.optimizer.Momentum(
......
......@@ -102,6 +102,7 @@ def coco(mode,
roidb_perm.rotate(-1)
if roidb_cur >= len(roidbs):
roidb_perm = deque(np.random.permutation(roidbs))
roidb_cur = 0
im, gt_boxes, gt_classes, is_crowd, im_info, im_id = roidb_reader(
roidb, mode)
if gt_boxes.shape[0] == 0:
......
......@@ -9,7 +9,7 @@ Before getting started, please make sure you have go throught the imagenet [Data
1. The entrypoint file is `dist_train.py`, some important flags are as follows:
- `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc..
- `model`, the model to run with, default is the fine tune model `DistResnet`.
- `batch_size`, the batch_size per device.
- `update_method`, specify the update method, can choose from local, pserver or nccl2.
- `device`, use CPU or GPU device.
......@@ -35,14 +35,14 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
1. launch parameter server process
``` python
``` bash
PADDLE_TRAINING_ROLE=PSERVER \
PADDLE_TRAINERS=4 \
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_CURRENT_IP=192.168.0.100 \
PADDLE_PSERVER_PORT=7164 \
python dist_train.py \
--model=ResNet50 \
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=CPU \
......@@ -51,34 +51,33 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
1. launch trainer process
``` python
``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
PADDLE_PSERVER_PORT=7164 \
python dist_train.py \
--model=ResNet50 \
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=GPU \
--data_dir=../data/ILSVRC2012
```
### NCCL2 Collective Mode
1. launch trainer process
``` python
``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
python dist_train.py \
--model=ResNet50 \
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--update_method=nccl2 \
--device=GPU \
--data_dir=../data/ILSVRC2012
```
......@@ -101,13 +100,37 @@ Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625]
Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625]
```
The training accucacys top1 of local training, distributed training with NCCL2 and parameter server architecture on the ResNet50 model are shown in the below figure:
The below figure shows top 1 train accuracy for local training with 8 GPUs and distributed training
with 32 GPUs, and also distributed training with batch merge feature turned on. Note that the
red curve is trained with origin model configuration, which does not have the warmup and some detailed
modifications.
For distributed training with 32GPUs using `--model DistResnet` we can achieve test accuracy 75.5% after
90 passes of training (the test accuracy is not shown in below figure). We can also achieve this result
using "batch merge" feature by setting `--multi_batch_repeat 4` and with higher throughput.
<p align="center">
<img src="../images/resnet50_32gpus-acc1.png" height=300 width=528 > <br/>
Training acc1 curves
Training top-1 accuracy curves
</p>
### Finetuning for Distributed Training
The default resnet50 distributed training config is based on this paper: https://arxiv.org/pdf/1706.02677.pdf
- use `--model DistResnet`
- we use 32 P40 GPUs with 4 Nodes, each has 8 GPUs
- we set `batch_size=32` for each GPU, in `batch_merge=on` case, we repeat 4 times before communicating with pserver.
- learning rate starts from 0.1 and warm up to 0.4 in 5 passes(because we already have gradient merging,
so we only need to linear scale up to trainer count) using 4 nodes.
- using batch_merge (`--multi_batch_repeat 4`) can make better use of GPU computing power and increase the
total training throughput. Because in the fine-tune configuration, we have to use `batch_size=32` per GPU,
and recent GPU is so fast that the communication between nodes may delay the total speed. In batch_merge mode
we run several batches forward and backward computation, then merge the gradients and send to pserver for
optimization, we use different batch norm mean and variance variable in each repeat so that adding repeats
behaves the same as adding more GPUs.
### Performance
TBD
......@@ -26,9 +26,84 @@ import six
import sys
sys.path.append("..")
import models
from args import *
from reader import train, val
def parse_args():
parser = argparse.ArgumentParser('Distributed Image Classification Training.')
parser.add_argument(
'--model',
type=str,
default='DistResNet',
help='The model to run.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size per device.')
parser.add_argument(
'--multi_batch_repeat', type=int, default=1, help='Batch merge repeats.')
parser.add_argument(
'--learning_rate', type=float, default=0.1, help='The learning rate.')
parser.add_argument(
'--pass_num', type=int, default=90, help='The number of passes.')
parser.add_argument(
'--data_format',
type=str,
default='NCHW',
choices=['NCHW', 'NHWC'],
help='The data data_format, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--gpus',
type=int,
default=1,
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
parser.add_argument(
'--cpus',
type=int,
default=1,
help='If cpus > 1, will set ParallelExecutor to use multiple threads.')
parser.add_argument(
'--no_test',
action='store_true',
help='If set, do not test the testset during training.')
parser.add_argument(
'--memory_optimize',
action='store_true',
help='If set, optimize runtime memory before start.')
parser.add_argument(
'--update_method',
type=str,
default='local',
choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
parser.add_argument(
'--reduce_strategy',
type=str,
choices=['reduce', 'all_reduce'],
default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--data_dir',
type=str,
default="../data/ILSVRC2012",
help="The ImageNet dataset root dir."
)
args = parser.parse_args()
return args
def get_model(args, is_train, main_prog, startup_prog):
pyreader = None
class_dim = 1000
......@@ -51,7 +126,7 @@ def get_model(args, is_train, main_prog, startup_prog):
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
input, label = fluid.layers.read_file(pyreader)
model_def = models.__dict__[args.model]()
model_def = models.__dict__[args.model](layers=50, is_train=is_train)
predict = model_def.net(input, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=predict, label=label)
......@@ -60,89 +135,64 @@ def get_model(args, is_train, main_prog, startup_prog):
batch_acc1 = fluid.layers.accuracy(input=predict, label=label, k=1)
batch_acc5 = fluid.layers.accuracy(input=predict, label=label, k=5)
# configure optimize
optimizer = None
if is_train:
start_lr = args.learning_rate
# n * worker * repeat
end_lr = args.learning_rate * trainer_count * args.multi_batch_repeat
total_images = 1281167 / trainer_count
step = int(total_images / (args.batch_size * args.gpus) + 1)
epochs = [30, 60, 90]
step = int(total_images / (args.batch_size * args.gpus * args.multi_batch_repeat) + 1)
warmup_steps = step * 5 # warmup 5 passes
epochs = [30, 60, 80]
bd = [step * e for e in epochs]
base_lr = args.learning_rate
base_lr = end_lr
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
learning_rate=models.learning_rate.lr_warmup(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
warmup_steps, start_lr, end_lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost)
if args.memory_optimize:
fluid.memory_optimize(main_prog)
batched_reader = None
pyreader.decorate_paddle_reader(
paddle.batch(
reader if args.no_random else paddle.reader.shuffle(
reader, buf_size=5120),
reader,
batch_size=args.batch_size))
return avg_cost, optimizer, [batch_acc1,
batch_acc5], batched_reader, pyreader
def append_nccl2_prepare(trainer_id, startup_prog):
if trainer_id >= 0:
# append gen_nccl_id at the end of startup program
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
num_trainers = len(worker_endpoints)
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
worker_endpoints.remove(current_endpoint)
nccl_id_var = startup_prog.global_block().create_var(
name="NCCLID",
persistable=True,
type=fluid.core.VarDesc.VarType.RAW)
startup_prog.global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs={
"endpoint": current_endpoint,
"endpoint_list": worker_endpoints,
"trainer_id": trainer_id
})
return nccl_id_var, num_trainers, trainer_id
else:
raise Exception("must set positive PADDLE_TRAINER_ID env variables for "
"nccl-based dist train.")
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, trainers=','.join(worker_endpoints),
current_endpoint=current_endpoint,
startup_program=startup_prog)
def dist_transpile(trainer_id, args, train_prog, startup_prog):
if trainer_id < 0:
return None, None
# the port of all pservers, needed by both trainer and pserver
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers = int(os.getenv("PADDLE_TRAINERS"))
# the IP of the local machine, needed by pserver only
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = fluid.DistributeTranspilerConfig()
......@@ -150,8 +200,6 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to
# generate different program for train and test.
program=train_prog,
pservers=pserver_endpoints,
trainers=trainers,
......@@ -170,17 +218,58 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def test_parallel(exe, test_args, args, test_prog, feeder):
def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for i in range(num_repeats):
for op in startup_prog.global_block().ops:
if op.type == "fill_constant":
for oname in op.output_arg_names:
if oname in repeat_vars:
var = startup_prog.global_block().var(oname)
repeat_var_name = "%s.repeat.%d" % (oname, i)
repeat_var = startup_prog.global_block().create_var(
name=repeat_var_name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable
)
main_prog.global_block()._clone_variable(repeat_var)
startup_prog.global_block().append_op(
type="fill_constant",
inputs={},
outputs={"Out": repeat_var},
attrs=op.all_attrs()
)
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for vname in repeat_vars:
real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
orig_var = fluid.global_scope().find_var(vname).get_tensor()
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
def test_single(exe, test_args, args, test_prog):
acc_evaluators = []
for i in six.moves.xrange(len(test_args[2])):
for i in xrange(len(test_args[2])):
acc_evaluators.append(fluid.metrics.Accuracy())
to_fetch = [v.name for v in test_args[2]]
test_args[4].start()
while True:
try:
acc_rets = exe.run(fetch_list=to_fetch)
acc_rets = exe.run(program=test_prog, fetch_list=to_fetch)
for i, e in enumerate(acc_evaluators):
e.update(
value=np.array(acc_rets[i]), weight=args.batch_size)
......@@ -191,23 +280,28 @@ def test_parallel(exe, test_args, args, test_prog, feeder):
return [e.eval() for e in acc_evaluators]
# NOTE: only need to benchmark using parallelexe
def train_parallel(train_args, test_args, args, train_prog, test_prog,
startup_prog, nccl_id_var, num_trainers, trainer_id):
over_all_start = time.time()
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
feeder = None
if nccl_id_var and trainer_id == 0:
#FIXME(wuyi): wait other trainer to start listening
time.sleep(30)
startup_exe = fluid.Executor(place)
if args.multi_batch_repeat > 1:
append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat)
startup_exe.run(startup_prog)
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus
strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy()
if args.multi_batch_repeat > 1:
pass_builder = build_strategy._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.multi_batch_repeat)
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce
......@@ -233,35 +327,21 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
num_trainers=num_trainers,
trainer_id=trainer_id)
if not args.no_test:
if args.update_method == "pserver":
test_scope = None
else:
# NOTE: use an empty scope to avoid test exe using NCCLID
test_scope = fluid.Scope()
test_exe = fluid.ParallelExecutor(
True, main_program=test_prog, share_vars_from=exe)
pyreader = train_args[4]
for pass_id in range(args.pass_num):
num_samples = 0
iters = 0
start_time = time.time()
batch_id = 0
pyreader.start()
while True:
if iters == args.iterations:
break
if iters == args.skip_batch_num:
start_time = time.time()
num_samples = 0
fetch_list = [avg_loss.name]
acc_name_list = [v.name for v in train_args[2]]
fetch_list.extend(acc_name_list)
try:
fetch_ret = exe.run(fetch_list)
if batch_id % 30 == 0:
fetch_ret = exe.run(fetch_list)
else:
fetch_ret = exe.run([])
except fluid.core.EOFException as eof:
break
except fluid.core.EnforceNotMet as ex:
......@@ -269,20 +349,19 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
break
num_samples += args.batch_size * args.gpus
iters += 1
if batch_id % 1 == 0:
if batch_id % 30 == 0:
fetched_data = [np.mean(np.array(d)) for d in fetch_ret]
print("Pass %d, batch %d, loss %s, accucacys: %s" %
(pass_id, batch_id, fetched_data[0], fetched_data[1:]))
batch_id += 1
print_train_time(start_time, time.time(), num_samples)
pyreader.reset() # reset reader handle
pyreader.reset()
if not args.no_test and test_args[2]:
test_feeder = None
test_ret = test_parallel(test_exe, test_args, args, test_prog,
test_feeder)
if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog)
test_ret = test_single(startup_exe, test_args, args, test_prog)
print("Pass: %d, Test Accuracy: %s\n" %
(pass_id, [np.mean(np.array(v)) for v in test_ret]))
......@@ -316,8 +395,6 @@ def main():
args = parse_args()
print_arguments(args)
print_paddle_envs()
if args.no_random:
fluid.default_startup_program().random_seed = 1
# the unique trainer id, starting from 0, needed by trainer
# only
......@@ -340,7 +417,7 @@ def main():
raise Exception(
"Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog])
if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([nccl_id_var, num_trainers, trainer_id])
train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
......
......@@ -3,6 +3,8 @@ from .mobilenet import MobileNet
from .googlenet import GoogleNet
from .vgg import VGG11, VGG13, VGG16, VGG19
from .resnet import ResNet50, ResNet101, ResNet152
from .resnet_dist import DistResNet
from .inception_v4 import InceptionV4
from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_32x4d
from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131
import learning_rate
......@@ -20,3 +20,31 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120):
decayed_lr = learning_rate * \
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
return decayed_lr
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
""" Applies linear learning rate warmup for distributed training
Argument learning_rate can be float or a Variable
lr = lr + (warmup_rate * step / warmup_steps)
"""
assert(isinstance(end_lr, float))
assert(isinstance(start_lr, float))
linear_step = end_lr - start_lr
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate_warmup")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (global_step / warmup_steps)
fluid.layers.tensor.assign(decayed_lr, lr)
with switch.default():
fluid.layers.tensor.assign(learning_rate, lr)
return lr
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math
__all__ = ["DistResNet"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class DistResNet():
def __init__(self, layers=50, is_train=True):
self.params = train_parameters
self.layers = layers
self.is_train = is_train
self.weight_decay = 1e-4
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv),
regularizer=fluid.regularizer.L2Decay(self.weight_decay)),
bias_attr=fluid.ParamAttr(
regularizer=fluid.regularizer.L2Decay(self.weight_decay))
)
return out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
bn_init_value=1.0):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False,
param_attr=fluid.ParamAttr(regularizer=fluid.regularizer.L2Decay(self.weight_decay)))
return fluid.layers.batch_norm(
input=conv, act=act, is_test=not self.is_train,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(bn_init_value),
regularizer=None))
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input
def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
# NOTE: default bias is 0.0 already
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, bn_init_value=0.0)
short = self.shortcut(input, num_filters * 4, stride)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
......@@ -21,3 +21,4 @@ data/pascalvoc/trainval.txt
log*
*.log
ssd_mobilenet_v1_pascalvoc*
quant_model
The minimum PaddlePaddle version needed for the code sample in this directory is the latest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
## SSD Object Detection
## Table of Contents
......
运行本目录下的程序示例需要使用 PaddlePaddle 最新的 develop branch 版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。
---
## SSD 目标检测
## Table of Contents
......
## Quantization-aware training for SSD
### Introduction
The quantization-aware training used in this experiments is introduced in [fixed-point quantization desigin](https://gthub.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/quantization/fixed_point_quantization.md). Since quantization-aware training is still an active area of research and experimentation,
here, we just give an simple quantization training usage in Fluid based on MobileNet-SSD model, and more other exeperiments are still needed, like how to quantization traning by considering fusing batch normalization and convolution/fully-connected layers, channel-wise quantization of weights and so on.
A Python transpiler is used to rewrite Fluid training program or evaluation program for quantization-aware training:
```python
#startup_prog = fluid.Program()
#train_prog = fluid.Program()
#loss = build_program(
# main_prog=train_prog,
# startup_prog=startup_prog,
# is_train=True)
#build_program(
# main_prog=test_prog,
# startup_prog=startup_prog,
# is_train=False)
#test_prog = test_prog.clone(for_test=True)
# above is an pseudo code
transpiler = fluid.contrib.QuantizeTranspiler(
weight_bits=8,
activation_bits=8,
activation_quantize_type='abs_max', # or 'range_abs_max'
weight_quantize_type='abs_max')
# note, transpiler.training_transpile will rewrite train_prog
# startup_prog is needed since it needs to insert and initialize
# some state variable
transpiler.training_transpile(train_prog, startup_prog)
transpiler.training_transpile(test_prog, startup_prog)
```
According to above design, this transpiler inserts fake quantization and de-quantization operation for each convolution operation (including depthwise convolution operation) and fully-connected operation. These quantizations take affect on weights and activations.
In the design, we introduce dynamic quantization and static quantization strategies for different activation quantization methods. In the expriments, when set `activation_quantize_type` to `abs_max`, it is dynamic quantization. That is to say, the quantization scale (maximum of absolute value) of activation will be calculated each mini-batch during inference. When set `activation_quantize_type` to `range_abs_max`, a quantization scale for inference period will be calculated during training. Following part will introduce how to train.
### Quantization-aware training
The training is fine-tuned on the well-trained MobileNet-SSD model. So download model at first:
```
wget http://paddlemodels.bj.bcebos.com/ssd_mobilenet_v1_pascalvoc.tar.gz
```
- dynamic quantization:
```python
python main_quant.py \
--data_dir=$PascalVOC_DIR$ \
--mode='train' \
--init_model=ssd_mobilenet_v1_pascalvoc \
--act_quant_type='abs_max' \
--epoc_num=20 \
--learning_rate=0.0001 \
--batch_size=64 \
--model_save_dir=$OUTPUT_DIR$
```
Since fine-tuned on a well-trained model, we use a small start learnng rate 0.0001, and train 20 epocs.
- static quantization:
```python
python main_quant.py \
--data_dir=$PascalVOC_DIR$ \
--mode='train' \
--init_model=ssd_mobilenet_v1_pascalvoc \
--act_quant_type='range_abs_max' \
--epoc_num=80 \
--learning_rate=0.001 \
--lr_epochs=30,60 \
--lr_decay_rates=1,0.1,0.01 \
--batch_size=64 \
--model_save_dir=$OUTPUT_DIR$
```
Here, train 80 epocs, learning rate decays at 30 and 60 epocs by 0.1 every time. Users can adjust these hype-parameters.
### Convert to inference model
As described in the design documentation, the inference graph is a little different from training, the difference is the de-quantization operation is before or after conv/fc. This is equivalent in training due to linear operation of conv/fc and de-quantization and functions' commutative law. But for inference, it needs to convert the graph, `fluid.contrib.QuantizeTranspiler.freeze_program` is used to do this:
```python
#startup_prog = fluid.Program()
#test_prog = fluid.Program()
#test_py_reader, map_eval, nmsed_out, image = build_program(
# main_prog=test_prog,
# startup_prog=startup_prog,
# train_params=configs,
# is_train=False)
#test_prog = test_prog.clone(for_test=True)
#transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8,
# activation_bits=8,
# activation_quantize_type=act_quant_type,
# weight_quantize_type='abs_max')
#transpiler.training_transpile(test_prog, startup_prog)
#place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
#exe = fluid.Executor(place)
#exe.run(startup_prog)
def if_exist(var):
return os.path.exists(os.path.join(init_model, var.name))
fluid.io.load_vars(exe, init_model, main_program=test_prog,
predicate=if_exist)
# freeze the rewrited training program
# freeze after load parameters, it will quantized weights
transpiler.freeze_program(test_prog, place)
```
Users can evaluate the converted model by:
```
python main_quant.py \
--data_dir=$PascalVOC_DIR$ \
--mode='test' \
--init_model=$MODLE_DIR$ \
--model_save_dir=$MobileNet_SSD_8BIT_MODEL$
```
You also can check the 8-bit model by the inference scripts
```
python main_quant.py \
--mode='infer' \
--init_model=$MobileNet_SSD_8BIT_MODEL$ \
--confs_threshold=0.5 \
--image_path='/data/PascalVOC/VOCdevkit/VOC2007/JPEGImages/002271.jpg'
```
See 002271.jpg for the visualized image with bbouding boxes.
### Results
Results of MobileNet-v1-SSD 300x300 model on PascalVOC dataset.
| Model | mAP |
|:---------------------------------------:|:------------------:|
|Floating point: 32bit | 73.32% |
|Fixed point: 8bit, dynamic quantization | 72.77% |
|Fixed point: 8bit, static quantization | 72.45% |
As mentioned above, other experiments, like how to quantization traning by considering fusing batch normalization and convolution/fully-connected layers, channel-wise quantization of weights, quantizated weights type with uint8 instead of int8 and so on.
......@@ -77,12 +77,13 @@ def eval(args, data_args, test_list, batch_size, model_dir=None):
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
# yapf: disable
if model_dir:
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, main_program=test_prog, predicate=if_exist)
# yapf: enable
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(
exe, model_dir, main_program=test_prog, predicate=if_exist)
test_reader = reader.test(data_args, test_list, batch_size=batch_size)
test_py_reader.decorate_paddle_reader(test_reader)
......@@ -96,7 +97,7 @@ def eval(args, data_args, test_list, batch_size, model_dir=None):
if batch_id % 10 == 0:
print("Batch {0}, map {1}".format(batch_id, test_map))
batch_id += 1
except fluid.core.EOFException:
except (fluid.core.EOFException, StopIteration):
test_py_reader.reset()
print("Test model {0}, map {1}".format(model_dir, test_map))
......
......@@ -85,11 +85,11 @@ def draw_bounding_box_on_image(image_path, nms_out, confs_threshold,
im_width, im_height = image.size
for dt in nms_out:
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
if score < confs_threshold:
if dt[1] < confs_threshold:
continue
category_id = dt[0]
bbox = dt[2:]
xmin, ymin, xmax, ymax = bbox
xmin, ymin, xmax, ymax = clip_bbox(dt[2:])
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line(
......@@ -104,6 +104,14 @@ def draw_bounding_box_on_image(image_path, nms_out, confs_threshold,
image.save(image_name)
def clip_bbox(bbox):
xmin = max(min(bbox[0], 1.), 0.)
ymin = max(min(bbox[1], 1.), 0.)
xmax = max(min(bbox[2], 1.), 0.)
ymax = max(min(bbox[3], 1.), 0.)
return xmin, ymin, xmax, ymax
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
......
import os
import time
import numpy as np
import argparse
import functools
import shutil
import math
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
from train import build_program
from train import train_parameters
from infer import draw_bounding_box_on_image
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('learning_rate', float, 0.0001, "Learning rate.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('epoc_num', int, 20, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('parallel', bool, True, "Whether train in parallel on multi-devices.")
add_arg('model_save_dir', str, 'quant_model', "The path to save model.")
add_arg('init_model', str, 'ssd_mobilenet_v1_pascalvoc', "The init model path.")
add_arg('ap_version', str, '11point', "mAP version can be integral or 11point.")
add_arg('image_shape', str, '3,300,300', "Input image shape.")
add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.")
add_arg('lr_epochs', str, '30,60', "The learning decay steps.")
add_arg('lr_decay_rates', str, '1,0.1,0.01', "The learning decay rates for each step.")
add_arg('data_dir', str, 'data/pascalvoc', "Data directory")
add_arg('act_quant_type', str, 'abs_max', "Quantize type of activation, whicn can be abs_max or range_abs_max")
add_arg('image_path', str, '', "The image used to inference and visualize.")
add_arg('confs_threshold', float, 0.5, "Confidence threshold to draw bbox.")
add_arg('mode', str, 'train', "Job mode can be one of ['train', 'test', 'infer'].")
#yapf: enable
def test(exe, test_prog, map_eval, test_py_reader):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
test_py_reader.start()
try:
batch = 0
while True:
test_map, = exe.run(test_prog, fetch_list=[accum_map])
if batch % 10 == 0:
print("Batch {0}, map {1}".format(batch, test_map))
batch += 1
except fluid.core.EOFException:
test_py_reader.reset()
finally:
test_py_reader.reset()
print("Test map {0}".format(test_map))
return test_map
def save_model(exe, main_prog, model_save_dir, postfix):
model_path = os.path.join(model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
fluid.io.save_persistables(exe, model_path, main_program=main_prog)
def train(args,
data_args,
train_params,
train_file_list,
val_file_list):
model_save_dir = args.model_save_dir
init_model = args.init_model
epoc_num = args.epoc_num
use_gpu = args.use_gpu
parallel = args.parallel
is_shuffle = True
act_quant_type = args.act_quant_type
if use_gpu:
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
batch_size = train_params['batch_size']
batch_size_per_device = batch_size // devices_num
iters_per_epoc = train_params["train_images"] // batch_size
num_workers = 4
startup_prog = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
train_py_reader, loss = build_program(
main_prog=train_prog,
startup_prog=startup_prog,
train_params=train_params,
is_train=True)
test_py_reader, map_eval, _, _ = build_program(
main_prog=test_prog,
startup_prog=startup_prog,
train_params=train_params,
is_train=False)
test_prog = test_prog.clone(for_test=True)
transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8,
activation_bits=8,
activation_quantize_type=act_quant_type,
weight_quantize_type='abs_max')
transpiler.training_transpile(train_prog, startup_prog)
transpiler.training_transpile(test_prog, startup_prog)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
if init_model:
print('Load init model %s.' % init_model)
def if_exist(var):
return os.path.exists(os.path.join(init_model, var.name))
fluid.io.load_vars(exe, init_model, main_program=train_prog,
predicate=if_exist)
else:
print('There is no init model.')
if parallel:
train_exe = fluid.ParallelExecutor(main_program=train_prog,
use_cuda=use_gpu, loss_name=loss.name)
train_reader = reader.train(data_args,
train_file_list,
batch_size_per_device,
shuffle=is_shuffle,
use_multiprocessing=True,
num_workers=num_workers,
max_queue=24)
test_reader = reader.test(data_args, val_file_list, batch_size)
train_py_reader.decorate_paddle_reader(train_reader)
test_py_reader.decorate_paddle_reader(test_reader)
train_py_reader.start()
best_map = 0.
try:
for epoc in range(epoc_num):
if epoc == 0:
# test quantized model without quantization-aware training.
test_map = test(exe, test_prog, map_eval, test_py_reader)
# train
for batch in range(iters_per_epoc):
start_time = time.time()
if parallel:
outs = train_exe.run(fetch_list=[loss.name])
else:
outs = exe.run(train_prog, fetch_list=[loss])
end_time = time.time()
avg_loss = np.mean(np.array(outs[0]))
if batch % 20 == 0:
print("Epoc {:d}, batch {:d}, loss {:.6f}, time {:.5f}".format(
epoc , batch, avg_loss, end_time - start_time))
end_time = time.time()
test_map = test(exe, test_prog, map_eval, test_py_reader)
save_model(exe, train_prog, model_save_dir, str(epoc))
if test_map > best_map:
best_map = test_map
save_model(exe, train_prog, model_save_dir, 'best_map')
print("Best test map {0}".format(best_map))
except (fluid.core.EOFException, StopIteration):
train_py_reader.reset()
def eval(args, data_args, configs, val_file_list):
init_model = args.init_model
use_gpu = args.use_gpu
act_quant_type = args.act_quant_type
model_save_dir = args.model_save_dir
batch_size = configs['batch_size']
batch_size_per_device = batch_size
startup_prog = fluid.Program()
test_prog = fluid.Program()
test_py_reader, map_eval, nmsed_out, image = build_program(
main_prog=test_prog,
startup_prog=startup_prog,
train_params=configs,
is_train=False)
test_prog = test_prog.clone(for_test=True)
transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8,
activation_bits=8,
activation_quantize_type=act_quant_type,
weight_quantize_type='abs_max')
transpiler.training_transpile(test_prog, startup_prog)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
def if_exist(var):
return os.path.exists(os.path.join(init_model, var.name))
fluid.io.load_vars(exe, init_model, main_program=test_prog,
predicate=if_exist)
# freeze after load parameters
transpiler.freeze_program(test_prog, place)
test_reader = reader.test(data_args, val_file_list, batch_size)
test_py_reader.decorate_paddle_reader(test_reader)
test_map = test(exe, test_prog, map_eval, test_py_reader)
print("Test model {0}, map {1}".format(init_model, test_map))
fluid.io.save_inference_model(model_save_dir, [image.name],
[nmsed_out], exe, test_prog)
def infer(args, data_args):
model_dir = args.init_model
image_path = args.image_path
confs_threshold = args.confs_threshold
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed , fetch] = fluid.io.load_inference_model(
dirname=model_dir,
executor=exe,
model_filename='__model__')
#print(np.array(fluid.global_scope().find_var('conv2d_20.w_0').get_tensor()))
#print(np.max(np.array(fluid.global_scope().find_var('conv2d_20.w_0').get_tensor())))
infer_reader = reader.infer(data_args, image_path)
data = infer_reader()
data = data.reshape((1,) + data.shape)
outs = exe.run(inference_program,
feed={feed[0]: data},
fetch_list=fetch,
return_numpy=False)
out = np.array(outs[0])
draw_bounding_box_on_image(image_path, out, confs_threshold,
data_args.label_list)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
# for pascalvoc
label_file = 'label_list'
train_list = 'trainval.txt'
val_list = 'test.txt'
dataset = 'pascalvoc'
mean_BGR = [float(m) for m in args.mean_BGR.split(",")]
image_shape = [int(m) for m in args.image_shape.split(",")]
lr_epochs = [int(m) for m in args.lr_epochs.split(",")]
lr_rates = [float(m) for m in args.lr_decay_rates.split(",")]
train_parameters[dataset]['image_shape'] = image_shape
train_parameters[dataset]['batch_size'] = args.batch_size
train_parameters[dataset]['lr'] = args.learning_rate
train_parameters[dataset]['epoc_num'] = args.epoc_num
train_parameters[dataset]['ap_version'] = args.ap_version
train_parameters[dataset]['lr_epochs'] = lr_epochs
train_parameters[dataset]['lr_decay'] = lr_rates
data_args = reader.Settings(
dataset=dataset,
data_dir=args.data_dir,
label_file=label_file,
resize_h=image_shape[1],
resize_w=image_shape[2],
mean_value=mean_BGR,
apply_distort=True,
apply_expand=True,
ap_version = args.ap_version)
if args.mode == 'train':
train(args, data_args, train_parameters[dataset], train_list, val_list)
elif args.mode == 'test':
eval(args, data_args, train_parameters[dataset], val_list)
else:
infer(args, data_args)
......@@ -5,6 +5,7 @@ import argparse
import functools
import shutil
import math
import multiprocessing
import paddle
import paddle.fluid as fluid
......@@ -16,18 +17,18 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('batch_size', int, 64, "Minibatch size of all devices.")
add_arg('epoc_num', int, 120, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('parallel', bool, True, "Parallel.")
add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('parallel', bool, True, "Whether train in parallel on multi-devices.")
add_arg('dataset', str, 'pascalvoc', "dataset can be coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('ap_version', str, '11point', "Integral, 11point.")
add_arg('ap_version', str, '11point', "mAP version can be integral or 11point.")
add_arg('image_shape', str, '3,300,300', "Input image shape.")
add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.")
add_arg('data_dir', str, 'data/pascalvoc', "data directory")
add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model")
add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.")
add_arg('data_dir', str, 'data/pascalvoc', "Data directory.")
add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model.")
#yapf: enable
train_parameters = {
......@@ -81,6 +82,7 @@ def build_program(main_prog, startup_prog, train_params, is_train):
image_shape = train_params['image_shape']
class_num = train_params['class_num']
ap_version = train_params['ap_version']
outs = []
with fluid.program_guard(main_prog, startup_prog):
py_reader = fluid.layers.py_reader(
capacity=64,
......@@ -98,11 +100,12 @@ def build_program(main_prog, startup_prog, train_params, is_train):
loss = fluid.layers.reduce_sum(loss)
optimizer = optimizer_setting(train_params)
optimizer.minimize(loss)
outs = [py_reader, loss]
else:
with fluid.unique_name.guard("inference"):
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.evaluator.DetectionMAP(
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
......@@ -111,7 +114,9 @@ def build_program(main_prog, startup_prog, train_params, is_train):
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=ap_version)
return py_reader, loss
# nmsed_out and image is used to save mode for inference
outs = [py_reader, map_eval, nmsed_out, image]
return outs
def train(args,
......@@ -127,8 +132,12 @@ def train(args,
enable_ce = args.enable_ce
is_shuffle = True
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
if not use_gpu:
devices_num = int(os.environ.get('CPU_NUM',
multiprocessing.cpu_count()))
else:
devices_num = fluid.core.get_cuda_device_count()
batch_size = train_params['batch_size']
epoc_num = train_params['epoc_num']
batch_size_per_device = batch_size // devices_num
......@@ -153,7 +162,7 @@ def train(args,
startup_prog=startup_prog,
train_params=train_params,
is_train=True)
test_py_reader, map_eval = build_program(
test_py_reader, map_eval, _, _ = build_program(
main_prog=test_prog,
startup_prog=startup_prog,
train_params=train_params,
......@@ -258,11 +267,9 @@ def train(args,
print("kpis train_speed_card%s %f" %
(devices_num, total_time / epoch_idx))
except fluid.core.EOFException:
train_py_reader.reset()
except StopIteration:
except (fluid.core.EOFException, StopIteration):
train_reader().close()
train_py_reader.reset()
train_py_reader.reset()
if __name__ == '__main__':
......@@ -291,6 +298,7 @@ if __name__ == '__main__':
train_parameters[dataset]['batch_size'] = args.batch_size
train_parameters[dataset]['lr'] = args.learning_rate
train_parameters[dataset]['epoc_num'] = args.epoc_num
train_parameters[dataset]['ap_version'] = args.ap_version
data_args = reader.Settings(
dataset=args.dataset,
......
Subproject commit d2fc9e0b45b4e6cfc93e73054026fc5a8abfbfb9
......@@ -5,6 +5,7 @@ python -u ../train_and_evaluate.py --use_cuda \
--ext_eval \
--word_emb_init ./data/word_embedding.pkl \
--save_path ./models \
--use_pyreader \
--batch_size 256 \
--vocab_size 172130 \
--channel1_num 16 \
......
......@@ -15,45 +15,85 @@ class Net(object):
self._stack_num = stack_num
self._channel1_num = channel1_num
self._channel2_num = channel2_num
self._feed_names = []
self.word_emb_name = "shared_word_emb"
self.use_stack_op = True
self.use_mask_cache = True
self.use_sparse_embedding = True
def set_word_embedding(self, word_emb, place):
word_emb_param = fluid.global_scope().find_var(
self.word_emb_name).get_tensor()
word_emb_param.set(word_emb, place)
def create_network(self):
mask_cache = dict() if self.use_mask_cache else None
turns_data = []
def create_py_reader(self, capacity, name):
# turns ids
shapes = [[-1, self._max_turn_len, 1]
for i in six.moves.xrange(self._max_turn_num)]
dtypes = ["int32" for i in six.moves.xrange(self._max_turn_num)]
# turns mask
shapes += [[-1, self._max_turn_len, 1]
for i in six.moves.xrange(self._max_turn_num)]
dtypes += ["float32" for i in six.moves.xrange(self._max_turn_num)]
# response ids, response mask, label
shapes += [[-1, self._max_turn_len, 1], [-1, self._max_turn_len, 1],
[-1, 1]]
dtypes += ["int32", "float32", "float32"]
py_reader = fluid.layers.py_reader(
capacity=capacity,
shapes=shapes,
lod_levels=[0] * (2 * self._max_turn_num + 3),
dtypes=dtypes,
name=name,
use_double_buffer=True)
data_vars = fluid.layers.read_file(py_reader)
self.turns_data = data_vars[0:self._max_turn_num]
self.turns_mask = data_vars[self._max_turn_num:2 * self._max_turn_num]
self.response = data_vars[-3]
self.response_mask = data_vars[-2]
self.label = data_vars[-1]
return py_reader
def create_data_layers(self):
self._feed_names = []
self.turns_data = []
for i in six.moves.xrange(self._max_turn_num):
name = "turn_%d" % i
turn = fluid.layers.data(
name="turn_%d" % i,
shape=[self._max_turn_len, 1],
dtype="int32")
turns_data.append(turn)
name=name, shape=[self._max_turn_len, 1], dtype="int32")
self.turns_data.append(turn)
self._feed_names.append(name)
turns_mask = []
self.turns_mask = []
for i in six.moves.xrange(self._max_turn_num):
name = "turn_mask_%d" % i
turn_mask = fluid.layers.data(
name="turn_mask_%d" % i,
shape=[self._max_turn_len, 1],
dtype="float32")
turns_mask.append(turn_mask)
name=name, shape=[self._max_turn_len, 1], dtype="float32")
self.turns_mask.append(turn_mask)
self._feed_names.append(name)
response = fluid.layers.data(
self.response = fluid.layers.data(
name="response", shape=[self._max_turn_len, 1], dtype="int32")
response_mask = fluid.layers.data(
self.response_mask = fluid.layers.data(
name="response_mask",
shape=[self._max_turn_len, 1],
dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="float32")
self.label = fluid.layers.data(name="label", shape=[1], dtype="float32")
self._feed_names += ["response", "response_mask", "label"]
def get_feed_names(self):
return self._feed_names
def set_word_embedding(self, word_emb, place):
word_emb_param = fluid.global_scope().find_var(
self.word_emb_name).get_tensor()
word_emb_param.set(word_emb, place)
def create_network(self):
mask_cache = dict() if self.use_mask_cache else None
response_emb = fluid.layers.embedding(
input=response,
input=self.response,
size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr(
......@@ -71,8 +111,8 @@ class Net(object):
key=Hr,
value=Hr,
d_key=self._emb_size,
q_mask=response_mask,
k_mask=response_mask,
q_mask=self.response_mask,
k_mask=self.response_mask,
mask_cache=mask_cache)
Hr_stack.append(Hr)
......@@ -80,7 +120,7 @@ class Net(object):
sim_turns = []
for t in six.moves.xrange(self._max_turn_num):
Hu = fluid.layers.embedding(
input=turns_data[t],
input=self.turns_data[t],
size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr(
......@@ -96,8 +136,8 @@ class Net(object):
key=Hu,
value=Hu,
d_key=self._emb_size,
q_mask=turns_mask[t],
k_mask=turns_mask[t],
q_mask=self.turns_mask[t],
k_mask=self.turns_mask[t],
mask_cache=mask_cache)
Hu_stack.append(Hu)
......@@ -111,8 +151,8 @@ class Net(object):
key=Hr_stack[index],
value=Hr_stack[index],
d_key=self._emb_size,
q_mask=turns_mask[t],
k_mask=response_mask,
q_mask=self.turns_mask[t],
k_mask=self.response_mask,
mask_cache=mask_cache)
r_a_t = layers.block(
name="r_attend_t_" + str(index),
......@@ -120,8 +160,8 @@ class Net(object):
key=Hu_stack[index],
value=Hu_stack[index],
d_key=self._emb_size,
q_mask=response_mask,
k_mask=turns_mask[t],
q_mask=self.response_mask,
k_mask=self.turns_mask[t],
mask_cache=mask_cache)
t_a_r_stack.append(t_a_r)
......@@ -158,5 +198,5 @@ class Net(object):
sim = fluid.layers.concat(input=sim_turns, axis=2)
final_info = layers.cnn_3d(sim, self._channel1_num, self._channel2_num)
loss, logits = layers.loss(final_info, label)
loss, logits = layers.loss(final_info, self.label)
return loss, logits
......@@ -7,7 +7,7 @@ import multiprocessing
import paddle
import paddle.fluid as fluid
import utils.reader as reader
from utils.util import print_arguments
from utils.util import print_arguments, mkdir
try:
import cPickle as pickle #python 2
......@@ -49,6 +49,10 @@ def parse_args():
'--use_cuda',
action='store_true',
help='If set, use cuda for training.')
parser.add_argument(
'--use_pyreader',
action='store_true',
help='If set, use pyreader for reading data.')
parser.add_argument(
'--ext_eval',
action='store_true',
......@@ -105,7 +109,75 @@ def parse_args():
#yapf: enable
def evaluate(score_path, result_file_path):
if args.ext_eval:
import utils.douban_evaluation as eva
else:
import utils.evaluation as eva
#write evaluation result
result = eva.evaluate(score_path)
with open(result_file_path, 'w') as out_file:
for p_at in result:
out_file.write(str(p_at) + '\n')
print('finish evaluation')
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
def test_with_feed(exe, program, feed_names, fetch_list, score_path, batches,
batch_num, dev_count):
score_file = open(score_path, 'w')
for it in six.moves.xrange(batch_num // dev_count):
feed_list = []
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
batch_data = reader.make_one_batch_input(batches, val_index)
feed_dict = dict(zip(feed_names, batch_data))
feed_list.append(feed_dict)
predicts = exe.run(feed=feed_list, fetch_list=fetch_list)
scores = np.array(predicts[0])
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
for i in six.moves.xrange(args.batch_size):
score_file.write(
str(scores[args.batch_size * dev + i][0]) + '\t' + str(
batches["label"][val_index][i]) + '\n')
score_file.close()
def test_with_pyreader(exe, program, pyreader, fetch_list, score_path, batches,
batch_num, dev_count):
def data_provider():
for index in six.moves.xrange(batch_num):
yield reader.make_one_batch_input(batches, index)
score_file = open(score_path, 'w')
pyreader.decorate_tensor_provider(data_provider)
it = 0
pyreader.start()
while True:
try:
predicts = exe.run(fetch_list=fetch_list)
scores = np.array(predicts[0])
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
for i in six.moves.xrange(args.batch_size):
score_file.write(
str(scores[args.batch_size * dev + i][0]) + '\t' + str(
batches["label"][val_index][i]) + '\n')
it += 1
except fluid.core.EOFException:
pyreader.reset()
break
score_file.close()
def train(args):
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
# data data_config
data_conf = {
"batch_size": args.batch_size,
......@@ -117,27 +189,47 @@ def train(args):
dam = Net(args.max_turn_num, args.max_turn_len, args.vocab_size,
args.emb_size, args.stack_num, args.channel1_num,
args.channel2_num)
loss, logits = dam.create_network()
loss.persistable = True
logits.persistable = True
train_program = fluid.default_main_program()
test_program = fluid.default_main_program().clone(for_test=True)
# gradient clipping
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0))
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.learning_rate,
decay_steps=400,
decay_rate=0.9,
staircase=True))
optimizer.minimize(loss)
fluid.memory_optimize(train_program)
train_program = fluid.Program()
train_startup = fluid.Program()
with fluid.program_guard(train_program, train_startup):
with fluid.unique_name.guard():
if args.use_pyreader:
train_pyreader = dam.create_py_reader(
capacity=10, name='train_reader')
else:
dam.create_data_layers()
loss, logits = dam.create_network()
loss.persistable = True
logits.persistable = True
# gradient clipping
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0))
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.learning_rate,
decay_steps=400,
decay_rate=0.9,
staircase=True))
optimizer.minimize(loss)
fluid.memory_optimize(train_program)
test_program = fluid.Program()
test_startup = fluid.Program()
with fluid.program_guard(test_program, test_startup):
with fluid.unique_name.guard():
if args.use_pyreader:
test_pyreader = dam.create_py_reader(
capacity=10, name='test_reader')
else:
dam.create_data_layers()
loss, logits = dam.create_network()
loss.persistable = True
logits.persistable = True
test_program = test_program.clone(for_test=True)
if args.use_cuda:
place = fluid.CUDAPlace(0)
......@@ -152,7 +244,8 @@ def train(args):
program=train_program, batch_size=args.batch_size))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run(train_startup)
exe.run(test_startup)
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda, loss_name=loss.name, main_program=train_program)
......@@ -162,11 +255,6 @@ def train(args):
main_program=test_program,
share_vars_from=train_exe)
if args.ext_eval:
import utils.douban_evaluation as eva
else:
import utils.evaluation as eva
if args.word_emb_init is not None:
print("start loading word embedding init ...")
if six.PY2:
......@@ -199,17 +287,15 @@ def train(args):
print("begin model training ...")
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
step = 0
for epoch in six.moves.xrange(args.num_scan_data):
shuffle_train = reader.unison_shuffle(train_data)
train_batches = reader.build_batches(shuffle_train, data_conf)
# train on one epoch data by feeding
def train_with_feed(step):
ave_cost = 0.0
for it in six.moves.xrange(batch_num // dev_count):
feed_list = []
for dev in six.moves.xrange(dev_count):
index = it * dev_count + dev
feed_dict = reader.make_one_batch_input(train_batches, index)
batch_data = reader.make_one_batch_input(train_batches, index)
feed_dict = dict(zip(dam.get_feed_names(), batch_data))
feed_list.append(feed_dict)
cost = train_exe.run(feed=feed_list, fetch_list=[loss.name])
......@@ -226,41 +312,73 @@ def train(args):
print("Save model at step %d ... " % step)
print(time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(time.time())))
fluid.io.save_persistables(exe, save_path)
fluid.io.save_persistables(exe, save_path, train_program)
score_path = os.path.join(args.save_path, 'score.' + str(step))
score_file = open(score_path, 'w')
for it in six.moves.xrange(val_batch_num // dev_count):
feed_list = []
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
feed_dict = reader.make_one_batch_input(val_batches,
val_index)
feed_list.append(feed_dict)
predicts = test_exe.run(feed=feed_list,
fetch_list=[logits.name])
scores = np.array(predicts[0])
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
for i in six.moves.xrange(args.batch_size):
score_file.write(
str(scores[args.batch_size * dev + i][0]) + '\t'
+ str(val_batches["label"][val_index][
i]) + '\n')
score_file.close()
#write evaluation result
result = eva.evaluate(score_path)
test_with_feed(test_exe, test_program,
dam.get_feed_names(), [logits.name], score_path,
val_batches, val_batch_num, dev_count)
result_file_path = os.path.join(args.save_path,
'result.' + str(step))
with open(result_file_path, 'w') as out_file:
for p_at in result:
out_file.write(str(p_at) + '\n')
print('finish evaluation')
print(time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(time.time())))
evaluate(score_path, result_file_path)
return step
# train on one epoch with pyreader
def train_with_pyreader(step):
def data_provider():
for index in six.moves.xrange(batch_num):
yield reader.make_one_batch_input(train_batches, index)
train_pyreader.decorate_tensor_provider(data_provider)
ave_cost = 0.0
train_pyreader.start()
while True:
try:
cost = train_exe.run(fetch_list=[loss.name])
ave_cost += np.array(cost[0]).mean()
step = step + 1
if step % print_step == 0:
print("processed: [" + str(step * dev_count * 1.0 /
batch_num) + "] ave loss: [" +
str(ave_cost / print_step) + "]")
ave_cost = 0.0
if (args.save_path is not None) and (step % save_step == 0):
save_path = os.path.join(args.save_path,
"step_" + str(step))
print("Save model at step %d ... " % step)
print(time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(time.time())))
fluid.io.save_persistables(exe, save_path, train_program)
score_path = os.path.join(args.save_path,
'score.' + str(step))
test_with_pyreader(test_exe, test_program, test_pyreader,
[logits.name], score_path, val_batches,
val_batch_num, dev_count)
result_file_path = os.path.join(args.save_path,
'result.' + str(step))
evaluate(score_path, result_file_path)
except fluid.core.EOFException:
train_pyreader.reset()
break
return step
# train over different epoches
global_step = 0
for epoch in six.moves.xrange(args.num_scan_data):
shuffle_train = reader.unison_shuffle(train_data)
train_batches = reader.build_batches(shuffle_train, data_conf)
if args.use_pyreader:
global_step = train_with_pyreader(global_step)
else:
global_step = train_with_feed(global_step)
if __name__ == '__main__':
......
......@@ -4,6 +4,7 @@ python -u ../train_and_evaluate.py --use_cuda \
--data_path ./data/data.pkl \
--word_emb_init ./data/word_embedding.pkl \
--save_path ./models \
--use_pyreader \
--batch_size 256 \
--vocab_size 434512 \
--emb_size 200 \
......
......@@ -202,30 +202,30 @@ def make_one_batch_input(data_batches, index):
every_turn_len[:, i] for i in six.moves.xrange(max_turn_num)
]
feed_dict = {}
feed_list = []
for i, turn in enumerate(turns_list):
feed_dict["turn_%d" % i] = turn
feed_dict["turn_%d" % i] = np.expand_dims(
feed_dict["turn_%d" % i], axis=-1)
turn = np.expand_dims(turn, axis=-1)
feed_list.append(turn)
for i, turn_len in enumerate(every_turn_len_list):
feed_dict["turn_mask_%d" % i] = np.ones(
(batch_size, max_turn_len, 1)).astype("float32")
turn_mask = np.ones((batch_size, max_turn_len, 1)).astype("float32")
for row in six.moves.xrange(batch_size):
feed_dict["turn_mask_%d" % i][row, turn_len[row]:, 0] = 0
turn_mask[row, turn_len[row]:, 0] = 0
feed_list.append(turn_mask)
feed_dict["response"] = response
feed_dict["response"] = np.expand_dims(feed_dict["response"], axis=-1)
response = np.expand_dims(response, axis=-1)
feed_list.append(response)
feed_dict["response_mask"] = np.ones(
(batch_size, max_turn_len, 1)).astype("float32")
response_mask = np.ones((batch_size, max_turn_len, 1)).astype("float32")
for row in six.moves.xrange(batch_size):
feed_dict["response_mask"][row, response_len[row]:, 0] = 0
response_mask[row, response_len[row]:, 0] = 0
feed_list.append(response_mask)
feed_dict["label"] = np.array([data_batches["label"][index]]).reshape(
label = np.array([data_batches["label"][index]]).reshape(
[-1, 1]).astype("float32")
feed_list.append(label)
return feed_dict
return feed_list
if __name__ == '__main__':
......
......@@ -22,10 +22,7 @@ def parse_args():
help='If set, run \
the task with continuous evaluation logs.')
parser.add_argument(
'--num_devices',
type=int,
default=1,
help='Number of GPU devices')
'--num_devices', type=int, default=1, help='Number of GPU devices')
args = parser.parse_args()
return args
......@@ -129,15 +126,15 @@ def train(train_reader,
newest_ppl = 0
for data in train_reader():
i += 1
lod_src_wordseq = utils.to_lodtensor(
[dat[0] for dat in data], place)
lod_dst_wordseq = utils.to_lodtensor(
[dat[1] for dat in data], place)
lod_src_wordseq = utils.to_lodtensor([dat[0] for dat in data],
place)
lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data],
place)
ret_avg_cost = train_exe.run(feed={
"src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq
},
fetch_list=fetch_list)
fetch_list=fetch_list)
avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl)
if i % 100 == 0:
......@@ -145,8 +142,8 @@ def train(train_reader,
t1 = time.time()
total_time += t1 - t0
print("epoch:%d num_steps:%d time_cost(s):%f" % (epoch_idx, i,
total_time / epoch_idx))
print("epoch:%d num_steps:%d time_cost(s):%f" %
(epoch_idx, i, total_time / epoch_idx))
if pass_idx == pass_num - 1 and args.enable_ce:
#Note: The following logs are special for CE monitoring.
......
......@@ -236,8 +236,8 @@ def do_train(train_reader,
t1 = time.time()
total_time += t1 - t0
print("epoch:%d num_steps:%d time_cost(s):%f" % (epoch_idx, i,
total_time / epoch_idx))
print("epoch:%d num_steps:%d time_cost(s):%f" %
(epoch_idx, i, total_time / epoch_idx))
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
feed_var_names = ["src_wordseq", "dst_wordseq"]
......
export CUDA_VISIBLE_DEVICES=0
cd data
sh download_data.sh
cd ..
python train.py \
--data_path data/simple-examples/data/ \
--model_type small \
--use_gpu True \
--enable_ce | python _ce.py
# lstm lm
以下是本例的简要目录结构及说明:
```text
.
├── README.md # 文档
├── train.py # 训练脚本
├── reader.py # 数据读取
└── lm_model.py # 模型定义文件
```
## 简介
循环神经网络语言模型的介绍可以参阅论文[Recurrent Neural Network Regularization](https://arxiv.org/abs/1409.2329),本文主要是说明基于lstm的语言的模型的实现,数据是采用ptb dataset,下载地址为
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
## 数据下载
用户可以自行下载数据,并解压, 也可以利用目录中的脚本
cd data; sh download_data.sh
## 训练
运行命令
`CUDA_VISIBLE_DEVICES=0 python train.py --data_path data/simple-examples/data/ --model_type small --use_gpu True`
开始训练模型。
model_type 为模型配置的大小,目前支持 small,medium, large 三种配置形式
实现采用双层的lstm,具体的参数和网络配置 可以参考 train.py, lm_model.py 文件中的设置
## 训练结果示例
p40中训练日志如下(small config), test 测试集仅在最后一个epoch完成后进行测试
```text
epoch id 0
ppl 232 865.86505 1.0
ppl 464 632.76526 1.0
ppl 696 510.47153 1.0
ppl 928 437.60617 1.0
ppl 1160 393.38422 1.0
ppl 1392 353.05365 1.0
ppl 1624 325.73267 1.0
ppl 1856 305.488 1.0
ppl 2088 286.3128 1.0
ppl 2320 270.91504 1.0
train ppl 270.86246
valid ppl 181.867964379
...
ppl 2320 40.975872 0.001953125
train ppl 40.974102
valid ppl 117.85741214
test ppl 113.939103843
```
## 与tf结果对比
tf采用的版本是1.6
```text
small config
train valid test
fluid 1.0 40.962 118.111 112.617
tf 1.6 40.492 118.329 113.788
medium config
train valid test
fluid 1.0 45.620 87.398 83.682
tf 1.6 45.594 87.363 84.015
large config
train valid test
fluid 1.0 37.221 82.358 78.137
tf 1.6 38.342 82.311 78.121
```
# this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi
from kpi import DurationKpi
imikolov_20_avg_ppl_kpi = CostKpi('lstm_language_model_loss', 0.02, 0)
imikolov_20_pass_duration_kpi = DurationKpi(
'lstm_language_model_duration', 0.02, 0, actived=True)
tracking_kpis = [
imikolov_20_avg_ppl_kpi,
imikolov_20_pass_duration_kpi,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
kpi_name = fs[0]
kpi_value = float(fs[1])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_type",
type=str,
default="small",
help="model_type [test|small|med|big]")
parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument('--para_init', action='store_true')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='whether using gpu')
parser.add_argument(
'--log_path',
help='path of the log file. If not set, logs are printed to console')
parser.add_argument('--enable_ce', action='store_true')
args = parser.parse_args()
return args
wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar -xzvf simple-examples.tgz
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN
import numpy as np
def lm_model(hidden_size,
vocab_size,
batch_size,
num_layers=2,
num_steps=20,
init_scale=0.1,
dropout=None):
def padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None):
weight_1_arr = []
weight_2_arr = []
bias_arr = []
hidden_array = []
cell_array = []
mask_array = []
for i in range(num_layers):
weight_1 = layers.create_parameter([hidden_size * 2, hidden_size*4], dtype="float32", name="fc_weight1_"+str(i), \
default_initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale))
weight_1_arr.append(weight_1)
bias_1 = layers.create_parameter(
[hidden_size * 4],
dtype="float32",
name="fc_bias1_" + str(i),
default_initializer=fluid.initializer.Constant(0.0))
bias_arr.append(bias_1)
pre_hidden = layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1])
pre_cell = layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1])
pre_hidden = layers.reshape(pre_hidden, shape=[-1, hidden_size])
pre_cell = layers.reshape(pre_cell, shape=[-1, hidden_size])
hidden_array.append(pre_hidden)
cell_array.append(pre_cell)
input_embedding = layers.transpose(input_embedding, perm=[1, 0, 2])
rnn = PaddingRNN()
with rnn.step():
input = rnn.step_input(input_embedding)
for k in range(num_layers):
pre_hidden = rnn.memory(init=hidden_array[k])
pre_cell = rnn.memory(init=cell_array[k])
weight_1 = weight_1_arr[k]
bias = bias_arr[k]
nn = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=nn, y=weight_1)
gate_input = layers.elementwise_add(gate_input, bias)
#i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
i = layers.slice(
gate_input, axes=[1], starts=[0], ends=[hidden_size])
j = layers.slice(
gate_input,
axes=[1],
starts=[hidden_size],
ends=[hidden_size * 2])
f = layers.slice(
gate_input,
axes=[1],
starts=[hidden_size * 2],
ends=[hidden_size * 3])
o = layers.slice(
gate_input,
axes=[1],
starts=[hidden_size * 3],
ends=[hidden_size * 4])
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
rnn.update_memory(pre_hidden, m)
rnn.update_memory(pre_cell, c)
rnn.step_output(m)
rnn.step_output(c)
input = m
if dropout != None and dropout > 0.0:
input = layers.dropout(
input,
dropout_prob=dropout,
dropout_implementation='upscale_in_train')
rnn.step_output(input)
#real_res = layers.concat(res, 0)
rnnout = rnn()
last_hidden_array = []
last_cell_array = []
real_res = rnnout[-1]
for i in range(num_layers):
m = rnnout[i * 2]
c = rnnout[i * 2 + 1]
m.stop_gradient = True
c.stop_gradient = True
last_h = layers.slice(
m, axes=[0], starts=[num_steps - 1], ends=[num_steps])
last_hidden_array.append(last_h)
last_c = layers.slice(
c, axes=[0], starts=[num_steps - 1], ends=[num_steps])
last_cell_array.append(last_c)
'''
else:
real_res = rnnout[-1]
for i in range( num_layers ):
m1, c1, m2, c2 = rnnout
real_res = m2
m1.stop_gradient = True
c1.stop_gradient = True
c2.stop_gradient = True
'''
#layers.Print( first_hidden, message="22", summarize=10)
#layers.Print( rnnout[1], message="11", summarize=10)
#real_res = ( rnnout[1] + rnnout[2] + rnnout[3] + rnnout[4]) / 4.0
real_res = layers.transpose(x=real_res, perm=[1, 0, 2])
last_hidden = layers.concat(last_hidden_array, 0)
last_cell = layers.concat(last_cell_array, 0)
'''
last_hidden = layers.concat( hidden_array, 1 )
last_hidden = layers.reshape( last_hidden, shape=[-1, num_layers, hidden_size])
last_hidden = layers.transpose( x = last_hidden, perm = [1, 0, 2])
last_cell = layers.concat( cell_array, 1)
last_cell = layers.reshape( last_cell, shape=[ -1, num_layers, hidden_size])
last_cell = layers.transpose( x = last_cell, perm = [1, 0, 2])
'''
return real_res, last_hidden, last_cell
def encoder_static(input_embedding, len=3, init_hidden=None,
init_cell=None):
weight_1_arr = []
weight_2_arr = []
bias_arr = []
hidden_array = []
cell_array = []
mask_array = []
for i in range(num_layers):
weight_1 = layers.create_parameter([hidden_size * 2, hidden_size*4], dtype="float32", name="fc_weight1_"+str(i), \
default_initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale))
weight_1_arr.append(weight_1)
bias_1 = layers.create_parameter(
[hidden_size * 4],
dtype="float32",
name="fc_bias1_" + str(i),
default_initializer=fluid.initializer.Constant(0.0))
bias_arr.append(bias_1)
pre_hidden = layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1])
pre_cell = layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1])
pre_hidden = layers.reshape(pre_hidden, shape=[-1, hidden_size])
pre_cell = layers.reshape(pre_cell, shape=[-1, hidden_size])
hidden_array.append(pre_hidden)
cell_array.append(pre_cell)
res = []
for index in range(len):
input = layers.slice(
input_embedding, axes=[1], starts=[index], ends=[index + 1])
input = layers.reshape(input, shape=[-1, hidden_size])
for k in range(num_layers):
pre_hidden = hidden_array[k]
pre_cell = cell_array[k]
weight_1 = weight_1_arr[k]
bias = bias_arr[k]
nn = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=nn, y=weight_1)
gate_input = layers.elementwise_add(gate_input, bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
hidden_array[k] = m
cell_array[k] = c
input = m
if dropout != None and dropout > 0.0:
input = layers.dropout(
input,
dropout_prob=dropout,
dropout_implementation='upscale_in_train')
res.append(layers.reshape(input, shape=[1, -1, hidden_size]))
real_res = layers.concat(res, 0)
real_res = layers.transpose(x=real_res, perm=[1, 0, 2])
last_hidden = layers.concat(hidden_array, 1)
last_hidden = layers.reshape(
last_hidden, shape=[-1, num_layers, hidden_size])
last_hidden = layers.transpose(x=last_hidden, perm=[1, 0, 2])
last_cell = layers.concat(cell_array, 1)
last_cell = layers.reshape(
last_cell, shape=[-1, num_layers, hidden_size])
last_cell = layers.transpose(x=last_cell, perm=[1, 0, 2])
return real_res, last_hidden, last_cell
x = layers.data(name="x", shape=[-1, 1, 1], dtype='int64')
y = layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = layers.data(name="init_hidden", shape=[1], dtype='float32')
init_cell = layers.data(name="init_cell", shape=[1], dtype='float32')
init_hidden = layers.reshape(
init_hidden, shape=[num_layers, -1, hidden_size])
init_cell = layers.reshape(init_cell, shape=[num_layers, -1, hidden_size])
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=True,
param_attr=fluid.ParamAttr(
name='embedding_para',
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))
x_emb = layers.reshape(x_emb, shape=[-1, num_steps, hidden_size])
if dropout != None and dropout > 0.0:
x_emb = layers.dropout(
x_emb,
dropout_prob=dropout,
dropout_implementation='upscale_in_train')
rnn_out, last_hidden, last_cell = padding_rnn(
x_emb, len=num_steps, init_hidden=init_hidden, init_cell=init_cell)
rnn_out = layers.reshape(rnn_out, shape=[-1, num_steps, hidden_size])
softmax_weight = layers.create_parameter([hidden_size, vocab_size], dtype="float32", name="softmax_weight", \
default_initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale))
softmax_bias = layers.create_parameter([vocab_size], dtype="float32", name='softmax_bias', \
default_initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale))
projection = layers.matmul(rnn_out, softmax_weight)
projection = layers.elementwise_add(projection, softmax_bias)
projection = layers.reshape(projection, shape=[-1, vocab_size])
#y = layers.reshape( y, shape=[-1, vocab_size])
loss = layers.softmax_with_cross_entropy(
logits=projection, label=y, soft_label=False)
loss = layers.reshape(loss, shape=[-1, num_steps])
loss = layers.reduce_mean(loss, dim=[0])
loss = layers.reduce_sum(loss)
loss.permissions = True
feeding_list = ['x', 'y', 'init_hidden', 'init_cell']
return loss, last_hidden, last_cell, feeding_list
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for parsing PTB text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import sys
import numpy as np
Py3 = sys.version_info[0] == 3
def _read_words(filename):
data = []
with open(filename, "r") as f:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
print("vocab word num", len(words))
word_to_id = dict(zip(words, range(len(words))))
return word_to_id
def _file_to_word_ids(filename, word_to_id):
data = _read_words(filename)
return [word_to_id[word] for word in data if word in word_to_id]
def ptb_raw_data(data_path=None):
"""Load PTB raw data from data directory "data_path".
Reads PTB text files, converts strings to integer ids,
and performs mini-batching of the inputs.
The PTB dataset comes from Tomas Mikolov's webpage:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data, test_data, vocabulary)
where each of the data objects can be passed to PTBIterator.
"""
train_path = os.path.join(data_path, "ptb.train.txt")
#train_path = os.path.join(data_path, "train.fake")
valid_path = os.path.join(data_path, "ptb.valid.txt")
test_path = os.path.join(data_path, "ptb.test.txt")
word_to_id = _build_vocab(train_path)
train_data = _file_to_word_ids(train_path, word_to_id)
valid_data = _file_to_word_ids(valid_path, word_to_id)
test_data = _file_to_word_ids(test_path, word_to_id)
vocabulary = len(word_to_id)
return train_data, valid_data, test_data, vocabulary
def get_data_iter(raw_data, batch_size, num_steps):
data_len = len(raw_data)
raw_data = np.asarray(raw_data, dtype="int64")
#print( "raw", raw_data[:20] )
batch_len = data_len // batch_size
data = raw_data[0:batch_size * batch_len].reshape((batch_size, batch_len))
#h = data.reshape( (-1))
#print( "h", h[:20])
epoch_size = (batch_len - 1) // num_steps
for i in range(epoch_size):
start = i * num_steps
#print( i * num_steps )
x = np.copy(data[:, i * num_steps:(i + 1) * num_steps])
y = np.copy(data[:, i * num_steps + 1:(i + 1) * num_steps + 1])
yield (x, y)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import time
import os
import random
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.executor import Executor
import reader
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append('..')
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from args import *
import lm_model
import logging
import pickle
SEED = 123
def get_current_model_para(train_prog, train_exe):
param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
vals = {}
for p_name in param_name_list:
p_array = np.array(fluid.global_scope().find_var(p_name).get_tensor())
vals[p_name] = p_array
return vals
def save_para_npz(train_prog, train_exe):
print("begin to save model to model_base")
param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
vals = {}
for p_name in param_name_list:
p_array = np.array(fluid.global_scope().find_var(p_name).get_tensor())
vals[p_name] = p_array
emb = vals["embedding_para"]
print("begin to save model to model_base")
np.savez("mode_base", **vals)
def train():
args = parse_args()
model_type = args.model_type
logger = logging.getLogger("lm")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
if args.enable_ce:
fluid.default_startup_program().random_seed = SEED
if args.log_path:
file_handler = logging.FileHandler(args.log_path)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
else:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.info('Running with args : {}'.format(args))
vocab_size = 10000
if model_type == "test":
num_layers = 1
batch_size = 2
hidden_size = 10
num_steps = 3
init_scale = 0.1
max_grad_norm = 5.0
epoch_start_decay = 1
max_epoch = 1
dropout = 0.0
lr_decay = 0.5
base_learning_rate = 1.0
elif model_type == "small":
num_layers = 2
batch_size = 20
hidden_size = 200
num_steps = 20
init_scale = 0.1
max_grad_norm = 5.0
epoch_start_decay = 4
max_epoch = 13
dropout = 0.0
lr_decay = 0.5
base_learning_rate = 1.0
elif model_type == "medium":
num_layers = 2
batch_size = 20
hidden_size = 650
num_steps = 35
init_scale = 0.05
max_grad_norm = 5.0
epoch_start_decay = 6
max_epoch = 39
dropout = 0.5
lr_decay = 0.8
base_learning_rate = 1.0
elif model_type == "large":
num_layers = 2
batch_size = 20
hidden_size = 1500
num_steps = 35
init_scale = 0.04
max_grad_norm = 10.0
epoch_start_decay = 14
max_epoch = 55
dropout = 0.65
lr_decay = 1.0 / 1.15
base_learning_rate = 1.0
else:
print("model type not support")
return
# Training process
loss, last_hidden, last_cell, feed_order = lm_model.lm_model(
hidden_size,
vocab_size,
batch_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale,
dropout=dropout)
# clone from default main program and use it as the validation program
main_program = fluid.default_main_program()
inference_program = fluid.default_main_program().clone(for_test=True)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm))
learning_rate = fluid.layers.create_global_var(
name="learning_rate",
shape=[1],
value=1.0,
dtype='float32',
persistable=True)
optimizer = fluid.optimizer.SGD(learning_rate=learning_rate)
optimizer.minimize(loss)
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
data_path = args.data_path
print("begin to load data")
raw_data = reader.ptb_raw_data(data_path)
print("finished load data")
train_data, valid_data, test_data, _ = raw_data
def prepare_input(batch, init_hidden, init_cell, epoch_id=0, with_lr=True):
x, y = batch
new_lr = base_learning_rate * (lr_decay**max(
epoch_id + 1 - epoch_start_decay, 0.0))
lr = np.ones((1), dtype='float32') * new_lr
res = {}
x = x.reshape((-1, num_steps, 1))
y = y.reshape((-1, 1))
res['x'] = x
res['y'] = y
res['init_hidden'] = init_hidden
res['init_cell'] = init_cell
if with_lr:
res['learning_rate'] = lr
return res
def eval(data):
# when eval the batch_size set to 1
eval_data_iter = reader.get_data_iter(data, 1, num_steps)
total_loss = 0.0
iters = 0
init_hidden = np.zeros((num_layers, 1, hidden_size), dtype='float32')
init_cell = np.zeros((num_layers, 1, hidden_size), dtype='float32')
for batch_id, batch in enumerate(eval_data_iter):
input_data_feed = prepare_input(
batch, init_hidden, init_cell, epoch_id, with_lr=False)
fetch_outs = exe.run(
inference_program,
feed=input_data_feed,
fetch_list=[loss.name, last_hidden.name, last_cell.name])
cost_train = np.array(fetch_outs[0])
init_hidden = np.array(fetch_outs[1])
init_cell = np.array(fetch_outs[2])
total_loss += cost_train
iters += num_steps
ppl = np.exp(total_loss / iters)
return ppl
# get train epoch size
batch_len = len(train_data) // batch_size
epoch_size = (batch_len - 1) // num_steps
log_interval = epoch_size // 10
total_time = 0.0
for epoch_id in range(max_epoch):
start_time = time.time()
print("epoch id", epoch_id)
train_data_iter = reader.get_data_iter(train_data, batch_size,
num_steps)
total_loss = 0
init_hidden = None
init_cell = None
#debug_para(fluid.framework.default_main_program(), parallel_executor)
total_loss = 0
iters = 0
init_hidden = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
for batch_id, batch in enumerate(train_data_iter):
input_data_feed = prepare_input(
batch, init_hidden, init_cell, epoch_id=epoch_id)
fetch_outs = exe.run(feed=input_data_feed,
fetch_list=[
loss.name, last_hidden.name,
last_cell.name, 'learning_rate'
])
cost_train = np.array(fetch_outs[0])
init_hidden = np.array(fetch_outs[1])
init_cell = np.array(fetch_outs[2])
lr = np.array(fetch_outs[3])
total_loss += cost_train
iters += num_steps
if batch_id > 0 and batch_id % log_interval == 0:
ppl = np.exp(total_loss / iters)
print("ppl ", batch_id, ppl[0], lr[0])
ppl = np.exp(total_loss / iters)
if epoch_id == 0 and ppl[0] > 1000:
# for bad init, after first epoch, the loss is over 1000
# no more need to continue
return
end_time = time.time()
total_time += end_time - start_time
print("train ppl", ppl[0])
if epoch_id == max_epoch - 1 and args.enable_ce:
print("lstm_language_model_duration\t%s" % (total_time / max_epoch))
print("lstm_language_model_loss\t%s" % ppl[0])
model_path = os.path.join("model_new/", str(epoch_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=main_program)
valid_ppl = eval(valid_data)
print("valid ppl", valid_ppl[0])
test_ppl = eval(test_data)
print("test ppl", test_ppl[0])
if __name__ == '__main__':
train()
......@@ -313,6 +313,15 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
return ave_loss, bleu_rouge
def l2_loss(train_prog):
param_list = train_prog.block(0).all_parameters()
para_sum = []
for para in param_list:
para_mul = fluid.layers.elementwise_mul(x=para, y=para, axis=0)
para_sum.append(fluid.layers.reduce_sum(input=para_mul, dim=None))
return fluid.layers.sums(para_sum) * 0.5
def train(logger, args):
logger.info('Load data_set and vocab...')
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
......@@ -351,24 +360,22 @@ def train(logger, args):
# build optimizer
if args.optim == 'sgd':
optimizer = fluid.optimizer.SGD(
learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
learning_rate=args.learning_rate)
elif args.optim == 'adam':
optimizer = fluid.optimizer.Adam(
learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
learning_rate=args.learning_rate)
elif args.optim == 'rprop':
optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
learning_rate=args.learning_rate)
else:
logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1)
optimizer.minimize(avg_cost)
if args.weight_decay > 0.0:
obj_func = avg_cost + args.weight_decay * l2_loss(main_program)
optimizer.minimize(obj_func)
else:
obj_func = avg_cost
optimizer.minimize(obj_func)
# initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
......@@ -411,7 +418,7 @@ def train(logger, args):
feed_data = batch_reader(batch_list, args)
fetch_outs = parallel_executor.run(
feed=list(feeder.feed_parallel(feed_data, dev_count)),
fetch_list=[avg_cost.name],
fetch_list=[obj_func.name],
return_numpy=False)
cost_train = np.array(fetch_outs[0]).mean()
total_num += args.batch_size * dev_count
......
......@@ -18,5 +18,5 @@ python run.py \
--max_p_len 500 \
--max_q_len 60 \
--max_a_len 200 \
--weight_decay 0.0 \
--weight_decay 0.0001 \
--drop_rate 0.2 $@\
......@@ -93,9 +93,13 @@ python -u train.py \
python train.py --help
```
更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams``TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置,如可以通过以下命令来训练 Transformer 论文中的 big model (如显存不够可适当减小 batch size 的值):
更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams``TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置,如可以通过以下命令来训练 Transformer 论文中的 big model (如显存不够可适当减小 batch size 的值,或设置 `max_length 200` 过滤过长的句子,或修改某些显存使用相关环境变量的值):
```sh
# 显存使用的比例,显存不足可适当增大,最大为1
export FLAGS_fraction_of_gpu_memory_to_use=1.0
# 显存清理的阈值,显存不足可适当减小,最小为0,为负数时不启用
export FLAGS_eager_delete_tensor_gb=0.8
python -u train.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
......@@ -115,18 +119,17 @@ python -u train.py \
```
有关这些参数更详细信息的请参考 `config.py` 中的注释说明。
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--divice CPU` 设置),训练速度相对较慢。在训练过程中,每隔一定 iteration 后(通过参数 `save_freq` 设置,默认为10000)保存模型到参数 `model_dir` 指定的目录,每个 epoch 结束后也会保存 checkpiont 到 `ckpt_dir` 指定的目录,每个 iteration 将打印如下的日志到标准输出:
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--divice CPU` 设置),训练速度相对较慢。在训练过程中,每隔一定 iteration 后(通过参数 `save_freq` 设置,默认为10000)保存模型到参数 `model_dir` 指定的目录,每个 epoch 结束后也会保存 checkpiont 到 `ckpt_dir` 指定的目录,每隔一定数目的 iteration (通过参数 `--fetch_steps` 设置,默认为100)将打印如下的日志到标准输出:
```txt
step_idx: 0, epoch: 0, batch: 0, avg loss: 11.059394, normalized loss: 9.682427, ppl: 63538.027344
step_idx: 1, epoch: 0, batch: 1, avg loss: 11.053112, normalized loss: 9.676146, ppl: 63140.144531
step_idx: 2, epoch: 0, batch: 2, avg loss: 11.054576, normalized loss: 9.677609, ppl: 63232.640625
step_idx: 3, epoch: 0, batch: 3, avg loss: 11.046638, normalized loss: 9.669671, ppl: 62732.664062
step_idx: 4, epoch: 0, batch: 4, avg loss: 11.030095, normalized loss: 9.653129, ppl: 61703.449219
step_idx: 5, epoch: 0, batch: 5, avg loss: 11.047491, normalized loss: 9.670525, ppl: 62786.230469
step_idx: 6, epoch: 0, batch: 6, avg loss: 11.044509, normalized loss: 9.667542, ppl: 62599.273438
step_idx: 7, epoch: 0, batch: 7, avg loss: 11.011090, normalized loss: 9.634124, ppl: 60541.859375
step_idx: 8, epoch: 0, batch: 8, avg loss: 10.985243, normalized loss: 9.608276, ppl: 58997.058594
step_idx: 9, epoch: 0, batch: 9, avg loss: 10.993434, normalized loss: 9.616467, ppl: 59482.292969
[2018-10-26 00:49:24,705 INFO train.py:536] step_idx: 0, epoch: 0, batch: 0, avg loss: 10.999878, normalized loss: 9.624138, ppl: 59866.832031
[2018-10-26 00:50:08,717 INFO train.py:545] step_idx: 100, epoch: 0, batch: 100, avg loss: 9.454134, normalized loss: 8.078394, ppl: 12760.809570, speed: 2.27 step/s
[2018-10-26 00:50:52,655 INFO train.py:545] step_idx: 200, epoch: 0, batch: 200, avg loss: 8.643907, normalized loss: 7.268166, ppl: 5675.458496, speed: 2.28 step/s
[2018-10-26 00:51:36,529 INFO train.py:545] step_idx: 300, epoch: 0, batch: 300, avg loss: 7.916654, normalized loss: 6.540914, ppl: 2742.579346, speed: 2.28 step/s
[2018-10-26 00:52:20,692 INFO train.py:545] step_idx: 400, epoch: 0, batch: 400, avg loss: 7.902879, normalized loss: 6.527138, ppl: 2705.058350, speed: 2.26 step/s
[2018-10-26 00:53:04,537 INFO train.py:545] step_idx: 500, epoch: 0, batch: 500, avg loss: 7.818271, normalized loss: 6.442531, ppl: 2485.604492, speed: 2.28 step/s
[2018-10-26 00:53:48,580 INFO train.py:545] step_idx: 600, epoch: 0, batch: 600, avg loss: 7.554341, normalized loss: 6.178601, ppl: 1909.012451, speed: 2.27 step/s
[2018-10-26 00:54:32,878 INFO train.py:545] step_idx: 700, epoch: 0, batch: 700, avg loss: 7.177765, normalized loss: 5.802025, ppl: 1309.977661, speed: 2.26 step/s
[2018-10-26 00:55:17,108 INFO train.py:545] step_idx: 800, epoch: 0, batch: 800, avg loss: 7.005494, normalized loss: 5.629754, ppl: 1102.674805, speed: 2.26 step/s
```
### 模型预测
......@@ -138,10 +141,9 @@ python -u infer.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--use_wordpiece False \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_199999.infer.model \
model_path trained_models/iter_100000.infer.model \
beam_size 4 \
max_out_len 255
```
......@@ -164,7 +166,7 @@ BLEU = 33.08, 64.2/39.2/26.4/18.5 (BP=0.994, ratio=0.994, hyp_len=61971, ref_len
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| BLEU | 26.05 | 28.75 | 33.27 |
| BLEU | 26.25 | 29.15 | 33.64 |
### 分布式训练
......
#!/bin/bash
set -x
unset http_proxy
unset https_proxy
#pserver
export TRAINING_ROLE=PSERVER
export PADDLE_PORT=30134
export PADDLE_PSERVERS=127.0.0.1
export PADDLE_IS_LOCAL=0
export PADDLE_INIT_TRAINER_COUNT=1
export POD_IP=127.0.0.1
export PADDLE_TRAINER_ID=0
export PADDLE_TRAINERS_NUM=1
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib64/:/usr/local/lib/:/workspace/brpc
export PYTHONPATH=$PYTHONPATH:/paddle/build/build_reader_RelWithDebInfo_gpu/python
#GLOG_v=7 GLOG_logtostderr=1
CUDA_VISIBLE_DEVICES=4,5,6,7 python -u train.py \
--src_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--trg_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--special_token '<s>' '<e>' '<unk>' \
--token_delimiter '\x01' \
--train_file_pattern 'cluster_test_data_en_fr/train/train.wordpiece.en-fr.0' \
--val_file_pattern 'cluster_test_data_en_fr/thirdparty/newstest2014.wordpiece.en-fr' \
--use_token_batch True \
--batch_size 3200 \
--sort_type pool \
--pool_size 200000 \
--local False > pserver.log 2>&1 &
pserver_pid=$(echo $!)
echo $pserver_pid
sleep 30s
#trainer
export TRAINING_ROLE=TRAINER
export PADDLE_PORT=30134
export PADDLE_PSERVERS=127.0.0.1
export PADDLE_IS_LOCAL=0
export PADDLE_INIT_TRAINER_COUNT=1
export POD_IP=127.0.0.1
export PADDLE_TRAINER_ID=0
export PADDLE_TRAINERS_NUM=1
CUDA_VISIBLE_DEVICES=4,5,6,7 python -u train.py \
--src_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--trg_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--special_token '<s>' '<e>' '<unk>' \
--token_delimiter '\x01' \
--train_file_pattern 'cluster_test_data_en_fr/train/train.wordpiece.en-fr.0' \
--val_file_pattern 'cluster_test_data_en_fr/thirdparty/newstest2014.wordpiece.en-fr' \
--use_token_batch True \
--batch_size 3200 \
--sort_type pool \
--pool_size 200000 \
--local False > trainer.log 2>&1 &
#sleep 80
#kill -9 $pserver_pid
......@@ -80,7 +80,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head])
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -99,7 +99,9 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]])
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
......@@ -122,8 +124,15 @@ def multi_head_attention(queries,
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
......@@ -523,8 +532,7 @@ def transformer(src_vocab_size,
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=layers.reshape(
predict, shape=[-1, trg_vocab_size]),
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
......@@ -637,6 +645,9 @@ def wrap_decoder(trg_vocab_size,
preprocess_cmd,
postprocess_cmd,
caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
if weight_sharing:
predict = layers.matmul(
x=dec_output,
......@@ -751,7 +762,6 @@ def fast_decode(
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=pre_enc_output,
caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
......
import argparse
import ast
import contextlib
import multiprocessing
import os
import six
......@@ -79,8 +80,7 @@ def parse_args():
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
"For EN-DE BPE data we provided, use spaces as token delimiter.")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
......@@ -98,9 +98,14 @@ def parse_args():
help="The iteration number to run in profiling.")
parser.add_argument(
"--use_parallel_exe",
type=bool,
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
"--profile_ops",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to profile operators.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -125,6 +130,8 @@ def parse_args():
def main(args):
train_prog = fluid.Program()
startup_prog = fluid.Program()
train_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
sum_cost, avg_cost, predict, token_num, pyreader = transformer(
......@@ -243,24 +250,33 @@ def main(args):
if args.use_py_reader:
pyreader.reset()
pyreader.start()
break
return reader_time, run_time
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', '/tmp/profile_file'):
yield
else:
yield
# start-up
init_flag = True
run(1)
run(5)
init_flag = False
# profiling
start = time.time()
# currently only support profiling on one device
with profiler.profiler('All', 'total', '/tmp/profile_file'):
with profile_context(args.profile_ops):
reader_time, run_time = run(args.iter_num)
end = time.time()
total_time = end - start
print("Total time: {0}, reader time: {1} s, run time: {2} s".format(
total_time, np.sum(reader_time), np.sum(run_time)))
print(
"Total time: {0}, reader time: {1} s, run time: {2} s, step number: {3}".
format(total_time, np.sum(reader_time), np.sum(run_time),
args.iter_num))
if __name__ == "__main__":
......
......@@ -297,9 +297,14 @@ class DataReader(object):
infos = self._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size], key=lambda x: x.max_len)
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
# concat batch
batches = []
......
import argparse
import ast
import copy
import logging
import multiprocessing
import os
import six
import sys
import time
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.transpiler.details import program_to_code
import reader
from config import *
......@@ -97,6 +101,11 @@ def parse_args():
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--update_method',
choices=("pserver", "nccl2"),
default="pserver",
help='Update method.')
parser.add_argument(
'--sync', type=ast.literal_eval, default=True, help="sync mode.")
parser.add_argument(
......@@ -115,6 +124,11 @@ def parse_args():
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--fetch_steps",
type=int,
default=100,
help="The frequency to fetch and print output.")
args = parser.parse_args()
# Append args related to dict
......@@ -131,6 +145,26 @@ def parse_args():
return args
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
current_endpoint):
assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
current_endpoint in worker_endpoints)
eps = copy.deepcopy(worker_endpoints)
eps.remove(current_endpoint)
nccl_id_var = startup_prog.global_block().create_var(
name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
startup_prog.global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs={
"endpoint": current_endpoint,
"endpoint_list": eps,
"trainer_id": trainer_id
})
return nccl_id_var
def pad_batch_data(insts,
pad_idx,
n_head,
......@@ -370,7 +404,7 @@ def test_context(exe, train_exe, dev_count):
TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader,
is_test=True)
test_prog = test_prog.clone(for_test=True)
test_data = prepare_data_generator(
args, is_test=True, count=dev_count, pyreader=pyreader)
......@@ -410,15 +444,25 @@ def test_context(exe, train_exe, dev_count):
return test
def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
token_num, predict, pyreader):
def train_loop(exe,
train_prog,
startup_prog,
dev_count,
sum_cost,
avg_cost,
token_num,
predict,
pyreader,
nccl2_num_trainers=1,
nccl2_trainer_id=0):
# Initialize the parameters.
if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
else:
print("init fluid.framework.default_startup_program")
logging.info("init fluid.framework.default_startup_program")
exe.run(startup_prog)
logging.info("begin reader")
train_data = prepare_data_generator(
args, is_test=False, count=dev_count, pyreader=pyreader)
......@@ -431,12 +475,16 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
# build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
logging.info("begin executor")
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
loss_name=avg_cost.name,
main_program=train_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
exec_strategy=exec_strategy,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
if args.val_file_pattern is not None:
test = test_context(exe, train_exe, dev_count)
......@@ -450,6 +498,8 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
step_idx = 0
init_flag = True
logging.info("begin train")
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
......@@ -464,25 +514,38 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
try:
feed_dict_list = prepare_feed_dict_list(data_generator,
init_flag, dev_count)
outs = train_exe.run(
fetch_list=[sum_cost.name, token_num.name],
fetch_list=[sum_cost.name, token_num.name]
if step_idx % args.fetch_steps == 0 else [],
feed=feed_dict_list)
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
# sum the cost from multi-devices
total_sum_cost = sum_cost_val.sum()
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
print("step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
if step_idx % int(TrainTaskConfig.
save_freq) == TrainTaskConfig.save_freq - 1:
if step_idx % args.fetch_steps == 0:
sum_cost_val, token_num_val = np.array(outs[0]), np.array(
outs[1])
# sum the cost from multi-devices
total_sum_cost = sum_cost_val.sum()
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.fetch_steps / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
......@@ -492,6 +555,7 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
os.path.join(TrainTaskConfig.model_dir,
"iter_" + str(step_idx) + ".infer.model"),
train_prog)
init_flag = False
batch_id += 1
step_idx += 1
......@@ -505,13 +569,13 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
# Validate and save the persistable.
if args.val_file_pattern is not None:
val_avg_cost, val_ppl = test()
print(
logging.info(
"epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
" consumed %fs" % (pass_id, val_avg_cost,
val_avg_cost - loss_normalizer, val_ppl,
time_consumed))
else:
print("epoch: %d, consumed %fs" % (pass_id, time_consumed))
logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
if not args.enable_ce:
fluid.io.save_persistables(
exe,
......@@ -531,7 +595,7 @@ def train(args):
is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0':
args.local = False
print(args)
logging.info(args)
if args.device == 'CPU':
TrainTaskConfig.use_gpu = False
......@@ -576,15 +640,21 @@ def train(args):
use_py_reader=args.use_py_reader,
is_test=False)
if args.local:
optimizer = None
if args.sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
logging.info("before adam")
with fluid.default_main_program()._lr_schedule_guard():
learning_rate = lr_decay * TrainTaskConfig.learning_rate
optimizer = fluid.optimizer.Adam(
learning_rate=lr_decay * TrainTaskConfig.learning_rate,
learning_rate=learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
elif args.sync == False:
else:
optimizer = fluid.optimizer.SGD(0.003)
optimizer.minimize(avg_cost)
......@@ -592,10 +662,32 @@ def train(args):
fluid.memory_optimize(train_prog)
if args.local:
print("local start_up:")
logging.info("local start_up:")
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
token_num, predict, pyreader)
else:
if args.update_method == "nccl2":
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
port = os.getenv("PADDLE_PORT")
worker_ips = os.getenv("PADDLE_TRAINERS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
trainers_num = len(worker_endpoints)
current_endpoint = os.getenv("POD_IP") + ":" + port
if trainer_id == 0:
logging.info("train_id == 0, sleep 60s")
time.sleep(60)
logging.info("trainers_num:{}".format(trainers_num))
logging.info("worker_endpoints:{}".format(worker_endpoints))
logging.info("current_endpoint:{}".format(current_endpoint))
append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
current_endpoint)
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader, trainers_num,
trainer_id)
return
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip...
eplist = []
......@@ -605,6 +697,13 @@ def train(args):
trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
logging.info("pserver_endpoints:{}".format(pserver_endpoints))
logging.info("current_endpoint:{}".format(current_endpoint))
logging.info("trainer_id:{}".format(trainer_id))
logging.info("pserver_ips:{}".format(pserver_ips))
logging.info("port:{}".format(port))
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id,
......@@ -614,32 +713,34 @@ def train(args):
startup_program=startup_prog)
if training_role == "PSERVER":
logging.info("distributed: pserver started")
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
print("need env SERVER_ENDPOINT")
logging.critical("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
print("psserver begin run")
with open('pserver_startup.desc', 'w') as f:
f.write(str(pserver_startup))
with open('pserver_prog.desc', 'w') as f:
f.write(str(pserver_prog))
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
logging.info("distributed: trainer started")
trainer_prog = t.get_trainer_program()
with open('trainer_prog.desc', 'w') as f:
f.write(str(trainer_prog))
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader)
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
logging.critical(
"environment var TRAINER_ROLE should be TRAINER os PSERVER")
exit(1)
if __name__ == "__main__":
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(
stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
args = parse_args()
train(args)
......@@ -17,12 +17,12 @@
## 简介,模型详解
在PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md)中对于命名实体识别任务有较详细的介绍,在本例中不再重复介绍。
在PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/legacy/sequence_tagging_for_ner/README.md)中对于命名实体识别任务有较详细的介绍,在本例中不再重复介绍。
在模型上,我们沿用了v2版本的模型结构,唯一区别是我们使用LSTM代替原始的RNN。
## 数据获取
完整数据的获取请参考PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md) 一节中的方式。本例的示例数据同样可以通过运行data/download.sh来获取。
完整数据的获取请参考PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/legacy/sequence_tagging_for_ner/README.md) 一节中的方式。本例的示例数据同样可以通过运行data/download.sh来获取。
## 训练
......
......@@ -87,8 +87,8 @@ def evaluate(epoch_id, exe, inference_program, dev_reader, test_reader, fetch_li
def train_and_evaluate(train_reader,
test_reader,
dev_reader,
test_reader,
network,
optimizer,
global_config,
......@@ -246,7 +246,10 @@ def main():
# use cuda or not
if not global_config.has_member('use_cuda'):
global_config.use_cuda = 'CUDA_VISIBLE_DEVICES' in os.environ
if 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES'] != '':
global_config.use_cuda = True
else:
global_config.use_cuda = False
global_config.list_config()
......
# TagSpace
以下是本例的简要目录结构及说明:
```text
.
├── README.md # 文档
├── train.py # 训练脚本
├── utils # 通用函数
├── small_train.txt # 小样本训练集
└── small_test.txt # 小样本测试集
```
## 简介
TagSpace模型的介绍可以参阅论文[#TagSpace: Semantic Embeddings from Hashtags](https://research.fb.com/publications/tagspace-semantic-embeddings-from-hashtags/),在本例中,我们实现了TagSpace的模型。
## 数据下载
[ag news dataset](https://github.com/mhjabreel/CharCNN/tree/master/data/ag_news_csv)
数据格式如下
```
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
```
## 训练
'--use_cuda 1' 表示使用gpu, 缺省表示使用cpu
GPU 环境
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file --use_cuda 1` 开始训练模型。
```
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.txt --use_cuda 1
```
CPU 环境
运行命令 `python train.py train_file test_file` 开始训练模型。
```
python train.py small_train.txt small_test.txt
```
## 未来工作
添加预测部分
添加多种负例采样方式
因为 它太大了无法显示 source diff 。你可以改为 查看blob
因为 它太大了无法显示 source diff 。你可以改为 查看blob
import os
import sys
import time
import six
import numpy as np
import math
import argparse
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
import paddle.fluid.layers.control_flow as cf
import paddle.fluid.layers.io as io
import time
import utils
SEED = 102
def parse_args():
parser = argparse.ArgumentParser("TagSpace benchmark.")
parser.add_argument('train_file')
parser.add_argument('test_file')
parser.add_argument('--use_cuda', help='whether use gpu')
args = parser.parse_args()
return args
def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=5, margin=0.1):
""" network definition """
text = io.data(name="text", shape=[1], lod_level=1, dtype='int64')
pos_tag = io.data(name="pos_tag", shape=[1], lod_level=1, dtype='int64')
neg_tag = io.data(name="neg_tag", shape=[1], lod_level=1, dtype='int64')
text_emb = nn.embedding(
input=text, size=[vocab_text_size, emb_dim], param_attr="text_emb")
pos_tag_emb = nn.embedding(
input=pos_tag, size=[vocab_tag_size, emb_dim], param_attr="tag_emb")
neg_tag_emb = nn.embedding(
input=neg_tag, size=[vocab_tag_size, emb_dim], param_attr="tag_emb")
conv_1d = fluid.nets.sequence_conv_pool(
input=text_emb,
num_filters=hid_dim,
filter_size=win_size,
act="tanh",
pool_type="max",
param_attr="cnn")
text_hid = fluid.layers.fc(input=conv_1d, size=emb_dim, param_attr="text_hid")
cos_pos = nn.cos_sim(pos_tag_emb, text_hid)
cos_neg = nn.cos_sim(neg_tag_emb, text_hid)
loss_part1 = nn.elementwise_sub(
tensor.fill_constant_batch_size_like(
input=cos_pos,
shape=[-1, 1],
value=margin,
dtype='float32'),
cos_pos)
loss_part2 = nn.elementwise_add(loss_part1, cos_neg)
loss_part3 = nn.elementwise_max(
tensor.fill_constant_batch_size_like(
input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'),
loss_part2)
avg_cost = nn.mean(loss_part3)
less = tensor.cast(cf.less_than(cos_neg, cos_pos), dtype='float32')
correct = nn.reduce_sum(less)
return text, pos_tag, neg_tag, avg_cost, correct, cos_pos
def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size,
pass_num, use_cuda, model_dir):
""" train network """
args = parse_args()
vocab_text_size = len(vocab_text)
vocab_tag_size = len(vocab_tag)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# Train program
text, pos_tag, neg_tag, avg_cost, correct, pos_cos = network(vocab_text_size, vocab_tag_size)
# Optimization to minimize lost
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=base_lr)
sgd_optimizer.minimize(avg_cost)
# Initialize executor
startup_program = fluid.default_startup_program()
loop_program = fluid.default_main_program()
exe = fluid.Executor(place)
exe.run(startup_program)
total_time = 0.0
for pass_idx in range(pass_num):
epoch_idx = pass_idx + 1
print("epoch_%d start" % epoch_idx)
t0 = time.time()
for batch_id, data in enumerate(train_reader()):
lod_text_seq = utils.to_lodtensor([dat[0] for dat in data], place)
lod_pos_tag = utils.to_lodtensor([dat[1] for dat in data], place)
lod_neg_tag = utils.to_lodtensor([dat[2] for dat in data], place)
loss_val, correct_val = exe.run(
loop_program,
feed={
"text": lod_text_seq,
"pos_tag": lod_pos_tag,
"neg_tag": lod_neg_tag},
fetch_list=[avg_cost, correct])
if batch_id % 10 == 0:
print("TRAIN --> pass: {} batch_id: {} avg_cost: {}, acc: {}"
.format(pass_idx, batch_id, loss_val,
float(correct_val) / batch_size))
t1 = time.time()
total_time += t1 - t0
print("epoch:%d num_steps:%d time_cost(s):%f" %
(epoch_idx, batch_id, total_time / epoch_idx))
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
feed_var_names = ["text", "pos_tag"]
fetch_vars = [pos_cos]
fluid.io.save_inference_model(save_dir ,feed_var_names, fetch_vars, exe)
print("finish training")
def train_net():
""" do training """
args = parse_args()
train_file = args.train_file
test_file = args.test_file
use_cuda = True if args.use_cuda else False
batch_size = 100
vocab_text, vocab_tag, train_reader, test_reader = utils.prepare_data(
train_file, test_file, batch_size=batch_size, buffer_size=batch_size*100, word_freq_threshold=0)
train(
train_reader=train_reader,
vocab_text=vocab_text,
vocab_tag=vocab_tag,
base_lr=0.01,
batch_size=batch_size,
pass_num=10,
use_cuda=use_cuda,
model_dir="model_dim10_2")
if __name__ == "__main__":
train_net()
import re
import sys
import collections
import six
import time
import numpy as np
import paddle.fluid as fluid
import paddle
import csv
def to_lodtensor(data, place):
""" convert to LODtensor """
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def prepare_data(train_filename,
test_filename,
batch_size,
neg_size=1,
buffer_size=1000,
word_freq_threshold=0,
enable_ce=False):
""" prepare the AG's News Topic Classification data """
print("start constuct word dict")
vocab_text = build_dict(2, word_freq_threshold, train_filename, test_filename)
vocab_tag = build_dict(0, word_freq_threshold, train_filename, test_filename)
print("construct word dict done\n")
train_reader = sort_batch(
paddle.reader.shuffle(
train(
train_filename, vocab_text, vocab_tag, buffer_size, data_type=DataType.SEQ),
buf_size=buffer_size),
batch_size, batch_size * 20)
test_reader = sort_batch(
test(
test_filename, vocab_text, vocab_tag, buffer_size, data_type=DataType.SEQ),
batch_size, batch_size * 20)
return vocab_text, vocab_tag, train_reader, test_reader
def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
"""
Create a batched reader.
:param reader: the data reader to read from.
:type reader: callable
:param batch_size: size of each mini-batch
:type batch_size: int
:param sort_group_size: size of partial sorted batch
:type sort_group_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader.
:rtype: callable
"""
def batch_reader():
r = reader()
b = []
for instance in r:
b.append(instance)
if len(b) == sort_group_size:
sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
b = []
c = []
for sort_i in sortl:
c.append(sort_i)
if (len(c) == batch_size):
yield c
c = []
if drop_last == False and len(b) != 0:
sortl = sorted(b, key=lambda x: len(x[0]), reverse=True)
c = []
for sort_i in sortl:
c.append(sort_i)
if (len(c) == batch_size):
yield c
c = []
# Batch size check
batch_size = int(batch_size)
if batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
return batch_reader
class DataType(object):
SEQ = 2
def word_count(column_num, input_file, word_freq=None):
"""
compute word count from corpus
"""
if word_freq is None:
word_freq = collections.defaultdict(int)
data_file = csv.reader(input_file)
for row in data_file:
for w in re.split(r'\W+',row[column_num].strip()):
word_freq[w]+= 1
return word_freq
def build_dict(column_num=2, min_word_freq=50, train_filename="", test_filename=""):
"""
Build a word dictionary from the corpus, Keys of the dictionary are words,
and values are zero-based IDs of these words.
"""
with open(train_filename) as trainf:
with open(test_filename) as testf:
word_freq = word_count(column_num, testf, word_count(column_num, trainf))
word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
return word_idx
def reader_creator(filename, text_idx, tag_idx, n, data_type):
def reader():
with open(filename) as input_file:
data_file = csv.reader(input_file)
for row in data_file:
text_raw = re.split(r'\W+', row[2].strip())
text = [text_idx.get(w) for w in text_raw]
tag_raw = re.split(r'\W+', row[0].strip())
pos_index = tag_idx.get(tag_raw[0])
pos_tag=[]
pos_tag.append(pos_index)
neg_tag=[]
max_iter = 100
now_iter = 0
sum_n = 0
while(sum_n < 1) :
now_iter += 1
if now_iter > max_iter:
print("error : only one class")
sys.exit(0)
rand_i = np.random.randint(0, len(tag_idx))
if rand_i != pos_index:
neg_index=rand_i
neg_tag.append(neg_index)
sum_n += 1
if n > 0 and len(text) > n: continue
yield text, pos_tag, neg_tag
return reader
def train(filename, text_idx, tag_idx, n, data_type=DataType.SEQ):
return reader_creator(filename, text_idx, tag_idx, n, data_type)
def test(filename, text_idx, tag_idx, n, data_type=DataType.SEQ):
return reader_creator(filename, text_idx, tag_idx, n, data_type)
# 基于DNN模型的点击率预估模型
## 介绍
本模型实现了下述论文中提出的DNN模型:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
## 运行环境
需要先安装PaddlePaddle Fluid,然后运行:
```shell
pip install -r requirements.txt
```
## 数据集
本文使用的是Kaggle公司举办的[展示广告竞赛](https://www.kaggle.com/c/criteo-display-ad-challenge/)中所使用的Criteo数据集。
每一行是一次广告展示的特征,第一列是一个标签,表示这次广告展示是否被点击。总共有39个特征,其中13个特征采用整型值,另外26个特征是类别类特征。测试集中是没有标签的。
下载数据集:
```bash
cd data && ./download.sh && cd ..
```
## 模型
本例子只实现了DeepFM论文中介绍的模型的DNN部分,DeepFM会在其他例子中给出。
## 数据准备
处理原始数据集,整型特征使用min-max归一化方法规范到[0, 1],类别类特征使用了one-hot编码。原始数据集分割成两部分:90%用于训练,其他10%用于训练过程中的验证。
## 训练
训练的命令行选项可以通过`python train.py -h`列出。
### 单机训练:
```bash
python train.py \
--train_data_path data/raw/train.txt \
2>&1 | tee train.log
```
训练到第1轮的第40000个batch后,测试的AUC为0.801178,误差(cost)为0.445196。
### 分布式训练
本地启动一个2 trainer 2 pserver的分布式训练任务,分布式场景下训练数据会按照trainer的id进行切分,保证trainer之间的训练数据不会重叠,提高训练效率
```bash
sh cluster_train.sh
```
## 预测
预测的命令行选项可以通过`python infer.py -h`列出。
对测试集进行预测:
```bash
python infer.py \
--model_path models/pass-0/ \
--data_path data/raw/valid.txt
```
注意:infer.py跑完最后输出的AUC才是整个预测文件的整体AUC。
## 在百度云上运行集群训练
1. 参考文档 [在百度云上启动Fluid分布式训练](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst) 在百度云上部署一个CPU集群。
1. 用preprocess.py处理训练数据生成train.txt。
1. 将train.txt切分成集群机器份,放到每台机器上。
1. 用上面的 `分布式训练` 中的命令行启动分布式训练任务.
# DNN for Click-Through Rate prediction
## Introduction
This model implements the DNN part proposed in the following paper:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
The DeepFm combines factorization machine and deep neural networks to model
both low order and high order feature interactions. For details of the
factorization machines, please refer to the paper [factorization
machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
## Environment
You should install PaddlePaddle Fluid first, and run:
```shell
pip install -r requirements.txt
```
## Dataset
This example uses Criteo dataset which was used for the [Display Advertising
Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)
hosted by Kaggle.
Each row is the features for an ad display and the first column is a label
indicating whether this ad has been clicked or not. There are 39 features in
total. 13 features take integer values and the other 26 features are
categorical features. For the test dataset, the labels are omitted.
Download dataset:
```bash
cd data && ./download.sh && cd ..
```
## Model
This Demo only implement the DNN part of the model described in DeepFM paper.
DeepFM model will be provided in other model.
## Data Preprocessing method
To preprocess the raw dataset, the integer features are clipped then min-max
normalized to [0, 1] and the categorical features are one-hot encoded. The raw
training dataset are splited such that 90% are used for training and the other
10% are used for validation during training. In reader.py, training data is the first
90% of data in train.txt, and validation data is the left.
## Train
The command line options for training can be listed by `python train.py -h`.
### Local Train:
```bash
python train.py \
--train_data_path data/raw/train.txt \
2>&1 | tee train.log
```
After training pass 1 batch 40000, the testing AUC is `0.801178` and the testing
cost is `0.445196`.
### Distributed Train
Run a 2 pserver 2 trainer distribute training on a single machine.
In distributed training setting, training data is splited by trainer_id, so that training data
do not overlap among trainers
```bash
sh cluster_train.sh
```
## Infer
The command line options for infering can be listed by `python infer.py -h`.
To make inference for the test dataset:
```bash
python infer.py \
--model_path models/ \
--data_path data/raw/train.txt
```
Note: The AUC value in the last log info is the total AUC for all test dataset. Here, train.txt is splited inside the reader.py so that validation data does not have overlap with training data.
## Train on Baidu Cloud
1. Please prepare some CPU machines on Baidu Cloud following the steps in [train_on_baidu_cloud](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst)
1. Prepare dataset using preprocess.py.
1. Split the train.txt to trainer_num parts and put them on the machines.
1. Run training with the cluster train using the command in `Distributed Train` above.
\ No newline at end of file
#!/bin/bash
# start pserver0
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
\ No newline at end of file
#!/bin/bash
wget --no-check-certificate https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz
tar zxf dac.tar.gz
rm -f dac.tar.gz
mkdir raw
mv ./*.txt raw/
import argparse
import logging
import numpy as np
# disable gpu training for this example
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
import reader
from network_conf import ctr_dnn_model
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--model_path',
type=str,
required=True,
help="The path of model parameters gz file")
parser.add_argument(
'--data_path',
type=str,
required=True,
help="The path of the dataset to infer")
parser.add_argument(
'--embedding_size',
type=int,
default=10,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help="The size for embedding layer (default:1000001)")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
return parser.parse_args()
def infer():
args = parse_args()
place = fluid.CPUPlace()
inference_scope = fluid.core.Scope()
dataset = reader.CriteoDataset(args.sparse_feature_dim)
test_reader = paddle.batch(dataset.test([args.data_path]), batch_size=args.batch_size)
startup_program = fluid.framework.Program()
test_program = fluid.framework.Program()
with fluid.framework.program_guard(test_program, startup_program):
loss, data_list, auc_var, batch_auc_var = ctr_dnn_model(args.embedding_size, args.sparse_feature_dim)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
with fluid.scope_guard(inference_scope):
[inference_program, _, fetch_targets] = fluid.io.load_inference_model(args.model_path, exe)
def set_zero(var_name):
param = inference_scope.var(var_name).get_tensor()
param_array = np.zeros(param._get_dims()).astype("int64")
param.set(param_array, place)
auc_states_names = ['_generated_var_2', '_generated_var_3']
for name in auc_states_names:
set_zero(name)
for batch_id, data in enumerate(test_reader()):
loss_val, auc_val = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=fetch_targets)
if batch_id % 100 == 0:
logger.info("TEST --> batch: {} loss: {} auc: {}".format(batch_id, loss_val/args.batch_size, auc_val))
if __name__ == '__main__':
infer()
import paddle.fluid as fluid
import math
dense_feature_dim = 13
def ctr_dnn_model(embedding_size, sparse_feature_dim):
dense_input = fluid.layers.data(
name="dense_input", shape=[dense_feature_dim], dtype='float32')
sparse_input_ids = [
fluid.layers.data(
name="C" + str(i), shape=[1], lod_level=1, dtype='int64')
for i in range(1, 27)
]
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors", initializer=fluid.initializer.Normal(scale=1/math.sqrt(sparse_feature_dim))))
sparse_embed_seq = map(embedding_layer, sparse_input_ids)
concated = fluid.layers.concat(sparse_embed_seq + [dense_input], axis=1)
fc1 = fluid.layers.fc(input=concated, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(scale=1/math.sqrt(concated.shape[1]))))
fc2 = fluid.layers.fc(input=fc1, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(scale=1/math.sqrt(fc1.shape[1]))))
fc3 = fluid.layers.fc(input=fc2, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(scale=1/math.sqrt(fc2.shape[1]))))
predict = fluid.layers.fc(input=fc3, size=2, act='softmax',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(scale=1/math.sqrt(fc3.shape[1]))))
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
data_list = [dense_input] + sparse_input_ids + [label]
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict, label=label, num_thresholds=2**12, slide_steps=20)
return avg_cost, data_list, auc_var, batch_auc_var
"""
Preprocess Criteo dataset. This dataset was used for the Display Advertising
Challenge (https://www.kaggle.com/c/criteo-display-ad-challenge).
"""
import os
import sys
import click
import random
import collections
# There are 13 integer features and 26 categorical features
continous_features = range(1, 14)
categorial_features = range(14, 40)
# Clip integer features. The clip point for each integer feature
# is derived from the 95% quantile of the total values in each feature
continous_clip = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
class CategoryDictGenerator:
"""
Generate dictionary for each of the categorical features
"""
def __init__(self, num_feature):
self.dicts = []
self.num_feature = num_feature
for i in range(0, num_feature):
self.dicts.append(collections.defaultdict(int))
def build(self, datafile, categorial_features, cutoff=0):
with open(datafile, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
for i in range(0, self.num_feature):
if features[categorial_features[i]] != '':
self.dicts[i][features[categorial_features[i]]] += 1
for i in range(0, self.num_feature):
self.dicts[i] = filter(lambda x: x[1] >= cutoff,
self.dicts[i].items())
self.dicts[i] = sorted(self.dicts[i], key=lambda x: (-x[1], x[0]))
vocabs, _ = list(zip(*self.dicts[i]))
self.dicts[i] = dict(zip(vocabs, range(1, len(vocabs) + 1)))
self.dicts[i]['<unk>'] = 0
def gen(self, idx, key):
if key not in self.dicts[idx]:
res = self.dicts[idx]['<unk>']
else:
res = self.dicts[idx][key]
return res
def dicts_sizes(self):
return map(len, self.dicts)
class ContinuousFeatureGenerator:
"""
Normalize the integer features to [0, 1] by min-max normalization
"""
def __init__(self, num_feature):
self.num_feature = num_feature
self.min = [sys.maxint] * num_feature
self.max = [-sys.maxint] * num_feature
def build(self, datafile, continous_features):
with open(datafile, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
for i in range(0, self.num_feature):
val = features[continous_features[i]]
if val != '':
val = int(val)
if val > continous_clip[i]:
val = continous_clip[i]
self.min[i] = min(self.min[i], val)
self.max[i] = max(self.max[i], val)
def gen(self, idx, val):
if val == '':
return 0.0
val = float(val)
return (val - self.min[idx]) / (self.max[idx] - self.min[idx])
@click.command("preprocess")
@click.option("--datadir", type=str, help="Path to raw criteo dataset")
@click.option("--outdir", type=str, help="Path to save the processed data")
def preprocess(datadir, outdir):
"""
All 13 integer features are normalized to continuous values and these continuous
features are combined into one vector with dimension of 13.
Each of the 26 categorical features are one-hot encoded and all the one-hot
vectors are combined into one sparse binary vector.
"""
dists = ContinuousFeatureGenerator(len(continous_features))
dists.build(os.path.join(datadir, 'train.txt'), continous_features)
dicts = CategoryDictGenerator(len(categorial_features))
dicts.build(
os.path.join(datadir, 'train.txt'), categorial_features, cutoff=200)
dict_sizes = dicts.dicts_sizes()
categorial_feature_offset = [0]
for i in range(1, len(categorial_features)):
offset = categorial_feature_offset[i - 1] + dict_sizes[i - 1]
categorial_feature_offset.append(offset)
random.seed(0)
# 90% of the data are used for training, and 10% of the data are used
# for validation.
with open(os.path.join(outdir, 'train.txt'), 'w') as out_train:
with open(os.path.join(outdir, 'valid.txt'), 'w') as out_valid:
with open(os.path.join(datadir, 'train.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i]])
continous_vals.append("{0:.6f}".format(val).rstrip('0')
.rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i]]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
label = features[0]
if random.randint(0, 9999) % 10 != 0:
out_train.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
else:
out_valid.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
with open(os.path.join(outdir, 'test.txt'), 'w') as out:
with open(os.path.join(datadir, 'test.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i] - 1])
continous_vals.append("{0:.6f}".format(val).rstrip('0')
.rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i] - 1]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
out.write('\t'.join([continous_vals, categorial_vals]) + '\n')
if __name__ == "__main__":
preprocess()
class Dataset:
def __init__(self):
pass
class CriteoDataset(Dataset):
def __init__(self, sparse_feature_dim):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
self.cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
self.hash_dim_ = sparse_feature_dim
# here, training data are lines with line_index < train_idx_
self.train_idx_ = 41256555
self.continuous_range_ = range(1, 14)
self.categorical_range_ = range(14, 40)
def _reader_creator(self, file_list, is_train, trainer_num, trainer_id):
def reader():
for file in file_list:
with open(file, 'r') as f:
line_idx = 0
for line in f:
line_idx += 1
if is_train and line_idx > self.train_idx_:
continue
elif not is_train and line_idx <= self.train_idx_:
continue
if trainer_id > 0 and line_idx % trainer_num != trainer_id:
continue
features = line.rstrip('\n').split('\t')
dense_feature = []
sparse_feature = []
for idx in self.continuous_range_:
if features[idx] == '':
dense_feature.append(0.0)
else:
dense_feature.append((float(features[idx]) - self.cont_min_[idx - 1]) / self.cont_diff_[idx - 1])
for idx in self.categorical_range_:
sparse_feature.append([hash("%d_%s" % (idx, features[idx])) % self.hash_dim_])
label = [int(features[0])]
yield [dense_feature] + sparse_feature + [label]
return reader
def train(self, file_list, trainer_num, trainer_id):
return self._reader_creator(file_list, True, trainer_num, trainer_id)
def test(self, file_list):
return self._reader_creator(file_list, False, -1)
def infer(self, file_list):
return self._reader_creator(file_list, False, -1)
from __future__ import print_function
import argparse
import logging
import os
# disable gpu training for this example
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
import reader
from network_conf import ctr_dnn_model
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
'--train_data_path',
type=str,
default='./data/raw/train.txt',
help="The path of training dataset")
parser.add_argument(
'--test_data_path',
type=str,
default='./data/raw/valid.txt',
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
parser.add_argument(
'--embedding_size',
type=int,
default=10,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--is_local',
type=int,
default=1,
help='Local train or distributed train (default: 1)')
# the following arguments is used for distributed train, if is_local == false, then you should set them
parser.add_argument(
'--role',
type=str,
default='pserver', # trainer or pserver
help='The path for model to store (default: models)')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The path for model to store (default: 127.0.0.1:6000)')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='The path for model to store (default: models)')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
return parser.parse_args()
def train_loop(args, train_program, data_list, loss, auc_var, batch_auc_var,
trainer_num, trainer_id):
dataset = reader.CriteoDataset(args.sparse_feature_dim)
train_reader = paddle.batch(
paddle.reader.shuffle(
dataset.train([args.train_data_path], trainer_num, trainer_id),
buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
data_name_list = [var.name for var in data_list]
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for pass_id in range(args.num_passes):
for batch_id, data in enumerate(train_reader()):
loss_val, auc_val, batch_auc_val = exe.run(
train_program,
feed=feeder.feed(data),
fetch_list=[loss, auc_var, batch_auc_var]
)
logger.info("TRAIN --> pass: {} batch: {} loss: {} auc: {}, batch_auc: {}"
.format(pass_id, batch_id, loss_val/args.batch_size, auc_val, batch_auc_val))
if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(batch_id)
if args.trainer_id == 0:
fluid.io.save_inference_model(model_dir, data_name_list, [loss, auc_var], exe)
model_dir = args.model_output_dir + '/pass-' + str(pass_id)
if args.trainer_id == 0:
fluid.io.save_inference_model(model_dir, data_name_list, [loss, auc_var], exe)
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
loss, data_list, auc_var, batch_auc_var = ctr_dnn_model(args.embedding_size, args.sparse_feature_dim)
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimizer.minimize(loss)
if args.is_local:
logger.info("run local training")
main_program = fluid.default_main_program()
train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var, 1, -1)
else:
logger.info("run dist training")
t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
logger.info("run pserver")
prog = t.get_pserver_program(args.current_endpoint)
startup = t.get_startup_program(args.current_endpoint, pserver_program=prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif args.role == "trainer":
logger.info("run trainer")
train_prog = t.get_trainer_program()
train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var,
args.trainers, args.trainer_id + 1)
if __name__ == '__main__':
train()
......@@ -17,7 +17,13 @@
## 简介
GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recurrent Neural Networks](https://arxiv.org/abs/1511.06939),在本例中,我们实现了GRU4REC的模型。
GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recurrent Neural Networks](https://arxiv.org/abs/1511.06939)
论文的贡献在于首次将RNN(GRU)运用于session-based推荐,相比传统的KNN和矩阵分解,效果有明显的提升。
论文的核心思想史在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。
session-based推荐应用场景非常广泛,比如用户的商品浏览、新闻点击、地点签到等序列数据。
## RSC15 数据下载及预处理
运行命令 下载RSC15官网数据集
......@@ -74,14 +80,16 @@ python convert_format.py
```
## 训练
GPU 环境 默认配置
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file` 开始训练模型。
```python
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.file
'--use_cuda 1' 表示使用gpu, 缺省表示使用cpu '--parallel 1' 表示使用多卡,缺省表示使用单卡
GPU 环境
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file --use_cuda 1` 开始训练模型。
```
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.txt --use_cuda 1
```
CPU 环境
运行命令 `python train.py train_file test_file` 开始训练模型。
```python
```
python train.py small_train.txt small_test.txt
```
......@@ -100,8 +108,8 @@ python train.py small_train.txt small_test.txt
base_lr=0.01, # base learning rate
batch_size=batch_size,
pass_num=10, # the number of passed for training
use_cuda=True, # whether to use GPU card
parallel=False, # whether to be parallel
use_cuda=use_cuda, # whether to use GPU card
parallel=parallel, # whether to be parallel
model_dir="model_recall20", # directory to save model
init_low_bound=-0.1, # uniform parameter initialization lower bound
init_high_bound=0.1) # uniform parameter initialization upper bound
......@@ -198,9 +206,9 @@ model saved in model_recall20/epoch_1
```
## 预测
运行命令 `CUDA_VISIBLE_DEVICES=0 python infer.py model_dir start_epoch last_epoch(inclusive) train_file test_file` 开始预测其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
运行命令 `CUDA_VISIBLE_DEVICES=0 python infer.py model_dir start_epoch last_epoch(inclusive) train_file test_file` 开始预测.其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
```python
CUDA_VISIBLE_DEVICES=0 python infer.py model 1 10 small_train.txt small_test.txt# prediction from epoch 1 to epoch 10 small_train.txt small_test.txt
CUDA_VISIBLE_DEVICES=0 python infer.py model 1 10 small_train.txt small_test.txt
```
## 预测结果示例
......
......@@ -17,7 +17,8 @@ def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument('train_file')
parser.add_argument('test_file')
parser.add_argument('--use_cuda', help='whether use gpu')
parser.add_argument('--parallel', help='whether parallel')
parser.add_argument(
'--enable_ce',
action='store_true',
......@@ -182,6 +183,9 @@ def train_net():
args = parse_args()
train_file = args.train_file
test_file = args.test_file
use_cuda = True if args.use_cuda else False
parallel = True if args.parallel else False
print("use_cuda:", use_cuda, "parallel:", parallel)
batch_size = 50
vocab, train_reader, test_reader = utils.prepare_data(
train_file, test_file,batch_size=batch_size * get_cards(args),\
......@@ -194,8 +198,8 @@ def train_net():
base_lr=0.01,
batch_size=batch_size,
pass_num=10,
use_cuda=True,
parallel=False,
use_cuda=use_cuda,
parallel=parallel,
model_dir="model_recall20",
init_low_bound=-0.1,
init_high_bound=0.1)
......
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v1.2.3
hooks:
- id: trailing-whitespace
\ No newline at end of file
# 个性化推荐中的多视角Simnet模型
## 介绍
在个性化推荐场景中,推荐系统给用户提供的项目(Item)列表通常是通过个性化的匹配模型计算出来的。在现实世界中,一个用户可能有很多个视角的特征,比如用户Id,年龄,项目的点击历史等。一个项目,举例来说,新闻资讯,也会有多种视角的特征比如新闻标题,新闻类别等。多视角Simnet模型是可以融合用户以及推荐项目的多个视角的特征并进行个性化匹配学习的一体化模型。这类模型在很多工业化的场景中都会被使用到,比如百度的Feed产品中。
## 数据集
目前,本项目使用机器生成的数据集来介绍多视角Simnet模型的概念,未来我们会逐渐加入真是世界中的数据集并在这个模型上进行效果验证。
## 模型
本项目的目标是提供一个在个性化匹配场景下利用Paddle搭建的模型。多视角Simnet模型包括多个编码器模块,每个编码器被用在不同的特征视角上。当前,项目中提供Bag-of-Embedding编码器,Temporal-Convolutional编码器,和Gated-Recurrent-Unit编码器。我们会逐渐加入稀疏特征场景下比较实用的编码器到这个项目中。模型的训练方法,当前采用的是Pairwise ranking模式进行训练,即针对一对具有关联的User-Item组合,随机实用一个Item作为负例进行排序学习。
## 训练
如下
如下命令行可以获得训练工具的具体选项,`python train.py -h`内容可以参考说明
```bash
python train.py
```
## 未来的工作
- 多种pairwise的损失函数会被加入到这个项目中。对于不同视角的特征,用户-项目之间的匹配关系可以使用不同的损失函数进行联合优化。整个模型会在真实数据中进行验证。
- 推理工具会被加入
- Parallel Executor选项会被加入
- 分布式训练能力会被加入
# Multi-view Simnet for Personalized recommendation
## Introduction
In personalized recommendation scenario, a user often is provided with several items from personalized interest matching model. In real world application, a user may have multiple views of features, say user-id, age, click-history of items, search queries. A item, e.g. news, may also have multiple views of features like news title, news category, images in news and so on. Multi-view Simnet is matching a model that combine users' and items' multiple views of features into one unified model. The model can be used in many industrial product like Baidu's feed news. The model is adapted from the paper A Multi-View Deep Learning(MV-DNN) Approach for Cross Domain User Modeling in Recommendation Systems, WWW 2015. The difference between our model and the MV-DNN is that we also consider multiple feature views of users.
## Dataset
Currently, synthetic dataset is provided for proof of concept and we aim to add more real world dataset in this project in the future.
## Model
This project aims to provide practical usage of Paddle in personalized matching scenario. The model provides several encoder modules for different views of features. Currently, Bag-of-Embedding encoder, Temporal-Convolutional encoder, Gated-Recurrent-Unit encoder are provided. We will add more practical encoder for sparse features commonly used in recommender systems. Training algorithms used in this model is pairwise ranking in that a negative item with multiple views will be sampled given a pair of positive user-item pair.
## Train
The command line options for training can be listed by `python train.py -h`
```bash
python train.py
```
## Future work
- Multiple types of pairwise loss will be added in this project. For different views of features between a user and an item, multiple losses will be supported. The model will be verified in real world dataset.
- infer will be added
- Parallel Executor will be added in this project
- Distributed Training will be added
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
import paddle.fluid.layers.control_flow as cf
import paddle.fluid.layers.io as io
class BowEncoder(object):
""" bow-encoder """
def __init__(self):
self.param_name = ""
def forward(self, emb):
return nn.sequence_pool(input=emb, pool_type='sum')
class CNNEncoder(object):
""" cnn-encoder"""
def __init__(self,
param_name="cnn.w",
win_size=3,
ksize=128,
act='tanh',
pool_type='max'):
self.param_name = param_name
self.win_size = win_size
self.ksize = ksize
self.act = act
self.pool_type = pool_type
def forward(self, emb):
return fluid.nets.sequence_conv_pool(
input=emb,
num_filters=self.ksize,
filter_size=self.win_size,
act=self.act,
pool_type=self.pool_type,
param_attr=str(self.param_name))
class GrnnEncoder(object):
""" grnn-encoder """
def __init__(self, param_name="grnn.w", hidden_size=128):
self.param_name = param_name
self.hidden_size = hidden_size
def forward(self, emb):
fc0 = nn.fc(
input=emb,
size=self.hidden_size * 3,
param_attr=str(str(self.param_name) + "_fc")
)
gru_h = nn.dynamic_gru(
input=fc0,
size=self.hidden_size,
is_reverse=False,
param_attr=str(self.param_name))
return nn.sequence_pool(input=gru_h, pool_type='max')
'''this is a very simple Encoder factory
most default argument values are used'''
class SimpleEncoderFactory(object):
def __init__(self):
pass
''' create an encoder through create function '''
def create(self, enc_type, enc_hid_size):
if enc_type == "bow":
bow_encode = BowEncoder()
return bow_encode
elif enc_type == "cnn":
cnn_encode = CNNEncoder(ksize=enc_hid_size)
return cnn_encode
elif enc_type == "gru":
rnn_encode = GrnnEncoder(hidden_size=enc_hid_size)
return rnn_encode
class MultiviewSimnet(object):
""" multi-view simnet """
def __init__(self, embedding_size, embedding_dim, hidden_size):
self.embedding_size = embedding_size
self.embedding_dim = embedding_dim
self.emb_shape = [self.embedding_size, self.embedding_dim]
self.hidden_size = hidden_size
self.margin = 0.1
def set_query_encoder(self, encoders):
self.query_encoders = encoders
def set_title_encoder(self, encoders):
self.title_encoders = encoders
def get_correct(self, x, y):
less = tensor.cast(cf.less_than(x, y), dtype='float32')
correct = nn.reduce_sum(less)
return correct
def train_net(self):
# input fields for query, pos_title, neg_title
q_slots = [
io.data(
name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
for i in range(len(self.query_encoders))
]
pt_slots = [
io.data(
name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
nt_slots = [
io.data(
name="nt%d" % i, shape=[1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
# lookup embedding for each slot
q_embs = [
nn.embedding(
input=query, size=self.emb_shape, param_attr="emb.w")
for query in q_slots
]
pt_embs = [
nn.embedding(
input=title, size=self.emb_shape, param_attr="emb.w")
for title in pt_slots
]
nt_embs = [
nn.embedding(
input=title, size=self.emb_shape, param_attr="emb.w")
for title in nt_slots
]
# encode each embedding field with encoder
q_encodes = [
self.query_encoders[i].forward(emb) for i, emb in enumerate(q_embs)
]
pt_encodes = [
self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
]
nt_encodes = [
self.title_encoders[i].forward(emb) for i, emb in enumerate(nt_embs)
]
# concat multi view for query, pos_title, neg_title
q_concat = nn.concat(q_encodes)
pt_concat = nn.concat(pt_encodes)
nt_concat = nn.concat(nt_encodes)
# projection of hidden layer
q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w')
pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w')
nt_hid = nn.fc(nt_concat, size=self.hidden_size, param_attr='t_fc.w')
# cosine of hidden layers
cos_pos = nn.cos_sim(q_hid, pt_hid)
cos_neg = nn.cos_sim(q_hid, nt_hid)
# pairwise hinge_loss
loss_part1 = nn.elementwise_sub(
tensor.fill_constant_batch_size_like(
input=cos_pos,
shape=[-1, 1],
value=self.margin,
dtype='float32'),
cos_pos)
loss_part2 = nn.elementwise_add(loss_part1, cos_neg)
loss_part3 = nn.elementwise_max(
tensor.fill_constant_batch_size_like(
input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'),
loss_part2)
avg_cost = nn.mean(loss_part3)
correct = self.get_correct(cos_neg, cos_pos)
return q_slots + pt_slots + nt_slots, avg_cost, correct
def pred_net(self, query_fields, pos_title_fields, neg_title_fields):
q_slots = [
io.data(
name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
for i in range(len(self.query_encoders))
]
pt_slots = [
io.data(
name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
# lookup embedding for each slot
q_embs = [
nn.embedding(
input=query, size=self.emb_shape, param_attr="emb.w")
for query in q_slots
]
pt_embs = [
nn.embedding(
input=title, size=self.emb_shape, param_attr="emb.w")
for title in pt_slots
]
# encode each embedding field with encoder
q_encodes = [
self.query_encoder[i].forward(emb) for i, emb in enumerate(q_embs)
]
pt_encodes = [
self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
]
# concat multi view for query, pos_title, neg_title
q_concat = nn.concat(q_encodes)
pt_concat = nn.concat(pt_encodes)
# projection of hidden layer
q_hid = nn.fc(q_concat, size=self.hidden_size, param_attr='q_fc.w')
pt_hid = nn.fc(pt_concat, size=self.hidden_size, param_attr='t_fc.w')
# cosine of hidden layers
cos = nn.cos_sim(q_hid, pt_hid)
return cos
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
class Dataset:
def __init__(self):
pass
class SyntheticDataset(Dataset):
def __init__(self, sparse_feature_dim, query_slot_num, title_slot_num):
# ids are randomly generated
self.ids_per_slot = 10
self.sparse_feature_dim = sparse_feature_dim
self.query_slot_num = query_slot_num
self.title_slot_num = title_slot_num
self.dataset_size = 10000
def _reader_creator(self, is_train):
def generate_ids(num, space):
return [random.randint(0, space - 1) for i in range(num)]
def reader():
for i in range(self.dataset_size):
query_slots = []
pos_title_slots = []
neg_title_slots = []
for i in range(self.query_slot_num):
qslot = generate_ids(self.ids_per_slot,
self.sparse_feature_dim)
query_slots.append(qslot)
for i in range(self.title_slot_num):
pt_slot = generate_ids(self.ids_per_slot,
self.sparse_feature_dim)
pos_title_slots.append(pt_slot)
if is_train:
for i in range(self.title_slot_num):
nt_slot = generate_ids(self.ids_per_slot,
self.sparse_feature_dim)
neg_title_slots.append(nt_slot)
yield query_slots + pos_title_slots + neg_title_slots
else:
yield query_slots + pos_title_slots
return reader
def train(self):
return self._reader_creator(True)
def valid(self):
return self._reader_creator(True)
def test(self):
return self._reader_creator(False)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import six
import numpy as np
import math
import argparse
import logging
import paddle.fluid as fluid
import paddle
import time
import reader as reader
from nets import MultiviewSimnet, SimpleEncoderFactory
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("multi-view simnet")
parser.add_argument("--train_file", type=str, help="Training file")
parser.add_argument("--valid_file", type=str, help="Validation file")
parser.add_argument(
"--epochs", type=int, default=10, help="Number of epochs for training")
parser.add_argument(
"--model_output_dir",
type=str,
default='model_output',
help="Model output folder")
parser.add_argument(
"--query_slots", type=int, default=1, help="Number of query slots")
parser.add_argument(
"--title_slots", type=int, default=1, help="Number of title slots")
parser.add_argument(
"--query_encoder",
type=str,
default="bow",
help="Encoder module for slot encoding")
parser.add_argument(
"--title_encoder",
type=str,
default="bow",
help="Encoder module for slot encoding")
parser.add_argument(
"--query_encode_dim",
type=int,
default=128,
help="Dimension of query encoder output")
parser.add_argument(
"--title_encode_dim",
type=int,
default=128,
help="Dimension of title encoder output")
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size for training")
parser.add_argument(
"--embedding_dim",
type=int,
default=128,
help="Default Dimension of Embedding")
parser.add_argument(
"--sparse_feature_dim",
type=int,
default=1000001,
help="Sparse feature hashing space"
"for index processing")
parser.add_argument(
"--hidden_size", type=int, default=128, help="Hidden dim")
return parser.parse_args()
def start_train(args):
dataset = reader.SyntheticDataset(args.sparse_feature_dim, args.query_slots,
args.title_slots)
train_reader = paddle.batch(
paddle.reader.shuffle(
dataset.train(), buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
factory = SimpleEncoderFactory()
query_encoders = [
factory.create(args.query_encoder, args.query_encode_dim)
for i in range(args.query_slots)
]
title_encoders = [
factory.create(args.title_encoder, args.title_encode_dim)
for i in range(args.title_slots)
]
m_simnet = MultiviewSimnet(args.sparse_feature_dim, args.embedding_dim,
args.hidden_size)
m_simnet.set_query_encoder(query_encoders)
m_simnet.set_title_encoder(title_encoders)
all_slots, avg_cost, correct = m_simnet.train_net()
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimizer.minimize(avg_cost)
startup_program = fluid.default_startup_program()
loop_program = fluid.default_main_program()
feeder = fluid.DataFeeder(feed_list=all_slots, place=place)
exe = fluid.Executor(place)
exe.run(startup_program)
for pass_id in range(args.epochs):
for batch_id, data in enumerate(train_reader()):
loss_val, correct_val = exe.run(loop_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, correct])
logger.info("TRAIN --> pass: {} batch_id: {} avg_cost: {}, acc: {}"
.format(pass_id, batch_id, loss_val,
float(correct_val) / args.batch_size))
fluid.io.save_inference_model(args.model_output_dir,
[val.name for val in all_slots],
[avg_cost, correct], exe)
def main():
args = parse_args()
start_train(args)
if __name__ == "__main__":
main()
# Sequence Semantic Retrieval Model
## Introduction
In news recommendation scenarios, different from traditional systems that recommend entertainment items such as movies or music, there are several new problems to solve.
- Very sparse user profile features exist that a user may login a news recommendation app anonymously and a user is likely to read a fresh news item.
- News are generated or disappeared very fast compare with movies or musics. Usually, there will be thousands of news generated in a news recommendation app. The Consumption of news is also fast since users care about newly happened things.
- User interests may change frequently in the news recommendation setting. The content of news will affect users' reading behaviors a lot even the category of the news does not belong to users' long-term interest. In news recommendation, reading behaviors are determined by both short-term interest and long-term interest of users.
[GRU4Rec](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/gru4rec) models a user's short-term and long-term interest by applying a gated-recurrent-unit on the user's reading history. The generalization ability of recurrent neural network captures users' similarity of reading sequences that alleviates the user profile sparsity problem. However, the paper of GRU4Rec operates on close domain of items that the model predicts which item a user will be interested in through classification method. In news recommendation, news items are dynamic through time that GRU4Rec model can not predict items that do not exist in training dataset.
Sequence Semantic Retrieval(SSR) Model shares the similar idea with Multi-Rate Deep Learning for Temporal Recommendation, SIGIR 2016. Sequence Semantic Retrieval Model has two components, one is the matching model part, the other one is the retrieval part.
- The idea of SSR is to model a user's personalized interest of an item through matching model structure, and the representation of a news item can be computed online even the news item does not exist in training dataset.
- With the representation of news items, we are able to build an vector indexing service online for news prediction and this is the retrieval part of SSR.
## Dataset
Dataset preprocessing follows the method of [GRU4Rec Project](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/gru4rec). Note that you should reuse scripts from GRU4Rec project for data preprocessing.
## Training
Before training, you should set PYTHONPATH environment
```
export PYTHONPATH=./models/fluid:$PYTHONPATH
```
The command line options for training can be listed by `python train.py -h`
``` bash
python train.py --train_file rsc15_train_tr_paddle.txt
```
## Build Index
TBA
## Retrieval
TBA
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
import paddle.fluid.layers.control_flow as cf
import paddle.fluid.layers.io as io
from PaddleRec.multiview_simnet.nets import BowEncoder
from PaddleRec.multiview_simnet.nets import GrnnEncoder
class PairwiseHingeLoss(object):
def __init__(self, margin=0.8):
self.margin = margin
def forward(self, pos, neg):
loss_part1 = nn.elementwise_sub(
tensor.fill_constant_batch_size_like(
input=pos,
shape=[-1, 1],
value=self.margin,
dtype='float32'),
pos)
loss_part2 = nn.elementwise_add(loss_part1, neg)
loss_part3 = nn.elementwise_max(
tensor.fill_constant_batch_size_like(
input=loss_part2,
shape=[-1, 1],
value=0.0,
dtype='float32'),
loss_part2)
return loss_part3
class SequenceSemanticRetrieval(object):
""" sequence semantic retrieval model """
def __init__(self, embedding_size, embedding_dim, hidden_size):
self.embedding_size = embedding_size
self.embedding_dim = embedding_dim
self.emb_shape = [self.embedding_size, self.embedding_dim]
self.hidden_size = hidden_size
self.user_encoder = GrnnEncoder(hidden_size=hidden_size)
self.item_encoder = BowEncoder()
self.pairwise_hinge_loss = PairwiseHingeLoss()
def get_correct(self, x, y):
less = tensor.cast(cf.less_than(x, y), dtype='float32')
correct = nn.reduce_sum(less)
return correct
def train(self):
user_data = io.data(
name="user", shape=[1], dtype="int64", lod_level=1
)
pos_item_data = io.data(
name="p_item", shape=[1], dtype="int64", lod_level=1
)
neg_item_data = io.data(
name="n_item", shape=[1], dtype="int64", lod_level=1
)
user_emb = nn.embedding(
input=user_data, size=self.emb_shape, param_attr="emb.item"
)
pos_item_emb = nn.embedding(
input=pos_item_data, size=self.emb_shape, param_attr="emb.item"
)
neg_item_emb = nn.embedding(
input=neg_item_data, size=self.emb_shape, param_attr="emb.item"
)
user_enc = self.user_encoder.forward(user_emb)
pos_item_enc = self.item_encoder.forward(pos_item_emb)
neg_item_enc = self.item_encoder.forward(neg_item_emb)
user_hid = nn.fc(
input=user_enc, size=self.hidden_size, param_attr='user.w', bias_attr="user.b"
)
pos_item_hid = nn.fc(
input=pos_item_enc, size=self.hidden_size, param_attr='item.w', bias_attr="item.b"
)
neg_item_hid = nn.fc(
input=neg_item_enc, size=self.hidden_size, param_attr='item.w', bias_attr="item.b"
)
cos_pos = nn.cos_sim(user_hid, pos_item_hid)
cos_neg = nn.cos_sim(user_hid, neg_item_hid)
hinge_loss = self.pairwise_hinge_loss.forward(cos_pos, cos_neg)
avg_cost = nn.mean(hinge_loss)
correct = self.get_correct(cos_neg, cos_pos)
return [user_data, pos_item_data, neg_item_data], \
pos_item_hid, neg_item_hid, avg_cost, correct
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
class Dataset:
def __init__(self):
pass
class Vocab:
def __init__(self):
pass
class YoochooseVocab(Vocab):
def __init__(self):
self.vocab = {}
self.word_array = []
def load(self, filelist):
idx = 0
for f in filelist:
with open(f, "r") as fin:
for line in fin:
group = line.strip().split()
for item in group:
if item not in self.vocab:
self.vocab[item] = idx
self.word_array.append(idx)
idx += 1
else:
self.word_array.append(self.vocab[item])
def get_vocab(self):
return self.vocab
def _get_word_array(self):
return self.word_array
class YoochooseDataset(Dataset):
def __init__(self, y_vocab):
self.vocab_size = len(y_vocab.get_vocab())
self.word_array = y_vocab._get_word_array()
self.vocab = y_vocab.get_vocab()
def sample_neg(self):
return random.randint(0, self.vocab_size - 1)
def sample_neg_from_seq(self, seq):
return seq[random.randint(0, len(seq) - 1)]
# TODO(guru4elephant): wait memory, should be improved
def sample_from_word_freq(self):
return self.word_array[random.randint(0, len(self.word_array) - 1)]
def _reader_creator(self, filelist, is_train):
def reader():
for f in filelist:
with open(f, 'r') as fin:
line_idx = 0
for line in fin:
ids = line.strip().split()
if len(ids) <= 1:
continue
conv_ids = [self.vocab[i] if i in self.vocab else 0 for i in ids]
# random select an index as boundary
# make ids before boundary as sequence
# make id next to boundary right as target
boundary = random.randint(1, len(ids) - 1)
src = conv_ids[:boundary]
pos_tgt = [conv_ids[boundary]]
if is_train:
neg_tgt = [self.sample_from_word_freq()]
yield [src, pos_tgt, neg_tgt]
else:
yield [src, pos_tgt]
return reader
def train(self, file_list):
return self._reader_creator(file_list, True)
def test(self, file_list):
return self._reader_creator(file_list, False)
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import logging
import paddle.fluid as fluid
import paddle
import reader as reader
from nets import SequenceSemanticRetrieval
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("sequence semantic retrieval")
parser.add_argument("--train_file", type=str, help="Training file")
parser.add_argument("--valid_file", type=str, help="Validation file")
parser.add_argument(
"--epochs", type=int, default=10, help="Number of epochs for training")
parser.add_argument(
"--model_output_dir",
type=str,
default='model_output',
help="Model output folder")
parser.add_argument(
"--sequence_encode_dim",
type=int,
default=128,
help="Dimension of sequence encoder output")
parser.add_argument(
"--matching_dim",
type=int,
default=128,
help="Dimension of hidden layer")
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size for training")
parser.add_argument(
"--embedding_dim",
type=int,
default=128,
help="Default Dimension of Embedding")
return parser.parse_args()
def start_train(args):
y_vocab = reader.YoochooseVocab()
y_vocab.load([args.train_file])
logger.info("Load yoochoose vocabulary size: {}".format(len(y_vocab.get_vocab())))
y_data = reader.YoochooseDataset(y_vocab)
train_reader = paddle.batch(
paddle.reader.shuffle(
y_data.train([args.train_file]), buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
ssr = SequenceSemanticRetrieval(
len(y_vocab.get_vocab()), args.embedding_dim, args.matching_dim
)
input_data, user_rep, item_rep, avg_cost, acc = ssr.train()
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimizer.minimize(avg_cost)
startup_program = fluid.default_startup_program()
loop_program = fluid.default_main_program()
data_list = [var.name for var in input_data]
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
exe = fluid.Executor(place)
exe.run(startup_program)
for pass_id in range(args.epochs):
for batch_id, data in enumerate(train_reader()):
loss_val, correct_val = exe.run(loop_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc])
logger.info("Train --> pass: {} batch_id: {} avg_cost: {}, acc: {}".
format(pass_id, batch_id, loss_val,
float(correct_val) / args.batch_size))
fluid.io.save_inference_model(args.model_output_dir,
[var.name for val in input_data],
[user_rep, item_rep, avg_cost, acc], exe)
def main():
args = parse_args()
start_train(args)
if __name__ == "__main__":
main()
......@@ -8,21 +8,21 @@ Fluid 模型库
在深度学习时代,图像分类的准确率大幅度提升,在图像分类任务中,我们向大家介绍了如何在经典的数据集ImageNet上,训练常用的模型,包括AlexNet、VGG、GoogLeNet、ResNet、Inception-v4、MobileNet、DPN(Dual
Path
Network)、SE-ResNeXt模型,也开源了\ `训练的模型 <https://github.com/PaddlePaddle/models/blob/develop/fluid/image_classification/README_cn.md#已有模型及其性能>`__\ 方便用户下载使用。同时提供了能够将Caffe模型转换为PaddlePaddle
Network)、SE-ResNeXt模型,也开源了\ `训练的模型 <https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleCV/image_classification/README_cn.md#已有模型及其性能>`__\ 方便用户下载使用。同时提供了能够将Caffe模型转换为PaddlePaddle
Fluid模型配置和参数文件的工具。
- `AlexNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
- `VGG <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
- `GoogleNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
- `AlexNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `VGG <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `GoogleNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `Residual
Network <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
- `Inception-v4 <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
- `MobileNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
Network <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `Inception-v4 <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `MobileNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `Dual Path
Network <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
- `SE-ResNeXt <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models>`__
Network <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `SE-ResNeXt <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models>`__
- `Caffe模型转换为Paddle
Fluid配置和模型文件工具 <https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid>`__
Fluid配置和模型文件工具 <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/caffe2fluid>`__
目标检测
--------
......@@ -36,8 +36,8 @@ COCO <http://cocodataset.org/#home>`__\ 数据训练通用物体检测模型,
开放环境中的检测人脸,尤其是小的、模糊的和部分遮挡的人脸也是一个具有挑战的任务。我们也介绍了如何基于 `WIDER FACE <http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/>`_ 数据训练百度自研的人脸检测PyramidBox模型,该算法于2018年3月份在WIDER FACE的多项评测中均获得 `第一名 <http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/WiderFace_Results.html>`_。
- `Single Shot MultiBox
Detector <https://github.com/PaddlePaddle/models/blob/develop/fluid/object_detection/README_cn.md>`__
- `Face Detector: PyramidBox <https://github.com/PaddlePaddle/models/tree/develop/fluid/face_detection/README_cn.md>`_
Detector <https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleCV/object_detection/README_cn.md>`__
- `Face Detector: PyramidBox <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/face_detection/README_cn.md>`_
图像语义分割
------------
......@@ -47,7 +47,7 @@ COCO <http://cocodataset.org/#home>`__\ 数据训练通用物体检测模型,
在图像语义分割任务中,我们介绍如何基于图像级联网络(Image Cascade
Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准确率和速度。
- `ICNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/icnet>`__
- `ICNet <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/icnet>`__
图像生成
-----------
......@@ -57,8 +57,8 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
在图像生成任务中,我们介绍了如何使用DCGAN和ConditioanlGAN来进行手写数字的生成,另外还介绍了用于风格迁移的CycleGAN.
- `DCGAN & ConditionalGAN <https://github.com/PaddlePaddle/models/tree/develop/fluid/gan/c_gan>`__
- `CycleGAN <https://github.com/PaddlePaddle/models/tree/develop/fluid/gan/cycle_gan>`__
- `DCGAN & ConditionalGAN <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/gan/c_gan>`__
- `CycleGAN <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/gan/cycle_gan>`__
场景文字识别
------------
......@@ -67,8 +67,8 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
在场景文字识别任务中,我们介绍如何将基于CNN的图像特征提取和基于RNN的序列翻译技术结合,免除人工定义特征,避免字符分割,使用自动学习到的图像特征,完成字符识别。当前,介绍了CRNN-CTC模型和基于注意力机制的序列到序列模型。
- `CRNN-CTC模型 <https://github.com/PaddlePaddle/models/tree/develop/fluid/ocr_recognition>`__
- `Attention模型 <https://github.com/PaddlePaddle/models/tree/develop/fluid/ocr_recognition>`__
- `CRNN-CTC模型 <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/ocr_recognition>`__
- `Attention模型 <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/ocr_recognition>`__
度量学习
......@@ -77,7 +77,7 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
度量学习也称作距离度量学习、相似度学习,通过学习对象之间的距离,度量学习能够用于分析对象时间的关联、比较关系,在实际问题中应用较为广泛,可应用于辅助分类、聚类问题,也广泛用于图像检索、人脸识别等领域。以往,针对不同的任务,需要选择合适的特征并手动构建距离函数,而度量学习可根据不同的任务来自主学习出针对特定任务的度量距离函数。度量学习和深度学习的结合,在人脸识别/验证、行人再识别(human Re-ID)、图像检索等领域均取得较好的性能,在这个任务中我们主要介绍了基于Fluid的深度度量学习模型,包含了三元组、四元组等损失函数。
- `Metric Learning <https://github.com/PaddlePaddle/models/tree/develop/fluid/metric_learning>`__
- `Metric Learning <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/metric_learning>`__
视频分类
......@@ -86,7 +86,7 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
视频分类是视频理解任务的基础,与图像分类不同的是,分类的对象不再是静止的图像,而是一个由多帧图像构成的、包含语音数据、包含运动信息等的视频对象,因此理解视频需要获得更多的上下文信息,不仅要理解每帧图像是什么、包含什么,还需要结合不同帧,知道上下文的关联信息。视频分类方法主要包含基于卷积神经网络、基于循环神经网络、或将这两者结合的方法。该任务中我们介绍基于Fluid的视频分类模型,目前包含Temporal Segment Network(TSN)模型,后续会持续增加更多模型。
- `TSN <https://github.com/PaddlePaddle/models/tree/develop/fluid/video_classification>`__
- `TSN <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/video_classification>`__
......@@ -122,7 +122,7 @@ RNN 结构的 NMT 得以应运而生,例如基于卷积神经网络 CNN
Attention 学习语言中的上下文依赖。相较于RNN/CNN,
这种结构在单层内计算复杂度更低、易于并行化、对长程依赖更易建模,最终在多种语言之间取得了最好的翻译效果。
- `Transformer <https://github.com/PaddlePaddle/models/blob/develop/fluid/neural_machine_translation/transformer/README_cn.md>`__
- `Transformer <https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleNLP/neural_machine_translation/transformer/README_cn.md>`__
强化学习
--------
......@@ -163,7 +163,7 @@ DQN 及其变体,并测试了它们在 Atari 游戏中的表现。
本例所开放的DAM (Deep Attention Matching Network)为百度自然语言处理部发表于ACL-2018的工作,用于检索式聊天机器人多轮对话中应答的选择。DAM受Transformer的启发,其网络结构完全基于注意力(attention)机制,利用栈式的self-attention结构分别学习不同粒度下应答和语境的语义表示,然后利用cross-attention获取应答与语境之间的相关性,在两个大规模多轮对话数据集上的表现均好于其它模型。
- `Deep Attention Matching Network <https://github.com/PaddlePaddle/models/tree/develop/fluid/deep_attention_matching_net>`__
- `Deep Attention Matching Network <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleNLP/deep_attention_matching_net>`__
AnyQ
----
......@@ -184,4 +184,18 @@ SimNet是百度自然语言处理部于2013年自主研发的语义匹配框架
百度阅读理解数据集是由百度自然语言处理部开源的一个真实世界数据集,所有的问题、原文都来源于实际数据(百度搜索引擎数据和百度知道问答社区),答案是由人类回答的。每个问题都对应多个答案,数据集包含200k问题、1000k原文和420k答案,是目前最大的中文MRC数据集。百度同时开源了对应的阅读理解模型,称为DuReader,采用当前通用的网络分层结构,通过双向attention机制捕捉问题和原文之间的交互关系,生成query-aware的原文表示,最终基于query-aware的原文表示通过point network预测答案范围。
- `DuReader in PaddlePaddle Fluid] <https://github.com/PaddlePaddle/models/blob/develop/fluid/machine_reading_comprehension/README.md>`__
- `DuReader in PaddlePaddle Fluid <https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleNLP/machine_reading_comprehension/README.md>`__
个性化推荐
-------
推荐系统在当前的互联网服务中正在发挥越来越大的作用,目前大部分电子商务系统、社交网络,广告推荐,搜索引擎,都不同程度的使用了各种形式的个性化推荐技术,帮助用户快速找到他们想要的信息。
在工业可用的推荐系统中,推荐策略一般会被划分为多个模块串联执行。以新闻推荐系统为例,存在多个可以使用深度学习技术的环节,例如新闻的自动化标注,个性化新闻召回,个性化匹配与排序等。PaddlePaddle对推荐算法的训练提供了完整的支持,并提供了多种模型配置供用户选择。
- `TagSpace <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/TagSpace>`_
- `GRU4Rec <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/gru4rec>`_
- `SequenceSemanticRetrieval <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/ssr>`_
- `DeepCTR <https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleRec/ctr/README.cn.md>`_
- `Multiview-Simnet <https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/multiview_simnet>`_
......@@ -6,18 +6,18 @@ Fluid 模型库
图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉中重要的基础问题,是物体检测、图像分割、物体跟踪、行为分析、人脸识别等其他高层视觉任务的基础,在许多领域都有着广泛的应用。如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。
在深度学习时代,图像分类的准确率大幅度提升,在图像分类任务中,我们向大家介绍了如何在经典的数据集ImageNet上,训练常用的模型,包括AlexNet、VGG、GoogLeNet、ResNet、Inception-v4、MobileNet、DPN(Dual Path Network)、SE-ResNeXt模型,也开源了[训练的模型](https://github.com/PaddlePaddle/models/blob/develop/fluid/image_classification/README_cn.md#已有模型及其性能) 方便用户下载使用。同时提供了能够将Caffe模型转换为PaddlePaddle
在深度学习时代,图像分类的准确率大幅度提升,在图像分类任务中,我们向大家介绍了如何在经典的数据集ImageNet上,训练常用的模型,包括AlexNet、VGG、GoogLeNet、ResNet、Inception-v4、MobileNet、DPN(Dual Path Network)、SE-ResNeXt模型,也开源了[训练的模型](https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleCV/image_classification/README_cn.md#已有模型及其性能) 方便用户下载使用。同时提供了能够将Caffe模型转换为PaddlePaddle
Fluid模型配置和参数文件的工具。
- [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [VGG](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [GoogleNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [MobileNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [Dual Path Network](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [SE-ResNeXt](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/models)
- [Caffe模型转换为Paddle Fluid配置和模型文件工具](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid)
- [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [VGG](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [GoogleNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [MobileNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [Dual Path Network](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [SE-ResNeXt](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/models)
- [Caffe模型转换为Paddle Fluid配置和模型文件工具](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification/caffe2fluid)
目标检测
--------
......@@ -30,9 +30,9 @@ Fluid模型配置和参数文件的工具。
Faster RCNN 是典型的两阶段目标检测器,相较于传统提取区域的方法,Faster RCNN中RPN网络通过共享卷积层参数大幅提高提取区域的效率,并提出高质量的候选区域。
- [Single Shot MultiBox Detector](https://github.com/PaddlePaddle/models/blob/develop/fluid/object_detection/README_cn.md)
- [Face Detector: PyramidBox](https://github.com/PaddlePaddle/models/tree/develop/fluid/face_detection/README_cn.md)
- [Faster RCNN](https://github.com/PaddlePaddle/models/tree/develop/fluid/faster_rcnn/README_cn.md)
- [Single Shot MultiBox Detector](https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleCV/object_detection/README_cn.md)
- [Face Detector: PyramidBox](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/face_detection/README_cn.md)
- [Faster RCNN](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/faster_rcnn/README_cn.md)
图像语义分割
------------
......@@ -42,7 +42,7 @@ Faster RCNN 是典型的两阶段目标检测器,相较于传统提取区域
在图像语义分割任务中,我们介绍如何基于图像级联网络(Image Cascade
Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准确率和速度。
- [ICNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/icnet)
- [ICNet](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/icnet)
图像生成
-----------
......@@ -52,8 +52,8 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
在图像生成任务中,我们介绍了如何使用DCGAN和ConditioanlGAN来进行手写数字的生成,另外还介绍了用于风格迁移的CycleGAN.
- [DCGAN & ConditionalGAN](https://github.com/PaddlePaddle/models/tree/develop/fluid/gan/c_gan)
- [CycleGAN](https://github.com/PaddlePaddle/models/tree/develop/fluid/gan/cycle_gan)
- [DCGAN & ConditionalGAN](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/gan/c_gan)
- [CycleGAN](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/gan/cycle_gan)
场景文字识别
------------
......@@ -62,8 +62,8 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
在场景文字识别任务中,我们介绍如何将基于CNN的图像特征提取和基于RNN的序列翻译技术结合,免除人工定义特征,避免字符分割,使用自动学习到的图像特征,完成字符识别。当前,介绍了CRNN-CTC模型和基于注意力机制的序列到序列模型。
- [CRNN-CTC模型](https://github.com/PaddlePaddle/models/tree/develop/fluid/ocr_recognition)
- [Attention模型](https://github.com/PaddlePaddle/models/tree/develop/fluid/ocr_recognition)
- [CRNN-CTC模型](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/ocr_recognition)
- [Attention模型](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/ocr_recognition)
度量学习
......@@ -72,7 +72,7 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
度量学习也称作距离度量学习、相似度学习,通过学习对象之间的距离,度量学习能够用于分析对象时间的关联、比较关系,在实际问题中应用较为广泛,可应用于辅助分类、聚类问题,也广泛用于图像检索、人脸识别等领域。以往,针对不同的任务,需要选择合适的特征并手动构建距离函数,而度量学习可根据不同的任务来自主学习出针对特定任务的度量距离函数。度量学习和深度学习的结合,在人脸识别/验证、行人再识别(human Re-ID)、图像检索等领域均取得较好的性能,在这个任务中我们主要介绍了基于Fluid的深度度量学习模型,包含了三元组、四元组等损失函数。
- [Metric Learning](https://github.com/PaddlePaddle/models/tree/develop/fluid/metric_learning)
- [Metric Learning](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/metric_learning)
视频分类
......@@ -81,7 +81,7 @@ Network,ICNet)进行语义分割,相比其他分割算法,ICNet兼顾了准
视频分类是视频理解任务的基础,与图像分类不同的是,分类的对象不再是静止的图像,而是一个由多帧图像构成的、包含语音数据、包含运动信息等的视频对象,因此理解视频需要获得更多的上下文信息,不仅要理解每帧图像是什么、包含什么,还需要结合不同帧,知道上下文的关联信息。视频分类方法主要包含基于卷积神经网络、基于循环神经网络、或将这两者结合的方法。该任务中我们介绍基于Fluid的视频分类模型,目前包含Temporal Segment Network(TSN)模型,后续会持续增加更多模型。
- [TSN](https://github.com/PaddlePaddle/models/tree/develop/fluid/video_classification)
- [TSN](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/video_classification)
语音识别
......@@ -101,7 +101,7 @@ Machine Translation, NMT)等阶段。在 NMT 成熟后,机器翻译才真正
本实例所实现的 Transformer 就是一个基于自注意力机制的机器翻译模型,其中不再有RNN或CNN结构,而是完全利用 Attention 学习语言中的上下文依赖。相较于RNN/CNN, 这种结构在单层内计算复杂度更低、易于并行化、对长程依赖更易建模,最终在多种语言之间取得了最好的翻译效果。
- [Transformer](https://github.com/PaddlePaddle/models/blob/develop/fluid/neural_machine_translation/transformer/README_cn.md)
- [Transformer](https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleNLP/neural_machine_translation/transformer/README_cn.md)
强化学习
--------
......@@ -135,7 +135,7 @@ Machine Translation, NMT)等阶段。在 NMT 成熟后,机器翻译才真正
本例所开放的DAM (Deep Attention Matching Network)为百度自然语言处理部发表于ACL-2018的工作,用于检索式聊天机器人多轮对话中应答的选择。DAM受Transformer的启发,其网络结构完全基于注意力(attention)机制,利用栈式的self-attention结构分别学习不同粒度下应答和语境的语义表示,然后利用cross-attention获取应答与语境之间的相关性,在两个大规模多轮对话数据集上的表现均好于其它模型。
- [Deep Attention Matching Network](https://github.com/PaddlePaddle/models/tree/develop/fluid/deep_attention_matching_net)
- [Deep Attention Matching Network](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleNLP/deep_attention_matching_net)
AnyQ
----
......@@ -153,4 +153,18 @@ SimNet是百度自然语言处理部于2013年自主研发的语义匹配框架
百度阅读理解数据集是由百度自然语言处理部开源的一个真实世界数据集,所有的问题、原文都来源于实际数据(百度搜索引擎数据和百度知道问答社区),答案是由人类回答的。每个问题都对应多个答案,数据集包含200k问题、1000k原文和420k答案,是目前最大的中文MRC数据集。百度同时开源了对应的阅读理解模型,称为DuReader,采用当前通用的网络分层结构,通过双向attention机制捕捉问题和原文之间的交互关系,生成query-aware的原文表示,最终基于query-aware的原文表示通过point network预测答案范围。
- [DuReader in PaddlePaddle Fluid](https://github.com/PaddlePaddle/models/blob/develop/fluid/machine_reading_comprehension/README.md)
- [DuReader in PaddlePaddle Fluid](https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleNLP/machine_reading_comprehension/README.md)
个性化推荐
-------
推荐系统在当前的互联网服务中正在发挥越来越大的作用,目前大部分电子商务系统、社交网络,广告推荐,搜索引擎,都不同程度的使用了各种形式的个性化推荐技术,帮助用户快速找到他们想要的信息。
在工业可用的推荐系统中,推荐策略一般会被划分为多个模块串联执行。以新闻推荐系统为例,存在多个可以使用深度学习技术的环节,例如新闻的自动化标注,个性化新闻召回,个性化匹配与排序等。PaddlePaddle对推荐算法的训练提供了完整的支持,并提供了多种模型配置供用户选择。
- [TagSpace](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/TagSpace)
- [GRU4Rec](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/gru4rec)
- [SequenceSemanticRetrieval](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/ssr)
- [DeepCTR](https://github.com/PaddlePaddle/models/blob/develop/fluid/PaddleRec/ctr/README.cn.md)
- [Multiview-Simnet](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/multiview_simnet)
Subproject commit 870651e257750f2c237f0b0bc9a27e5d062d1909
Subproject commit 4dbe7f7b0e76c188eb7f448d104f0165f0a12229
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
__all__ = ['parse_args', ]
BENCHMARK_MODELS = [
"ResNet50", "ResNet101", "ResNet152"
]
def parse_args():
parser = argparse.ArgumentParser('Distributed Image Classification Training.')
parser.add_argument(
'--model',
type=str,
choices=BENCHMARK_MODELS,
default='resnet',
help='The model to run benchmark with.')
parser.add_argument(
'--batch_size', type=int, default=32, help='The minibatch size.')
# args related to learning rate
parser.add_argument(
'--learning_rate', type=float, default=0.001, help='The learning rate.')
# TODO(wuyi): add "--use_fake_data" option back.
parser.add_argument(
'--skip_batch_num',
type=int,
default=5,
help='The first num of minibatch num to skip, for better performance test'
)
parser.add_argument(
'--iterations', type=int, default=80, help='The number of minibatches.')
parser.add_argument(
'--pass_num', type=int, default=100, help='The number of passes.')
parser.add_argument(
'--data_format',
type=str,
default='NCHW',
choices=['NCHW', 'NHWC'],
help='The data data_format, now only support NCHW.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type.')
parser.add_argument(
'--gpus',
type=int,
default=1,
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
# this option is available only for vgg and resnet.
parser.add_argument(
'--cpus',
type=int,
default=1,
help='If cpus > 1, will set ParallelExecutor to use multiple threads.')
parser.add_argument(
'--data_set',
type=str,
default='flowers',
choices=['cifar10', 'flowers', 'imagenet'],
help='Optional dataset for benchmark.')
parser.add_argument(
'--no_test',
action='store_true',
help='If set, do not test the testset during training.')
parser.add_argument(
'--memory_optimize',
action='store_true',
help='If set, optimize runtime memory before start.')
parser.add_argument(
'--update_method',
type=str,
default='local',
choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
parser.add_argument(
'--no_random',
action='store_true',
help='If set, keep the random seed and do not shuffle the data.')
parser.add_argument(
'--reduce_strategy',
type=str,
choices=['reduce', 'all_reduce'],
default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--data_dir',
type=str,
default="../data/ILSVRC2012",
help="The ImageNet dataset root dir."
)
args = parser.parse_args()
return args
./neural_machine_translation/rnn_search
\ No newline at end of file
./neural_machine_translation/transformer
\ No newline at end of file
......@@ -86,7 +86,7 @@ SSD使用一个卷积神经网络实现“端到端”的检测:输入为原
文件共两个字段,第一个字段为图像文件的相对路径,第二个字段为对应标注文件的相对路径。
### 预训练模型准备
下载预训练的VGG-16模型,我们提供了一个转换好的模型,下载模型[http://paddlepaddle.bj.bcebos.com/model_zoo/detection/ssd_model/vgg_model.tar.gz](http://paddlepaddle.bj.bcebos.com/model_zoo/detection/ssd_model/vgg_model.tar.gz),并将其放置路径为```vgg/vgg_model.tar.gz```。
下载预训练的VGG-16模型,我们提供了一个转换好的模型,下载模型[http://paddlemodels.bj.bcebos.com/v2/vgg_model.tar.gz](http://paddlemodels.bj.bcebos.com/v2/vgg_model.tar.gz),并将其放置路径为```vgg/vgg_model.tar.gz```。
### 模型训练
直接执行```python train.py```即可进行训练。需要注意本示例仅支持CUDA GPU环境,无法在CPU上训练,主要因为使用CPU训练速度很慢,实践中一般使用GPU来处理图像任务,这里实现采用硬编码方式使用cuDNN,不提供CPU版本。```train.py```的一些关键执行逻辑:
......
......@@ -77,7 +77,7 @@ The first field is the relative path of the image file, and the second field is
### To Use Pre-trained Model
We also provide a pre-trained model using VGG-16 with good performance. To use the model, download the file http://paddlepaddle.bj.bcebos.com/model_zoo/detection/ssd_model/vgg_model.tar.gz, and place it as ```vgg/vgg_model.tar.gz```
We also provide a pre-trained model using VGG-16 with good performance. To use the model, download the file http://paddlemodels.bj.bcebos.com/v2/vgg_model.tar.gz, and place it as ```vgg/vgg_model.tar.gz```.
### Training
Next, run ```python train.py``` to train the model. Note that this example only supports the CUDA GPU environment, and can not be trained using only CPU. This is mainly because the training is very slow using CPU only.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册