diff --git a/fluid/DeepQNetwork/DuelingDQN_agent.py b/fluid/DeepQNetwork/DuelingDQN_agent.py index d6224ef34a2cb1ec0a09d9ed2e87a2f89ab82142..271a767b7b5841cf1abe213fc477859e3cf5dd05 100644 --- a/fluid/DeepQNetwork/DuelingDQN_agent.py +++ b/fluid/DeepQNetwork/DuelingDQN_agent.py @@ -158,7 +158,8 @@ class DuelingDQNModel(object): for i, var in enumerate(policy_vars): sync_op = fluid.layers.assign(policy_vars[i], target_vars[i]) sync_ops.append(sync_op) - sync_program = sync_program.prune(sync_ops) + # The prune API is deprecated, please don't use it any more. + sync_program = sync_program._prune(sync_ops) return sync_program def act(self, state, train_or_test): diff --git a/fluid/DeepQNetwork/atari.py b/fluid/DeepQNetwork/atari.py index 46b7542019121b36e3e8923dba350e1d8a71fa34..ec793cba15ddc1c42986689eaad5773875a4ffde 100644 --- a/fluid/DeepQNetwork/atari.py +++ b/fluid/DeepQNetwork/atari.py @@ -9,7 +9,7 @@ import gym from gym import spaces from gym.envs.atari.atari_env import ACTION_MEANING -from ale_python_interface import ALEInterface +from atari_py import ALEInterface __all__ = ['AtariPlayer'] diff --git a/fluid/faster_rcnn/README.md b/fluid/faster_rcnn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..99438c6768197238d7889d8c63e1e8bdbfde0202 --- /dev/null +++ b/fluid/faster_rcnn/README.md @@ -0,0 +1,163 @@ +# Faster RCNN Objective Detection + +--- +## Table of Contents + +- [Installation](#installation) +- [Introduction](#introduction) +- [Data preparation](#data-preparation) +- [Training](#training) +- [Finetuning](#finetuning) +- [Evaluation](#evaluation) +- [Inference and Visualization](#inference-and-visualization) +- [Appendix](#appendix) + +## Installation + +Running sample code in this directory requires PaddelPaddle Fluid v.1.0.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle) and make an update. + +## Introduction + +[Faster Rcnn](https://arxiv.org/abs/1506.01497) is a typical two stage detector. The total framework of network can be divided into four parts, as shown below: +

+
+Faster RCNN model +

+ +1. Base conv layer。As a CNN objective dection, Faster RCNN extract feature maps using a basic convolutional network. The feature maps then can be shared by RPN and fc layers. This sampel uses [ResNet-50](https://arxiv.org/abs/1512.03385) as base conv layer. +2. Region Proposal Network (RPN)。RPN generates proposals for detection。This block generates anchors by a set of size and ratio and classifies anchors into fore-ground and back-ground by softmax. Then refine anchors to obtain more precise proposals using box regression. +3. RoI pooling。This layer takes feature maps and proposals as input. The proposals are mapped to feature maps and pooled to the same size. The output are sent to fc layers for classification and regression. +4. Detection layer。Using the output of roi pooling to compute the class and locatoin of each proposal in two fc layers. + +## Data preparation + +Train the model on [MS-COCO dataset](http://cocodataset.org/#download), download dataset as below: + + cd dataset/coco + ./download.sh + + +## Training + +After data preparation, one can start the training step by: + + python train.py \ + --max_size=1333 \ + --scales=800 \ + --batch_size=8 \ + --model_save_dir=output/ + +- Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train. +- For more help on arguments: + + python train.py --help + +**download the pre-trained model:** This sample provides Resnet-50 pre-trained model which is converted from Caffe. The model fuses the parameters in batch normalization layer. One can download pre-trained model as: + + sh ./pretrained/download.sh + +Set `pretrained_model` to load pre-trained model. In addition, this parameter is used to load trained model when finetuning as well. + +**data reader introduction:** + +* Data reader is defined in `reader.py`. +* Scaling the short side of all images to `scales`. If the long side is larger than `max_size`, then scaling the long side to `max_size`. +* In training stage, images are horizontally flipped. +* Images in the same batch can be padding to the same size. + +**model configuration:** + +* Use RoIPooling. +* NMS threshold=0.7. During training, pre\_nms=12000, post\_nms=2000; during test, pre\_nms=6000, post\_nms=1000. +* In generating proposal lables, fg\_fraction=0.25, fg\_thresh=0.5, bg\_thresh_hi=0.5, bg\_thresh\_lo=0.0. +* In rpn target assignment, rpn\_fg\_fraction=0.5, rpn\_positive\_overlap=0.7, rpn\_negative\_overlap=0.3. + +**training strategy:** + +* Use momentum optimizer with momentum=0.9. +* Weight decay is 0.0001. +* In first 500 iteration, the learning rate increases linearly from 0.00333 to 0.01. Then lr is decayed at 120000, 160000 iteration with multiplier 0.1, 0.01. The maximum iteration is 180000. +* Set the learning rate of bias to two times as global lr in non basic convolutional layers. +* In basic convolutional layers, parameters of affine layers and res body do not update. +* Use Nvidia Tesla V100 8GPU, total time for training is about 40 hours. + +Training result is shown as below: +

+
+Faster RCNN train loss +

+* Fluid all padding: Each image padding to 1333\*1333. +* Fluid minibatch padding: Images in one batch padding to the same size. This method is same as detectron. +* Fluid no padding: Images without padding. + +## Finetuning + +Finetuning is to finetune model weights in a specific task by loading pretrained weights. After initializing ```pretrained_model```, one can finetune a model as: + + python train.py + --max_size=1333 \ + --scales=800 \ + --pretrained_model=${path_to_pretrain_model} \ + --batch_size= 8\ + --model_save_dir=output/ + +## Evaluation + +Evaluation is to evaluate the performance of a trained model. This sample provides `eval_coco_map.py` which uses a COCO-specific mAP metric defined by [COCO committee](http://cocodataset.org/#detections-eval). To use `eval_coco_map.py` , [cocoapi](https://github.com/cocodataset/cocoapi) is needed. Install the cocoapi: + + # COCOAPI=/path/to/clone/cocoapi + git clone https://github.com/cocodataset/cocoapi.git $COCOAPI + cd $COCOAPI/PythonAPI + # if cython is not installed + pip install Cython + # Install into global site-packages + make install + # Alternatively, if you do not have permissions or prefer + # not to install the COCO API into global site-packages + python2 setup.py install --user + +`eval_coco_map.py` is the main executor for evalution, one can start evalution step by: + + python eval_coco_map.py \ + --dataset=coco2017 \ + --pretrained_mode=${path_to_pretrain_model} \ + --batch_size=1 \ + --nms_threshold=0.5 \ + --score_threshold=0.05 + +Evalutaion result is shown as below: +

+
+Faster RCNN mAP +

+ +| Model | Batch size | Max iteration | mAP | +| :------------------------------ | :------------: | :-------------------:|------: | +| Detectron | 8 | 180000 | 0.315 | +| Fluid minibatch padding | 8 | 180000 | 0.314 | +| Fluid all padding | 8 | 180000 | 0.308 | +| Fluid no padding |6 | 240000 | 0.317 | + +* Fluid all padding: Each image padding to 1333\*1333. +* Fluid minibatch padding: Images in one batch padding to the same size. This method is same as detectron. +* Fluid no padding: Images without padding. + +## Inference and Visualization + +Inference is used to get prediction score or image features based on trained models. `infer.py` is the main executor for inference, one can start infer step by: + + python infer.py \ + --dataset=coco2017 \ + --pretrained_model=${path_to_pretrain_model} \ + --image_path=data/COCO17/val2017/ \ + --image_name=000000000139.jpg \ + --draw_threshold=0.6 + +Visualization of infer result is shown as below: +

+ + + +
+Faster RCNN Visualization Examples +

diff --git a/fluid/faster_rcnn/README_cn.md b/fluid/faster_rcnn/README_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..cd238a86c8ec82f1519280c5552699c1b75438ec --- /dev/null +++ b/fluid/faster_rcnn/README_cn.md @@ -0,0 +1,147 @@ +# Faster RCNN 目标检测 + +--- +## 内容 + +- [安装](#安装) +- [简介](#简介) +- [数据准备](#数据准备) +- [模型训练](#模型训练) +- [参数微调](#参数微调) +- [模型评估](#模型评估) +- [模型推断及可视化](#模型推断及可视化) +- [附录](#附录) + +## 安装 + +在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.0.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。 + +## 简介 + +[Faster Rcnn](https://arxiv.org/abs/1506.01497) 是典型的两阶段目标检测器。如下图所示,整体网络可以分为4个主要内容: +

+
+Faster RCNN 目标检测模型 +

+ +1. 基础卷积层。作为一种卷积神经网络目标检测方法,Faster RCNN首先使用一组基础的卷积网络提取图像的特征图。特征图被后续RPN层和全连接层共享。本示例采用[ResNet-50](https://arxiv.org/abs/1512.03385)作为基础卷积层。 +2. 区域生成网络(RPN)。RPN网络用于生成候选区域(proposals)。该层通过一组固定的尺寸和比例得到一组锚点(anchors), 通过softmax判断锚点属于前景或者背景,再利用区域回归修正锚点从而获得精确的候选区域。 +3. RoI池化。该层收集输入的特征图和候选区域,将候选区域映射到特征图中并池化为统一大小的区域特征图,送入全连接层判定目标类别。 +4. 检测层。利用区域特征图计算候选区域的类别,同时再次通过区域回归获得检测框最终的精确位置。 + +## 数据准备 + +在[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集。 + + cd dataset/coco + ./download.sh + +## 模型训练 + +数据准备完毕后,可以通过如下的方式启动训练: + + python train.py \ + --max_size=1333 \ + --scales=800 \ + --batch_size=8 \ + --model_save_dir=output/ \ + --pretrained_model=${path_to_pretrain_model} + +- 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。 +- 可选参数见: + + python train.py --help + +**下载预训练模型:** 本示例提供Resnet-50预训练模型,该模性转换自Caffe,并对批标准化层(Batch Normalization Layer)进行参数融合。采用如下命令下载预训练模型: + + sh ./pretrained/download.sh + +通过初始化`pretrained_model` 加载预训练模型。同时在参数微调时也采用该设置加载已训练模型。 + +**数据读取器说明:** 数据读取器定义在reader.py中。所有图像将短边等比例缩放至`scales`,若长边大于`max_size`, 则再次将长边等比例缩放至`max_iter`。在训练阶段,对图像采用水平翻转。支持将同一个batch内的图像padding为相同尺寸。 + +**模型设置:** + +* 使用RoIPooling。 +* 训练过程pre\_nms=12000, post\_nms=2000,测试过程pre\_nms=6000, post\_nms=1000。nms阈值为0.7。 +* RPN网络得到labels的过程中,fg\_fraction=0.25,fg\_thresh=0.5,bg\_thresh_hi=0.5,bg\_thresh\_lo=0.0 +* RPN选择anchor时,rpn\_fg\_fraction=0.5,rpn\_positive\_overlap=0.7,rpn\_negative\_overlap=0.3 + + +下图为模型训练结果: +

+
+Faster RCNN 训练loss +

+* Fluid all padding: 每张图像填充为1333\*1333大小。 +* Fluid minibatch padding: 同一个batch内的图像填充为相同尺寸。该方法与detectron处理相同。 +* Fluid no padding: 不对图像做填充处理。 + +**训练策略:** + +* 采用momentum优化算法训练Faster RCNN,momentum=0.9。 +* 权重衰减系数为0.0001,前500轮学习率从0.00333线性增加至0.01。在120000,160000轮时使用0.1,0.01乘子进行学习率衰减,最大训练180000轮。 +* 非基础卷积层卷积bias学习率为整体学习率2倍。 +* 基础卷积层中,affine_layers参数不更新,res2层参数不更新。 +* 使用Nvidia Tesla V100 8卡并行,总共训练时长大约40小时。 + +## 模型评估 + +模型评估是指对训练完毕的模型评估各类性能指标。本示例采用[COCO官方评估](http://cocodataset.org/#detections-eval),使用前需要首先下载[cocoapi](https://github.com/cocodataset/cocoapi): + + # COCOAPI=/path/to/clone/cocoapi + git clone https://github.com/cocodataset/cocoapi.git $COCOAPI + cd $COCOAPI/PythonAPI + # if cython is not installed + pip install Cython + # Install into global site-packages + make install + # Alternatively, if you do not have permissions or prefer + # not to install the COCO API into global site-packages + python2 setup.py install --user + +`eval_coco_map.py`是评估模块的主要执行程序,调用示例如下: + + python eval_coco_map.py \ + --dataset=coco2017 \ + --pretrained_mode=${path_to_pretrain_model} \ + --batch_size=1 \ + --nms_threshold=0.5 \ + --score_threshold=0.05 + +下图为模型评估结果: +

+
+Faster RCNN mAP +

+ +| 模型 | 批量大小 | 迭代次数 | mAP | +| :------------------------------ | :------------: | :------------------: |------: | +| Detectron | 8 | 180000 | 0.315 | +| Fluid minibatch padding | 8 | 180000 | 0.314 | +| Fluid all padding | 8 | 180000 | 0.308 | +| Fluid no padding |6 | 240000 | 0.317 | + +* Fluid all padding: 每张图像填充为1333\*1333大小。 +* Fluid minibatch padding: 同一个batch内的图像填充为相同尺寸。该方法与detectron处理相同。 +* Fluid no padding: 不对图像做填充处理。 + +## 模型推断及可视化 + +模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下: + + python infer.py \ + --dataset=coco2017 \ + --pretrained_model=${path_to_pretrain_model} \ + --image_path=data/COCO17/val2017/ \ + --image_name=000000000139.jpg \ + --draw_threshold=0.6 + +下图为模型可视化预测结果: +

+ + + +
+Faster RCNN 预测可视化 +

diff --git a/fluid/faster_rcnn/image/000000000139.jpg b/fluid/faster_rcnn/image/000000000139.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c83a2cc4a9a2f13534d81f0c4ede78ae32c58cb Binary files /dev/null and b/fluid/faster_rcnn/image/000000000139.jpg differ diff --git a/fluid/faster_rcnn/image/000000127517.jpg b/fluid/faster_rcnn/image/000000127517.jpg new file mode 100644 index 0000000000000000000000000000000000000000..23d30251a5e386137b5881f4af48072abffad8dd Binary files /dev/null and b/fluid/faster_rcnn/image/000000127517.jpg differ diff --git a/fluid/faster_rcnn/image/000000203864.jpg b/fluid/faster_rcnn/image/000000203864.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f16ce4e05004404ff7353a8096318fd121e835a1 Binary files /dev/null and b/fluid/faster_rcnn/image/000000203864.jpg differ diff --git a/fluid/faster_rcnn/image/000000515077.jpg b/fluid/faster_rcnn/image/000000515077.jpg new file mode 100644 index 0000000000000000000000000000000000000000..61df889539b72f9b0a0b36c5731ff660a0955c46 Binary files /dev/null and b/fluid/faster_rcnn/image/000000515077.jpg differ diff --git a/fluid/faster_rcnn/image/Faster_RCNN.jpg b/fluid/faster_rcnn/image/Faster_RCNN.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68ea75863902094337165d50d8ec0960264d9e25 Binary files /dev/null and b/fluid/faster_rcnn/image/Faster_RCNN.jpg differ diff --git a/fluid/faster_rcnn/image/mAP.jpg b/fluid/faster_rcnn/image/mAP.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9202f8f1a06200c2e9c111eac562c2d59b460cb9 Binary files /dev/null and b/fluid/faster_rcnn/image/mAP.jpg differ diff --git a/fluid/faster_rcnn/image/train_loss.jpg b/fluid/faster_rcnn/image/train_loss.jpg new file mode 100644 index 0000000000000000000000000000000000000000..75d4b0a5f02aacd8fb3a0dfa7a62e60fac9ccb6b Binary files /dev/null and b/fluid/faster_rcnn/image/train_loss.jpg differ diff --git a/fluid/faster_rcnn/pretrained/download.sh b/fluid/faster_rcnn/pretrained/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..267247d5683d8bc9323d201d893419b3355f7283 --- /dev/null +++ b/fluid/faster_rcnn/pretrained/download.sh @@ -0,0 +1,8 @@ +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd "$DIR" + +# Download the data. +echo "Downloading..." +wget http://paddlemodels.bj.bcebos.com/faster_rcnn/imagenet_resnet50_fusebn.tar.gz +echo "Extracting..." +tar -xf imagenet_resnet50_fusebn.tar.gz diff --git a/fluid/faster_rcnn/reader.py b/fluid/faster_rcnn/reader.py index f134230c73b5b1b79a64cd99aaacffa775cad234..986db9888d004bf625593c641aa6cec71a8d60f3 100644 --- a/fluid/faster_rcnn/reader.py +++ b/fluid/faster_rcnn/reader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -150,6 +150,8 @@ def coco(settings, else: for roidb in roidbs: + if settings.image_name not in roidb['image']: + continue im, im_info, im_id = roidb_reader(roidb, mode) batch_out = [(im, im_info, im_id)] yield batch_out diff --git a/fluid/gan/cycle_gan/README.md b/fluid/gan/cycle_gan/README.md index c2c22ec8df7cb98cb0d76682e4db6811c90393b5..6520c123f80f423366287ea53a36a3969d1a73c9 100644 --- a/fluid/gan/cycle_gan/README.md +++ b/fluid/gan/cycle_gan/README.md @@ -21,21 +21,23 @@ TODO horse2zebra训练集包含1069张野马图片,1336张斑马图片。测试集包含121张野马图片和141张斑马图片。 -数据下载处理完毕后,并组织为以下路径: +数据下载处理完毕后,并组织为以下路径结构: ``` -horse2zebra/ -|-- testA -|-- testA.txt -|-- testB -|-- testB.txt -|-- trainA -|-- trainA.txt -|-- trainB -`-- trainB.txt +data +|-- horse2zebra +| |-- testA +| |-- testA.txt +| |-- testB +| |-- testB.txt +| |-- trainA +| |-- trainA.txt +| |-- trainB +| `-- trainB.txt + ``` -以上数据文件中,‘testA’为存放野马测试图片的文件夹,‘testB’为存放斑马测试图片的文件夹,'testA.txt'和'testB.txt'分别为野马和斑马测试图片路径列表文件,格式如下: +以上数据文件中,`data`文件夹需要放在训练脚本`train.py`同级目录下。`testA`为存放野马测试图片的文件夹,`testB`为存放斑马测试图片的文件夹,`testA.txt`和`testB.txt`分别为野马和斑马测试图片路径列表文件,格式如下: ``` testA/n02381460_9243.jpg @@ -53,7 +55,7 @@ testA/n02381460_9245.jpg 在GPU单卡上训练: ``` -env CUDA_VISIABLE_DEVICES=0 python train.py +env CUDA_VISIBLE_DEVICES=0 python train.py ``` 执行`python train.py --help`可查看更多使用方式和参数详细说明。 @@ -72,7 +74,7 @@ env CUDA_VISIABLE_DEVICES=0 python train.py ``` env CUDA_VISIBLE_DEVICE=0 python infer.py \ - --model_path="models/1" --input="./data/inputA/*" \ + --init_model="models/1" --input="./data/inputA/*" \ --output="./output" ``` diff --git a/fluid/image_classification/dist_train/README.md b/fluid/image_classification/dist_train/README.md index 02bbea17f423fe2e16fd3115058ba92805a313ab..282a026acf1ee6d5b1c17aa05a2a8f734047c006 100644 --- a/fluid/image_classification/dist_train/README.md +++ b/fluid/image_classification/dist_train/README.md @@ -52,7 +52,7 @@ In this example, we launched 4 parameter server instances and 4 trainer instance 1. launch trainer process ``` python - PADDLE_TRAINING_ROLE=PSERVER \ + PADDLE_TRAINING_ROLE=TRAINER \ PADDLE_TRAINERS=4 \ PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ PADDLE_TRAINER_ID=0 \ @@ -110,4 +110,4 @@ Training acc1 curves ### Performance -TBD \ No newline at end of file +TBD diff --git a/fluid/image_classification/dist_train/dist_train.py b/fluid/image_classification/dist_train/dist_train.py index 8f3953d1320441030c3cdee899ec85ffc8f8f401..127aff6e527ded5475cbe7d785d268baa6df43d8 100644 --- a/fluid/image_classification/dist_train/dist_train.py +++ b/fluid/image_classification/dist_train/dist_train.py @@ -22,6 +22,7 @@ import numpy as np import paddle import paddle.fluid as fluid import paddle.fluid.core as core +import six import sys sys.path.append("..") import models @@ -172,7 +173,7 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog): def test_parallel(exe, test_args, args, test_prog, feeder): acc_evaluators = [] - for i in xrange(len(test_args[2])): + for i in six.moves.xrange(len(test_args[2])): acc_evaluators.append(fluid.metrics.Accuracy()) to_fetch = [v.name for v in test_args[2]] @@ -291,7 +292,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, def print_arguments(args): print('----------- Configuration Arguments -----------') - for arg, value in sorted(vars(args).iteritems()): + for arg, value in sorted(six.iteritems(vars(args))): print('%s: %s' % (arg, value)) print('------------------------------------------------') @@ -307,7 +308,7 @@ def print_paddle_envs(): print('----------- Configuration envs -----------') for k in os.environ: if "PADDLE_" in k: - print "ENV %s:%s" % (k, os.environ[k]) + print("ENV %s:%s" % (k, os.environ[k])) print('------------------------------------------------') diff --git a/fluid/image_classification/reader.py b/fluid/image_classification/reader.py index 50be1cdef5ff3ad612d4d447a87174a767867a02..639b0b01200e3d81c57d75e560d6911f3e74b710 100644 --- a/fluid/image_classification/reader.py +++ b/fluid/image_classification/reader.py @@ -140,7 +140,7 @@ def _reader_creator(file_list, # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) - per_node_lines = len(full_lines) / trainer_count + per_node_lines = len(full_lines) // trainer_count lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) * per_node_lines] print( diff --git a/fluid/image_classification/train.py b/fluid/image_classification/train.py index bfc5f8b1412a11606d54b020f29bef969bae2a62..238c322bea9abf1ce086a0228b491e82cb69ae45 100644 --- a/fluid/image_classification/train.py +++ b/fluid/image_classification/train.py @@ -33,7 +33,7 @@ add_arg('lr', float, 0.1, "set learning rate.") add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.") -add_arg('data_dir' str, "./data/ILSVRC2012", "The ImageNet dataset root dir.") +add_arg('data_dir', str, "./data/ILSVRC2012", "The ImageNet dataset root dir.") # yapf: enable model_list = [m for m in dir(models) if "__" not in m] diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 26ce487a925636d68a75c031416486456086c1b6..2d0e5accca7747ee30c861a9f711462b9c92fa35 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -171,7 +171,7 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece): ]) # This is used here to set dropout to the test mode. - infer_program = fluid.default_main_program().inference_optimize() + infer_program = fluid.default_main_program().clone(for_test=True) for batch_id, data in enumerate(test_data.batch_generator()): data_input = prepare_batch_input( diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index fc76dbbd6fe4af5d5924451d2d96cf66fce230f9..ee2d1d501f25881e401bb725b67b0a78b71f7605 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -639,7 +639,8 @@ def wrap_decoder(trg_vocab_size, if weight_sharing: predict = layers.matmul( x=dec_output, - y=fluid.get_var(word_emb_param_names[0]), + y=fluid.default_main_program().global_block().var( + word_emb_param_names[0]), transpose_y=True) else: predict = layers.fc(input=dec_output, diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 97d9e0f45f7cb5f5091f60530990e4924b03f9d8..d34d4cc108b7ef220f82935387cc79ea93a8623f 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -462,13 +462,14 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, # Since the token number differs among devices, customize gradient scale to # use token average cost among multi-devices. and the gradient scale is # `1 / token_number` for average cost. - build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized - logging.info("begin read executor") + #build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized + exec_strategy = fluid.ExecutionStrategy() if args.update_method == "nccl2": exec_strategy.num_threads = 1 + logging.info("begin executor") train_exe = fluid.ParallelExecutor( use_cuda=TrainTaskConfig.use_gpu, loss_name=avg_cost.name, diff --git a/fluid/object_detection/train.py b/fluid/object_detection/train.py index 1106635dabab26ff70ccad81d477af25819cec17..6e763ea1d4ae1a2579238aa4388bc6425b1400f7 100644 --- a/fluid/object_detection/train.py +++ b/fluid/object_detection/train.py @@ -126,7 +126,7 @@ def train(args, devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" devices_num = len(devices.split(",")) batch_size = train_params['batch_size'] - epoc_num = train_params['epoch_num'] + epoc_num = train_params['epoc_num'] batch_size_per_device = batch_size // devices_num iters_per_epoc = train_params["train_images"] // batch_size num_workers = 8 diff --git a/fluid/video_classification/README.md b/fluid/video_classification/README.md index 34c361f2ab685b683bfbd73b5b41aa65c69e4f98..822c3ccf64cb1c5567e574425229974524a34471 100644 --- a/fluid/video_classification/README.md +++ b/fluid/video_classification/README.md @@ -111,7 +111,6 @@ According to the congfiguration of evaluation, the output log is like: Inference is used to get prediction score or video features based on trained models. ``` python infer.py \ - --batch_size=128 \ --class_dim=101 \ --image_shape=3,224,224 \ --with_mem_opt=True \ diff --git a/fluid/video_classification/data/generate_train_data.py b/fluid/video_classification/data/generate_train_data.py index 547ebbaa746ab87987c5207784b2a9c1b212315d..1a5fa2edee7353d978d0329fcdc5af1f85b50645 100644 --- a/fluid/video_classification/data/generate_train_data.py +++ b/fluid/video_classification/data/generate_train_data.py @@ -9,27 +9,33 @@ for line in f.readlines(): dd[name.lower()] = int(label) - 1 f.close() -# generate pkl -path = 'train/' -savepath = 'train_pkl/' -if not os.path.exists(savepath): - os.makedirs(savepath) - -fw = open('train.list', 'w') -for folder in os.listdir(path): - vidid = folder.split('_', 1)[1] - this_label = dd[folder.split('_')[1].lower()] - this_feat = [] - for img in sorted(os.listdir(path + folder)): - fout = open(path + folder + '/' + img, 'rb') - this_feat.append(fout.read()) - fout.close() - - res = [vidid, this_label, this_feat] - - outp = open(savepath + vidid + '.pkl', 'wb') - cPickle.dump(res, outp, protocol=cPickle.HIGHEST_PROTOCOL) - outp.close() - - fw.write('data/train_pkl/%s.pkl\n' % vidid) -fw.close() + +def generate_pkl(mode): + # generate pkl + path = '%s/' % mode + savepath = '%s_pkl/' % mode + if not os.path.exists(savepath): + os.makedirs(savepath) + + fw = open('%s.list' % mode, 'w') + for folder in os.listdir(path): + vidid = folder.split('_', 1)[1] + this_label = dd[folder.split('_')[1].lower()] + this_feat = [] + for img in sorted(os.listdir(path + folder)): + fout = open(path + folder + '/' + img, 'rb') + this_feat.append(fout.read()) + fout.close() + + res = [vidid, this_label, this_feat] + + outp = open(savepath + vidid + '.pkl', 'wb') + cPickle.dump(res, outp, protocol=cPickle.HIGHEST_PROTOCOL) + outp.close() + + fw.write('data/%s/%s.pkl\n' % (savepath, vidid)) + fw.close() + + +generate_pkl('train') +generate_pkl('test') diff --git a/fluid/video_classification/reader.py b/fluid/video_classification/reader.py index 530383f9cf1443c96bc8651fbd5d2e4beba2efd0..e688b66487f2615e19d468fa36f4cb2dc7578a54 100644 --- a/fluid/video_classification/reader.py +++ b/fluid/video_classification/reader.py @@ -16,8 +16,8 @@ THREAD = 8 BUF_SIZE = 1024 TRAIN_LIST = 'data/train.list' -TEST_LIST = 'data/val.list' -INFER_LIST = 'data/val.list' +TEST_LIST = 'data/test.list' +INFER_LIST = 'data/test.list' img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) diff --git a/fluid/video_classification/train.py b/fluid/video_classification/train.py index 6e8552d4514440e0a3b9f7eab8781dd0be54632f..a4171b2faf5fb63c37607fbe2fea21416c3e0441 100644 --- a/fluid/video_classification/train.py +++ b/fluid/video_classification/train.py @@ -2,6 +2,7 @@ import os import numpy as np import time import sys +import paddle.v2 as paddle import paddle.fluid as fluid from resnet import TSN_ResNet import reader