未验证 提交 f6628ee8 编写于 作者: T tangjiji 提交者: GitHub

Repro (#628)

* add other tasks

* Update README_zh.md

* Update README.md

* Update README.md
上级 69a9e2fa
...@@ -5,6 +5,11 @@ English| [简体中文](./README_zh.md) ...@@ -5,6 +5,11 @@ English| [简体中文](./README_zh.md)
- [Pre-trained models](#pre-trained-models) - [Pre-trained models](#pre-trained-models)
- [Downstream tasks](#downstream-tasks) - [Downstream tasks](#downstream-tasks)
* [VCR](#VCR) * [VCR](#VCR)
* [VQA](#VQA)
* [IR&TR](#Retrieval)
* [RefCOCO+](#RefCOCO+)
- [Usage](#usage) - [Usage](#usage)
* [Install PaddlePaddle](#install-paddlepaddle) * [Install PaddlePaddle](#install-paddlepaddle)
* [Fine-tuning on ERNIE-ViL](#fine-tuning-on-ernie-vil) * [Fine-tuning on ERNIE-ViL](#fine-tuning-on-ernie-vil)
...@@ -43,11 +48,15 @@ Based on the scene graph parsed from the text using Scene Graph Parser, we const ...@@ -43,11 +48,15 @@ Based on the scene graph parsed from the text using Scene Graph Parser, we const
## Pre-trained Models ## Pre-trained Models
ERNIE-ViL adopts large-scale image-text aligned datasets as the pre-training data. We provide ERNIE-ViL models of two scale settings which are pretrained on [**Conceptual Captions**](https://www.aclweb.org/anthology/P18-1238.pdf) and [**SBU Captions**](http://papers.nips.cc/paper/4470-im2text-describing-images-using-1-million-captio). ERNIE-ViL adopts large-scale image-text aligned datasets as the pre-training data. We provide ERNIE-ViL models of two scale settings which are pretrained on two out-of-domain datasets, e.g., [**Conceptual Captions**](https://www.aclweb.org/anthology/P18-1238.pdf) and [**SBU Captions**](http://papers.nips.cc/paper/4470-im2text-describing-images-using-1-million-captio).
- [**ERNIE-ViL _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-en.1.tar.gz) (_lowercased | 12-text-stream-layer, 6-visual-stream-layer_) - [**ERNIE-ViL _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-en.1.tar.gz) (_lowercased | 12-text-stream-layer, 6-visual-stream-layer_)
- [**ERNIE-ViL _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-en.1.tar.gz) (_lowercased | 24-text-stream-layer, 6-visual-stream-layer_) - [**ERNIE-ViL _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-en.1.tar.gz) (_lowercased | 24-text-stream-layer, 6-visual-stream-layer_)
We also provide large scale settings model which are pretrained on both out-of-domain datasets([**Conceptual Captions**](https://www.aclweb.org/anthology/P18-1238.pdf), [**SBU Captions**](http://papers.nips.cc/paper/4470-im2text-describing-images-using-1-million-captio)) and in-domain([**MS-COCO**](https://arxiv.org/abs/1405.0312)[**Visual-Genome**](https://arxiv.org/abs/1602.07332)) datasets.
- [**ERNIE-ViL-Out&in-domain _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-all-domain-large-en.1.tar.gz) (_lowercased | 24-text-stream-layer, 6-visual-stream-layer_)
## Downstream tasks ## Downstream tasks
We finetune ERNIE-ViL on five vision-langage downstream tasks, i.e., Visual Commensense Reasoning([**VCR**](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zellers_From_Recognition_to_Cognition_Visual_Commonsense_Reasoning_CVPR_2019_paper.pdf)), We finetune ERNIE-ViL on five vision-langage downstream tasks, i.e., Visual Commensense Reasoning([**VCR**](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zellers_From_Recognition_to_Cognition_Visual_Commonsense_Reasoning_CVPR_2019_paper.pdf)),
Visual Question Answering([**VQA**](https://openaccess.thecvf.com/content_iccv_2015/papers/Antol_VQA_Visual_Question_ICCV_2015_paper.pdf)), Visual Question Answering([**VQA**](https://openaccess.thecvf.com/content_iccv_2015/papers/Antol_VQA_Visual_Question_ICCV_2015_paper.pdf)),
...@@ -59,7 +68,7 @@ _Code and pre-trained models related to VCR task are made public now, and those ...@@ -59,7 +68,7 @@ _Code and pre-trained models related to VCR task are made public now, and those
### VCR ### VCR
* datasets * datasets
* The training, validation and testing data of VCR task are provided by [**VCR Website**](https://visualcommonsense.com/download/). * The training, validation and testing data of **VCR** task are provided by [**VCR Website**](https://visualcommonsense.com/download/).
* Organization of visual features is modified from [**ViLBERT**](https://github.com/jiasenlu/vilbert_beta), we directly use the data from it. Data can be downloaded [here](https://github.com/jiasenlu/vilbert_beta/tree/master/data). * Organization of visual features is modified from [**ViLBERT**](https://github.com/jiasenlu/vilbert_beta), we directly use the data from it. Data can be downloaded [here](https://github.com/jiasenlu/vilbert_beta/tree/master/data).
* Put all downloaded files under diretory "data/vcr". * Put all downloaded files under diretory "data/vcr".
...@@ -67,19 +76,79 @@ _Code and pre-trained models related to VCR task are made public now, and those ...@@ -67,19 +76,79 @@ _Code and pre-trained models related to VCR task are made public now, and those
* Task pre-training: We perform task-pretraining on VCR task, which is also known as task-specific-pretraining. The trained models are as follows: * Task pre-training: We perform task-pretraining on VCR task, which is also known as task-specific-pretraining. The trained models are as follows:
* [**ERNIE-ViL-VCR-task-pretrain _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-VCR-task-pre-en.1.tar.gz) * [**ERNIE-ViL-VCR-task-pretrain _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-VCR-task-pre-en.1.tar.gz)
* [**ERNIE-ViL-VCR-task-pretrain _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-VCR-task-pre-en.1.tar.gz) * [**ERNIE-ViL-VCR-task-pretrain _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-VCR-task-pre-en.1.tar.gz)
* Performance: Results of VCR task for ERNIE-ViL model, compared with previous state-of-the-art pre-trained models([**VILLA**](https://arxiv.org/pdf/2006.06195.pdf)). * Performance: Results of VCR task for different scale settings of ERNIE-ViL model
| Models | <strong>Q->A</strong> | <strong>QA->R</strong> | <strong>Q->AR</strong> | | Models | <strong>Q->A</strong> | <strong>QA->R</strong> | <strong>Q->AR</strong> |
| :--------------------------------------| :---------------------------: | :----------------------------: | :-----------------------------: | | :--------------------------------------| :---------------------------: | :----------------------------: | :-----------------------------: |
| VILLA (task-pretrain) _base_ | 75.54(76.4) | 78.78(79.1) | 59.75(60.6) |
| ERNIE-ViL (task-pretrain) _base_ | 76.37(77.0) | 79.65(80.3) | 61.24(62.1) | | ERNIE-ViL (task-pretrain) _base_ | 76.37(77.0) | 79.65(80.3) | 61.24(62.1) |
| VILLA (task-pretrain) _large_ | 78.45(78.9) | 82.57(82.8) | 65.18(65.7) |
| ERNIE-ViL (task-pretrain) _large_ | <strong>78.52(79.2)</strong> | <strong>83.37(83.5)</strong> | <strong/>65.81(66.3) </strong> | | ERNIE-ViL (task-pretrain) _large_ | <strong>78.52(79.2)</strong> | <strong>83.37(83.5)</strong> | <strong/>65.81(66.3) </strong> |
_Numerical results outside and inside parentheses represent the dev and test performance of VCR task respectively. _Numerical results outside and inside parentheses represent the dev and test performance of VCR task respectively.
Test results are obtained from the [**VCR leadborad**](https://visualcommonsense.com/leaderboard/)._ Test results are obtained from the [**VCR leadborad**](https://visualcommonsense.com/leaderboard/)._
### VQA
* datasets
* The training, validation and testing data of VCR task are provided by[**VQA Website**](https://visualqa.org/).
* Visual features are extracted by using tools in [bottom-up attention](https://github.com/jiasenlu/bottom-up-attention), The minimum and maximum number of the extracted boxes are 100 and 100.
* A single training & test data is organized as follows:
```script
question_id, question, answer_label, answer_score, image_w, image_h, number_box, image_loc, image_embeddings
```
_The labels and scores of multiple answers are separated by the character ‘|’._
* Performance: Results of **VQA** task for different scale settings of ERNIE-ViL model
| Models | <strong>test-dev</strong> | <strong>test-std</strong> |
| :-------------------------------- | :-------------------------------: | :------------------------------: |
| ERNIE-ViL _base_ | 73.18 | 73.36 |
| ERNIE-ViL _large_ | 73.78 | 73.96 |
| ERNIE-ViL-Out&in-domain _large_ | 74.95 | 75.10 |
### IR&TR
* datasets
* The images and captions of Flickr30k datasets can be obtailed from [**here**](https://www.kaggle.com/hsankesara/flickr-image-dataset).
* Visual features are extracted by using tools in [bottom-up attention](https://github.com/jiasenlu/bottom-up-attention). The minimum and maximum number of the extracted boxes are 0 and 36. The organization of visual features is illstruated as follows:
```script
image_w, image_h, number_box, image_loc, image_embeddings
```
* The organization of text data can refer to our given sample, e.g., data/flickr.flickr.dev.data.
* Performance
* Results of **Image Retrieval** task on **Flickr30k dataset** for different scale settings of ERNIE-ViL model
| Models | <strong>R@1</strong> | <strong>R@5</strong> | <strong>R@10</strong> |
| :-------------------------------- | :---------------------: | :----------------------: | :----------------------: |
| ERNIE-ViL _base_ | 74.44 | 92.72 | 95.94 |
| ERNIE-ViL _large_ | 75.10 | 93.42 | 96.26 |
| ERNIE-ViL-Out&in-domain _large_ | 76.66 | 94.16 | 96.76 |
* Results of **Text Retrieval** task on **Flickr30k dataset** for different scale settings of ERNIE-ViL model
| Models | <strong>R@1</strong> | <strong>R@5</strong> | <strong>R@10</strong> |
| :-------------------------------- | :---------------------: | :----------------------: | :----------------------: |
| ERNIE-ViL _base_ | 86.70 | 97.80 | 99.00 |
| ERNIE-ViL _large_ | 88.70 | 97.30 | 99.10 |
| ERNIE-ViL-Out&in-domain _large_ | 89.20 | 98.50 | 99.20 |
### RefCOCO+
* datasets
* Organization of visual features is modified from [MAttNet](https://github.com/lichengunc/MAttNet).
* A single training & test data is organized as follows:
```script
expressions, image_w, image_h, number_box, number_boxes_gt, image_loc, image_embeddings, box_label, label
```
* Performance
* Results of **RefCOCO+** task for different scale settings of ERNIE-ViL model
| Models | <strong>val</strong> | <strong>testA</strong> | <strong>testB</strong> |
| :-------------------------------- | :---------------------: | :----------------------: | :----------------------: |
| ERNIE-ViL _base_ | 74.02 | 80.33 | 64.74 |
| ERNIE-ViL _large_ | 74.24 | 80.97 | 64.70 |
| ERNIE-ViL-Out&in-domain _large_ | 75.89 | 82.39 | 66.91 |
## Usage ## Usage
...@@ -92,32 +161,61 @@ This code has been tested with Paddle Fluid 1.8 with Python 2.7. Other dependenc ...@@ -92,32 +161,61 @@ This code has been tested with Paddle Fluid 1.8 with Python 2.7. Other dependenc
### Fine-tuning on ERNIE-ViL ### Fine-tuning on ERNIE-ViL
Please update LD_LIBRARY_PATH about CUDA, cuDNN, NCCL2 before fine-tuning. You can easily run fine-tuning through Please update LD_LIBRARY_PATH about CUDA, cuDNN, NCCL2 before fine-tuning. You can easily run fine-tuning through
configuration files. For example, you can finetune ERNIE-ViL model on VCR task by configuration files. You can finetune ERNIE-ViL model on different downstream tasks by the following command:
```script ```script
sh run_finetuning.sh vcr conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $pretrain_models sh run_finetuning.sh $task_name(vqa/flickr/refcoco_plus/vcr) conf/${task_name}/model_conf_${task_name} $vocab_file $ernie_vil_config $pretrain_models_params
``` ```
Files which are needed by fine-tuning can be found in our given download links, incluing vocabulary dictionary, configuration Files which are needed by fine-tuning can be found in our given download links, incluing vocabulary dictionary, configuration
file and pre-trained parameters. Note that our fine-tuning experiments on VCR are carried on 4 NVIDIA V100 (32GB) GPUs. file and pre-trained parameters. Training details of different downstream tasks (large scale) are illstruated in the table below.
| Tasks | Batch Size | Learning Rate | # of Epochs | GPUs | Layer Decay rate | Hidden dropout |
| ----- | ----------:| -------------:| -----------:| --------:| ----------------:| --------------:|
| VCR | 16(x4) | 1e-4 | 6 | 4x V100 | 0.9 | 0.1 |
| VQA 2.0 | 64(x4) | 1e-4 | 15 | 4x V100 | 0.9 | 0.1 |
| RefCOCO+ | 64(x2) | 1e-4 | 30 | 2x V100 | 0.9 | 0.2 |
| Flickr | 8(x8) | 2e-5 | 40 | 8x V100 | 0.0 | 0.1 |
Our fine-tuning experiments on downstream tasks are carried on NVIDIA V100 (32GB) GPUs.
If your GPU memory is not enough, you can reduce the batch size in the corresponding configuration file, e.g., "conf/vcr/model_conf_vcr". If your GPU memory is not enough, you can reduce the batch size in the corresponding configuration file, e.g., "conf/vcr/model_conf_vcr".
### Inference ### Inference
You can use the following command to infer fine-tuned models. For example, you can infer VCR models by the following commands for different sub-tasks: You can use the following command to infer fine-tuned models.
**Task Q->A** #### VCR
```script ```script
sh run_inference.sh vcr qa $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file Task Q->A: sh run_inference.sh vcr qa $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file
``` ```
**Task QA->R**
```script
Task Q->AR: sh run_inference.sh vcr qar $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file
```
#### VQA
```script ```script
sh run_inference.sh vcr qar $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file sh run_inference.sh vqa eval $split(val/test_dev/test_std) conf/vqa/model_conf_vqa $vocab_file $ernie_vil_config $model_params $res_file
``` ```
_No test labels are given in the released test samples, you can obtailed the final score by submiting the result file to the [VQA website](https://visualqa.org/)_.
#### RefCOCO+
```script
sh run_inference.sh refcoco_plus eval $split(val/test_A/test_B) conf/refcoco_plus/model_conf_refcoco_plus $vocab_file $ernie_vil_config $model_params $res_file
```
#### Flickr
```script
sh run_inference.sh flickr eval $split(dev/test) conf/flickr/model_conf_flickr $vocab_file $ernie_vil_config $model_params $res_file
```
_Get the accuray score by using the given tools of tools/get_recall.py._
## Citation ## Citation
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
- [模型框架](#模型框架) - [模型框架](#模型框架)
- [预训练模型](#预训练模型) - [预训练模型](#预训练模型)
- [下游任务](#下游任务) - [下游任务](#下游任务)
* [视觉推理](#视觉推理) * [视觉常识推理](#视觉常识推理)
* [视觉问答](#视觉问答)
* [跨模态检索](#跨模态检索)
* [引用表达式理解](#引用表达式理解)
- [使用说明](#使用说明) - [使用说明](#使用说明)
* [安装飞桨](#安装飞桨) * [安装飞桨](#安装飞桨)
* [运行微调](#运行微调) * [运行微调](#运行微调)
...@@ -41,45 +44,110 @@ ERNIE-ViL 场景图预训练任务结构 ...@@ -41,45 +44,110 @@ ERNIE-ViL 场景图预训练任务结构
## 预训练模型 ## 预训练模型
ERNIE-ViL使用大规模图文对齐数据作为预训练数据,基于[**Conceptual ERNIE-ViL使用大规模图文对齐数据作为预训练数据,基于[**Conceptual
Captions**](https://www.aclweb.org/anthology/P18-1238.pdf)[**SBU Captions**](https://www.aclweb.org/anthology/P18-1238.pdf)[**SBU
Captions**](http://papers.nips.cc/paper/4470-im2text-describing-images-using-1-million-captio)数据集,训练和发布了两种参数规模的模型 Captions**](http://papers.nips.cc/paper/4470-im2text-describing-images-using-1-million-captio)两个out-of-domain数据集,训练两种参数规模模型如下
- [**ERNIE-ViL _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-en.1.tar.gz) (_lowercased | 12-text-stream-layer, 6-visual-stream-layer_) - [**ERNIE-ViL _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-en.1.tar.gz) (_lowercased | 12-text-stream-layer, 6-visual-stream-layer_)
- [**ERNIE-ViL _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-en.1.tar.gz) (_lowercased | 24-text-stream-layer, 6-visual-stream-layer_) - [**ERNIE-ViL _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-en.1.tar.gz) (_lowercased | 24-text-stream-layer, 6-visual-stream-layer_)
基于两个out-of-domian数据集([**Conceptual
Captions**](https://www.aclweb.org/anthology/P18-1238.pdf)[**SBU
Captions**](http://papers.nips.cc/paper/4470-im2text-describing-images-using-1-million-captio))和两个in-domain数据集([**MS-COCO**](https://arxiv.org/abs/1405.0312)[**Visual-Genome**](https://arxiv.org/abs/1602.07332))训练了large参数规模的模型:
- [**ERNIE-ViL-Out&in-domain _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-all-domain-large-en.1.tar.gz) (_lowercased | 24-text-stream-layer, 6-visual-stream-layer_)
## 下游任务 ## 下游任务
ERNIE-ViL在五个视觉语言下游任务进行了实验,包括[**视觉常识推理**](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zellers_From_Recognition_to_Cognition_Visual_Commonsense_Reasoning_CVPR_2019_paper.pdf) ERNIE-ViL在五个视觉语言下游任务进行了实验,包括[**视觉常识推理**](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zellers_From_Recognition_to_Cognition_Visual_Commonsense_Reasoning_CVPR_2019_paper.pdf)
[**视觉问答**](https://openaccess.thecvf.com/content_iccv_2015/papers/Antol_VQA_Visual_Question_ICCV_2015_paper.pdf) [**视觉问答**](https://openaccess.thecvf.com/content_iccv_2015/papers/Antol_VQA_Visual_Question_ICCV_2015_paper.pdf)
[**跨模态图片检索**](https://www.mitpressjournals.org/doi/abs/10.1162/tacl_a_00166) [**跨模态图片检索**](https://www.mitpressjournals.org/doi/abs/10.1162/tacl_a_00166)
[**跨模态文本检索**](https://www.mitpressjournals.org/doi/abs/10.1162/tacl_a_00166) [**跨模态文本检索**](https://www.mitpressjournals.org/doi/abs/10.1162/tacl_a_00166)
[**引用式理解**](https://www.aclweb.org/anthology/D14-1086.pdf) [**引用表达式理解**](https://www.aclweb.org/anthology/D14-1086.pdf),与主流模型的效果对比可以参考开源论文。
_当前仅开源视觉常识推理任务相关模型和代码,后续计划开源更多下游任务的模型和代码。_
### **视觉常识推理** ### **视觉常识推理**
* 数据集合 * 数据集合
* 训练、验证和测试集合相关数据[**视觉常识推理官网**](http://visualcommonsense.com/download/)提供 * 训练、验证和测试集合相关数据可以由[**视觉常识推理官网**](http://visualcommonsense.com/download/)获取
* 视觉端特征的组织方式借鉴[**ViLBERT**](https://github.com/jiasenlu/vilbert_beta), 因此项目直接使用**ViLBERT**中的数据,数据[下载地址](https://github.com/jiasenlu/vilbert_beta/tree/master/data); * 视觉端特征的组织方式借鉴[**ViLBERT**](https://github.com/jiasenlu/vilbert_beta), 因此项目直接使用**ViLBERT**中的数据,数据[下载地址](https://github.com/jiasenlu/vilbert_beta/tree/master/data);
* 将所有获取的文件放在 data/vcr 目录下; * 将所有获取的文件放在 data/vcr 目录下;
* 任务预训练: 基于ERNIE-ViL的out-of-domain模型,在视觉推理任务中进行了任务预训练,预训练获得模型如下
* 任务预训练: 在视觉推理任务中进行了任务预训练,预训练获得模型如下
* [**ERNIE-ViL-VCR-task-pretrain _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-VCR-task-pre-en.1.tar.gz) * [**ERNIE-ViL-VCR-task-pretrain _base_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-base-VCR-task-pre-en.1.tar.gz)
* [**ERNIE-ViL-VCR-task-pretrain _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-VCR-task-pre-en.1.tar.gz) * [**ERNIE-ViL-VCR-task-pretrain _large_**](https://ernie-github.cdn.bcebos.com/model-ernie-vil-large-VCR-task-pre-en.1.tar.gz)
* 效果: ERNIE-ViL与之前最优预训练模型[**VILLA**](https://arxiv.org/pdf/2006.06195.pdf)在视觉常识推理任务上的效果对比如下:
* 效果: ERNIE-ViL在视觉常识推理任务上的效果对比如下:
| 模型 | <strong>Q->A</strong> | <strong>QA->R</strong> | <strong>Q->AR</strong> | | 模型 | <strong>Q->A</strong> | <strong>QA->R</strong> | <strong>Q->AR</strong> |
| :---------------------------------- | :---------------------------: | :----------------------------: | :---------------------------: | | :---------------------------------- | :---------------------------: | :----------------------------: | :---------------------------: |
| VILLA (task-pretrain) _base_ | 75.54(76.4) | 78.78(79.1) | 59.75(60.6) |
| ERNIE-ViL (task-pretrain) _base_ | 76.37(77.0) | 79.65(80.3) | 61.24(62.1) | | ERNIE-ViL (task-pretrain) _base_ | 76.37(77.0) | 79.65(80.3) | 61.24(62.1) |
| VILLA (task-pretrain) _large_ | 78.45(78.9) | 82.57(82.8) | 65.18(65.7) |
| ERNIE-ViL (task-pretrain) _large_ | <strong>78.52(79.2)</strong> | <strong>83.37(83.5)</strong> | <strong/>65.81(66.3) </strong> | | ERNIE-ViL (task-pretrain) _large_ | <strong>78.52(79.2)</strong> | <strong>83.37(83.5)</strong> | <strong/>65.81(66.3) </strong> |
_注:括号外表示验证集效果,括号内表示测试集效果,测试集效果由[VCR榜单](https://visualcommonsense.com/leaderboard/)提供。_ _注:括号外表示验证集效果,括号内表示测试集效果,测试集效果提交到[VCR榜单](https://visualcommonsense.com/leaderboard/)获得。_
### **视觉问答**
* 数据集合
* 原始图片、问题和答案可以由[**视觉问答官网**](https://visualqa.org/)获取。
* 视觉端特征使用[**bottom-up attention**](https://github.com/jiasenlu/bottom-up-attention)中的工具提取,提取的box动态值为100-100。
* 训练 & 测试数据按照如下方式组织:
```script
question_id, question, answer_label, answer_score, image_w, image_h, number_box, image_loc, image_embeddings
```
_多个答案的label和score用 ‘|’ 分隔,和image相关的项均可以从bottom up attention的工具提取。_
* 效果:ERNIE-ViL的三种预训练模型在**视觉问答**任务下的效果如下表
| 模型 | <strong>test-dev</strong> | <strong>test-std</strong> |
| :-------------------------------- | :-------------------------------: | :------------------------------: |
| ERNIE-ViL _base_ | 73.18 | 73.36 |
| ERNIE-ViL _large_ | 73.78 | 73.96 |
| ERNIE-ViL-Out&in-domain _large_ | 74.95 | 75.10 |
### **跨模态检索**
* 数据集合
* 原始图片和文本描述相关的数据,可以从[**这里**](https://www.kaggle.com/hsankesara/flickr-image-dataset)获取。
* 视觉端特征使用[**bottom-up attention**](https://github.com/jiasenlu/bottom-up-attention)提取,提取的box动态值为0-36。
* 文本相关的数据可以参见data/flickr给出的示例 flickr.dev.data,图片端特征组织方式为
```script
image_w, image_h, number_box, image_loc, image_embeddings
```
* 效果
* ERNIE-ViL的三种预训练模型在**跨模态图片检索(Flickr30k 数据集)**上的效果如下表
| 模型 | <strong>R@1</strong> | <strong>R@5</strong> | <strong>R@10</strong> |
| :-------------------------------- | :---------------------: | :----------------------: | :----------------------: |
| ERNIE-ViL _base_ | 74.44 | 92.72 | 95.94 |
| ERNIE-ViL _large_ | 75.10 | 93.42 | 96.26 |
| ERNIE-ViL-Out&in-domain _large_ | 76.66 | 94.16 | 96.76 |
* ERNIE-ViL的三种预训练模型在**跨模态文本检索(Flickr30k 数据集)**任务上的效果如下表
| 模型 | <strong>R@1</strong> | <strong>R@5</strong> | <strong>R@10</strong> |
| :-------------------------------- | :---------------------: | :----------------------: | :----------------------: |
| ERNIE-ViL _base_ | 86.70 | 97.80 | 99.00 |
| ERNIE-ViL _large_ | 88.70 | 97.30 | 99.10 |
| ERNIE-ViL-Out&in-domain _large_ | 89.20 | 98.50 | 99.20 |
### **引用表达式理解**
* 数据集合
* 视觉端特征参考了[MAttNet](https://github.com/lichengunc/MAttNet)的提取方式。
* 单条训练 & 验证 数据的组织方式为
```script
expressions, image_w, image_h, number_box, number_boxes_gt, image_loc, image_embeddings, box_label, label
```
* 效果
* ERNIE-ViL的三种预训练模型在**引用表达式理解**任务上的效果如下表:
| 模型 | <strong>val</strong> | <strong>testA</strong> | <strong>testB</strong> |
| :-------------------------------- | :---------------------: | :----------------------: | :----------------------: |
| ERNIE-ViL _base_ | 74.02 | 80.33 | 64.74 |
| ERNIE-ViL _large_ | 74.24 | 80.97 | 64.70 |
| ERNIE-ViL-Out&in-domain _large_ | 75.89 | 82.39 | 66.91 |
## 使用说明 ## 使用说明
...@@ -90,32 +158,62 @@ ERNIE-ViL代码基于Paddle Fluid 1.8 和 Python 2.7, 依赖的其他模块也 ...@@ -90,32 +158,62 @@ ERNIE-ViL代码基于Paddle Fluid 1.8 和 Python 2.7, 依赖的其他模块也
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### 运行微调 ### 运行微调
在运行 ERNIE-ViL 前,需要将 CUDA 、cuDNN 、NCCL2 的动态库路径添加到 LD_LIBRARY_PATH 。 我们把下游任务的参数配置文件放到了 conf/ ,可以简单地通过配置文件运行。 例如,您可以通过下面的指令在VCR上任务上进行微调: 在运行 ERNIE-ViL 微调前,需要将 CUDA 、cuDNN 、NCCL2 的动态库路径添加到 LD_LIBRARY_PATH 。 我们把下游任务的参数配置文件放到了 conf/ ,可以简单地通过配置文件运行。 例如,您可以通过下面的指令在各个下游任务上进行微调:
```script ```script
sh run_finetuning.sh vcr conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $pretrain_models_params sh run_finetuning.sh $task_name(vqa/flickr/refcoco_plus/vcr) conf/${task_name}/model_conf_${task_name} $vocab_file $ernie_vil_config $pretrain_models_params
``` ```
前面提供的模型链接中包含了所有需要的文件, 包含词表文件,配置文件和预训练参数。VCR任务的微调实验是在 4 张32 GB 的英伟达V100 GPU上运行,如果您的GPU显存不够,可以考虑八张卡运行或者减小配置中的batch_size。
_我们目前开放了预训练模型和VCR的任务代码,其他的下游任务可以参考任务自主尝试。_ 前面提供的模型链接中包含了所有需要的文件, 包含词表文件,配置文件和预训练参数。微调相关的模型配置和参数配置可以通过conf/ 目录下的文件找到,这里对论文最优结果(large模型)的一些关键参数进行汇总:
| Tasks | Batch Size | Learning Rate | # of Epochs | GPUs | Layer Decay rate | Hidden dropout |
| ----- | ----------:| -------------:| -----------:| --------:| ----------------:| --------------:|
| VCR | 16(x4) | 1e-4 | 6 | 4x V100 | 0.9 | 0.1 |
| VQA 2.0 | 64(x4) | 1e-4 | 15 | 4x V100 | 0.9 | 0.1 |
| RefCOCO+ | 64(x2) | 1e-4 | 30 | 2x V100 | 0.9 | 0.2 |
| Flickr | 8(x8) | 2e-5 | 40 | 8x V100 | 0.0 | 0.1 |
所有的下游任务的微调实验是在 32 GB 的英伟达V100 GPU上运行,如果您的GPU显存不够,可以考虑更多张卡运行或者减小配置中的batch_size。
### 预测 ### 预测
基于已经训练的模型,您可以通过下面的命令测试VCR的效果: 基于已经训练的模型,您可以通过下面的命令测试下游任务的效果(相关的配置文件可以从之前下载的包获得)
**Task Q->A** #### VCR
```script ```script
sh run_inference.sh vcr qa $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file Task Q->A: sh run_inference.sh vcr qa $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file
``` ```
**Task QA->R**
```script ```script
sh run_inference.sh vcr qar $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file Task Q->AR: sh run_inference.sh vcr qar $split(val/test) conf/vcr/model_conf_vcr $vocab_file $ernie_vil_config $model_params $res_file
``` ```
VCR的测试可以在一张32GB的英伟达V100 GPU上运行,测试的结果包含Q->A 任务、QA->R任务和Q->AR任务,其中Q->AR任务由前两个任务结果合并所得。 _VCR的测试可以在一张32GB的英伟达V100 GPU上运行,测试的结果包含Q->A 任务、QA->R任务和Q->AR任务,其中Q->AR任务由前两个任务结果合并所得._
#### VQA
```script
sh run_inference.sh vqa eval $split(val/test_dev/test_std) conf/vqa/model_conf_vqa $vocab_file $ernie_vil_config $model_params $res_file
```
注:_VQA的测试样本没有label信息,需要将结果文件提交到[**VQA网站**](https://visualqa.org/)查看结果。_
#### RefCOCO+
```script
sh run_inference.sh refcoco_plus eval $split(val/test_A/test_B) conf/refcoco_plus/model_conf_refcoco_plus $vocab_file $ernie_vil_config $model_params $res_file
```
#### Flickr
```script
sh run_inference.sh flickr eval $split(dev/test) conf/flickr/model_conf_flickr $vocab_file $ernie_vil_config $model_params $res_file
```
注:_Flickr的结果是一个预测结果文件,可以参考 tools/get_recall.py 统计一下最终结果。_
## 引用 ## 引用
可以按下面的格式引用我们的论文: 可以按下面的格式引用我们的论文:
......
...@@ -35,8 +35,13 @@ model_g.add_arg("task_name", str, "vcr", "Task to finetune on ERNIE-ViL") ...@@ -35,8 +35,13 @@ model_g.add_arg("task_name", str, "vcr", "Task to finetune on ERNIE-ViL")
train_g = ArgumentGroup(parser, "training", "training options.") train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 100, "Number of epoches for training.") train_g.add_arg("epoch", int, 100, "Number of epoches for training.")
train_g.add_arg("learning_rate", float, 0.0001, "Learning rate used to train with warmup.") train_g.add_arg("learning_rate", float, 0.0001, "Learning rate used to train with warmup.")
train_g.add_arg("seq_dropout", float, 0.0, "dropout rate after the sequence output.")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay", train_g.add_arg("lr_scheduler", str, "linear_warmup_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay', 'manual_warmup_decay']) "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay', 'manual_warmup_decay'])
train_g.add_arg("layer_decay_rate", float, 0.0, "layer wise decay, 0.0 denote no layer decay")
train_g.add_arg("text_init_layers", int, 18, "diff from text and image layer, base:12-6=6, large:24-6=18")
train_g.add_arg("n_layers", int, 30, "max layers of text and image, base:12 + 6 , large:24 + 6")
train_g.add_arg("decay_steps", str, "", "learning rate decay steps, list with ;") train_g.add_arg("decay_steps", str, "", "learning rate decay steps, list with ;")
train_g.add_arg("lr_decay_ratio", float, 0.1, "learning rate decay ratio, used with manual_warmup_decay") train_g.add_arg("lr_decay_ratio", float, 0.1, "learning rate decay ratio, used with manual_warmup_decay")
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.") train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
...@@ -68,6 +73,9 @@ data_g.add_arg("feature_size", int, 2048, "Number of roi feature size of image." ...@@ -68,6 +73,9 @@ data_g.add_arg("feature_size", int, 2048, "Number of roi feature size of image."
data_g.add_arg("fusion_method", str, "sum", "Number of roi feature size of image.") data_g.add_arg("fusion_method", str, "sum", "Number of roi feature size of image.")
data_g.add_arg("batch_size", int, 16, "Total examples' number in batch for training. see also --in_tokens.") data_g.add_arg("batch_size", int, 16, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("task_group_json", str, "", "Path to task json") data_g.add_arg("task_group_json", str, "", "Path to task json")
data_g.add_arg("scale_circle", float, "1.0", "The scale factor in circle loss function, only use in circle loss mode")
data_g.add_arg("use_sigmoid", bool, False, "Whether to use sigmoid to match score, use for explode problem")
data_g.add_arg("margin", float, "0.2", "The margin value in triplet loss function")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.") run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.") run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
......
...@@ -56,7 +56,6 @@ def prepare_batch_data(batch_records, num_choice, pad_id, task_index, task_num): ...@@ -56,7 +56,6 @@ def prepare_batch_data(batch_records, num_choice, pad_id, task_index, task_num):
src_pos = np.array(batch_input_pos).astype("int64").reshape([num_choice * num_sample, max_len, 1]) src_pos = np.array(batch_input_pos).astype("int64").reshape([num_choice * num_sample, max_len, 1])
src_seg = np.array(batch_seg_ids).astype("int64").reshape([num_choice * num_sample, max_len, 1]) src_seg = np.array(batch_seg_ids).astype("int64").reshape([num_choice * num_sample, max_len, 1])
src_masks = np.array(batch_input_masks).astype("float32").reshape([num_choice * num_sample, max_len, 1]) src_masks = np.array(batch_input_masks).astype("float32").reshape([num_choice * num_sample, max_len, 1])
src_task = np.zeros(src_ids.shape, dtype="int64")
batch, seq_len, fea_len = image_embedding.shape batch, seq_len, fea_len = image_embedding.shape
image_embedding = np.tile(np.expand_dims(image_embedding, axis=1), \ image_embedding = np.tile(np.expand_dims(image_embedding, axis=1), \
(1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, fea_len]) (1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, fea_len])
...@@ -64,7 +63,7 @@ def prepare_batch_data(batch_records, num_choice, pad_id, task_index, task_num): ...@@ -64,7 +63,7 @@ def prepare_batch_data(batch_records, num_choice, pad_id, task_index, task_num):
(1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, 1]) (1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, 1])
image_loc = np.tile(np.expand_dims(image_loc, axis=1), \ image_loc = np.tile(np.expand_dims(image_loc, axis=1), \
(1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, 5]) (1, num_choice, 1, 1)).reshape([num_choice * batch, seq_len, 5])
return_list = [src_ids, src_pos, src_seg, src_task, src_masks, \ return_list = [src_ids, src_pos, src_seg, src_masks, \
image_embedding, image_loc, image_mask, labels, batch_anno_ids] image_embedding, image_loc, image_mask, labels, batch_anno_ids]
return_list.append(np.array([task_index]).astype('int64')) return_list.append(np.array([task_index]).astype('int64'))
return_list.append(binary_labels) return_list.append(binary_labels)
...@@ -76,14 +75,187 @@ def prepare_batch_data(batch_records, num_choice, pad_id, task_index, task_num): ...@@ -76,14 +75,187 @@ def prepare_batch_data(batch_records, num_choice, pad_id, task_index, task_num):
return return_list return return_list
def prepare_vqa_batch_data(insts,
total_token_num,
task_index,
task_num,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
prepare batch data for vqa tasks
"""
batch_src_ids = [inst["token_ids"] for inst in insts]
batch_sent_ids = [inst["sent_ids"] for inst in insts]
batch_pos_ids = [inst["pos_ids"] for inst in insts]
batch_image_embedding = [inst["image_embeddings"] for inst in insts]
batch_image_loc = [inst["image_loc"] for inst in insts]
batch_weight_label = [inst["weight_labels"] for inst in insts]
q_ids = np.array([inst["question_id"] for inst in insts])
#pad and trans to numpy array
src_id, self_input_mask, seq_lens = pad_batch_data(
batch_src_ids, pad_idx=pad_id, return_input_mask=True, return_seq_lens = True)
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
weight_labels = np.array(batch_weight_label).astype("float32")
#image_embedding_ori = copy.deepcopy(batch_image_embedding)
image_embedding, image_mask = pad_feature_data(batch_image_embedding, return_mask = True)
#image_embedding_ori = pad_feature_data(image_embedding_ori)
image_loc = pad_feature_data(batch_image_loc)
return_list = [
src_id, pos_id, sent_id, self_input_mask, \
image_embedding, image_loc, image_mask, weight_labels, q_ids
]
return return_list
def prepare_flickr_data(insts,
total_token_num,
task_index,
task_num,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
outs=4,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
prepare flickr data for finetuning tasks
"""
if outs > 1:
batch_src_ids = [inst["token_ids"][out] for inst in insts for out in range(outs)]
batch_sent_ids = [inst["sent_ids"][out] for inst in insts for out in range(outs)]
batch_pos_ids = [inst["pos_ids"][out] for inst in insts for out in range(outs)]
batch_image_embedding = [inst["image_embeddings"][out] for inst in insts for out in range(outs)]
batch_image_loc = [inst["image_loc"][out] for inst in insts for out in range(outs)]
else:
batch_src_ids = [inst["token_ids"] for inst in insts]
batch_sent_ids = [inst["sent_ids"] for inst in insts]
batch_pos_ids = [inst["pos_ids"] for inst in insts]
batch_image_embedding = [inst["image_embeddings"] for inst in insts ]
batch_image_loc = [inst["image_loc"] for inst in insts ]
batch_ids = [inst["ids"] for inst in insts for out in range(outs)]
batch_size = int(len(batch_src_ids) / outs)
label = np.array([[0] for i in range(batch_size)], dtype = "int64")
src_id, self_input_mask, seq_lens = pad_batch_data(
batch_src_ids, pad_idx=pad_id, return_input_mask=True, return_seq_lens = True)
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
image_embeddings, image_mask = pad_feature_data(batch_image_embedding, return_mask = True)
image_loc = pad_feature_data(batch_image_loc)
ids = np.array(batch_ids, dtype = "int64")
return_list = [
src_id, pos_id, sent_id, self_input_mask, image_embeddings, image_loc, image_mask, label, ids]
return return_list
def prepare_refcoco_plus_batch_data(insts,
total_token_num,
task_index,
task_num,
voc_size=0,
pad_id=None,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
prepare batch data for refcoco_plus tasks
"""
batch_src_ids = [inst["token_ids"] for inst in insts]
batch_sent_ids = [inst["sent_ids"] for inst in insts]
batch_pos_ids = [inst["pos_ids"] for inst in insts]
batch_image_embedding = [inst["image_embeddings"] for inst in insts]
batch_image_loc = [inst["image_loc"] for inst in insts]
batch_image_label = [inst["label"] for inst in insts]
add_items = np.array([inst["add_item"] for inst in insts], dtype="float32")
src_id, self_input_mask, seq_lens = pad_batch_data(
batch_src_ids, pad_idx=pad_id, return_input_mask=True, return_seq_lens = True)
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
image_embedding, image_mask = pad_feature_data(batch_image_embedding, return_mask = True)
image_loc = pad_feature_data(batch_image_loc)
image_label = pad_feature_data(batch_image_label)
return_list = [
src_id, pos_id, sent_id, self_input_mask, seq_lens, \
image_embedding, image_loc, image_mask, image_label, add_items
]
return return_list
def pad_batch_data(insts,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False,
return_seq_lens=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] * (max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
if return_seq_lens:
seq_lens = np.array([len(inst) for inst in insts])
return_list += [seq_lens.astype("int64").reshape([-1, 1])]
return return_list if len(return_list) > 1 else return_list[0]
def pad_feature_data(data, pad_value=0.0, dtype="float32", return_mask=False): def pad_feature_data(data, pad_value=0.0, dtype="float32", return_mask=False):
""" """
pad visual features with given pad value pad visual features with given pad value
""" """
max_lenth=max([len(item) for item in data]) max_length=max([len(item) for item in data])
data_width = len(data[0][0]) data_width = len(data[0][0])
out_data = np.ones((len(data), max_lenth, data_width), dtype=dtype) * pad_value out_data = np.ones((len(data), max_length, data_width), dtype=dtype) * pad_value
out_mask = np.zeros((len(data), max_lenth, 1), dtype=dtype) out_mask = np.zeros((len(data), max_length, 1), dtype=dtype)
for i in range(len(data)): for i in range(len(data)):
out_data[i, 0: len(data[i]), :] = data[i] out_data[i, 0: len(data[i]), :] = data[i]
if return_mask: if return_mask:
......
output_model_path="./output_flickr"
lr_scheduler="manual_warmup_decay"
decay_steps="54360;72480"
num_train_steps=90600
SAVE_STEPS=4530
WARMUP_STEPS=9060
BATCH_SIZE=4
LR_RATE=1e-5
WEIGHT_DECAY=0.01
MAX_LEN=48
hardest=False
meansum=False
use_circle_loss=True
scale_circle=32.0
use_sigmoid=True
margin=0.3
[
{
"prob": 1.0,
"data_func": "image_text_match",
"task_name": "image_text_match",
"train_filelist": "./conf/flickr/flickr_retrieval_train.filelist",
"train_image_path":"./data/flickr/flickr_bottom_up_10_36.csv.0",
"train_caption_path":"./data/flickr/flickr.train.data",
"hardest_setting_path":"./data/flickr/hard_negative.pkl",
"Proprocessor": "PreprocessorBasic",
"tokenizer_name" : "FullTokenizer",
"vocab_path" : "./package/vocab.txt",
"negative_schema":[ "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei", "ei"]
}
]
[
{
"prob": 1.0,
"data_func": "image_text_match",
"task_name": "image_text_match",
"test_filelist": "./conf/flickr/flickr_retrieval_test.filelist",
"dev_filelist": "./conf/flickr/flickr_retrieval_dev.filelist",
"train_filelist": "./conf/flickr/flickr_retrieval_train.filelist",
"train_image_path":"./data/flickr/flickr_bottom_up_10_36.csv.0",
"dev_image_path": "data/flickr/flickr_dev_10_36.csv.0",
"test_image_path": "data/flickr/flickr_test_10_36.csv.0",
"Proprocessor": "PreprocessorBasic",
"tokenizer_name" : "FullTokenizer",
"vocab_path" : "./package/vocab.txt"
}
]
output_model_path="output_refcoco_plus"
lr_scheduler="manual_warmup_decay"
decay_steps="3290;4700;7050"
lr_decay_ratio=0.2
num_train_steps=26640
SAVE_STEPS=470
WARMUP_STEPS=940
BATCH_SIZE=32
VALID_STEPS=20000
LR_RATE=2e-5
WEIGHT_DECAY=0.01
MAX_LEN=80
layer_decay_rate=0.9
./data/refcoco_plus/refer_testA.part
./data/refcoco_plus/refer_testB.part
./data/refcoco_plus/train_part_sample
./data/refcoco_plus/refer_val.part
[
{
"prob": 1.0,
"data_func": "refcoco+",
"task_name": "refcoco+",
"train_filelist": "./conf/refcoco_plus/refer_train.filelist",
"Proprocessor": "PreprocessorBasic",
"tokenizer_name" : "FullTokenizer",
"vocab_path" : "./package/vocab.txt"
}
]
[
{
"prob": 1.0,
"data_func": "refcoco+",
"task_name": "refcoco+",
"val_filelist": "./conf/refcoco_plus/refer_val.filelist",
"testA_filelist": "./conf/refcoco_plus/refer_testA.filelist",
"testB_filelist": "./conf/refcoco_plus/refer_testB.filelist",
"Proprocessor": "PreprocessorBasic",
"tokenizer_name" : "FullTokenizer",
"vocab_path" : "./package/vocab.txt"
}
]
lr_decay_dict_file="./conf/vqa/vqa_finetune_decay.list"
output_model_path="output_20_mask"
lr_scheduler="manual_warmup_decay"
num_train_steps=50200
SAVE_STEPS=2510
WARMUP_STEPS=3710
BATCH_SIZE=16
LR_RATE=1e-4
decay_steps="15100;22590"
WEIGHT_DECAY=0.01
layer_decay_rate=0.9
text_init_layers=6
n_layers=18
MAX_LEN=16
task_group_json=./conf/vqa/task_vqa.json
[
{
"prob": 1.0,
"valid_filelist": "./conf/vqa/vqa_valid.filelist",
"vg_train_filelist": "./conf/vqa/vg_train.filelist",
"train_filelist": "./conf/vqa/vqa_train.filelist",
"num_class": 3129,
"classifier_hid_size": 2048,
"vg_init_epochs": 2,
"Proprocessor": "PreprocessorBasic",
"tokenizer_name" : "FullTokenizer",
"vocab_path" : "./package/vocab.txt"
}
]
[
{
"prob": 1.0,
"data_func": "image_text_match",
"task_name": "image_text_match",
"val_filelist": "./conf/vqa/vqa_val.filelist",
"test_dev_filelist": "./conf/vqa/vqa_test_dev.filelist",
"test_std_filelist": "./conf/vqa/vqa_test_std.filelist",
"pickle_file": "./data/vqa/trainval_label2ans.pkl",
"num_class": 3129,
"classifier_hid_size": 2048,
"Proprocessor": "PreprocessorBasic",
"tokenizer_name" : "FullTokenizer",
"vocab_path" : "./package/vocab.txt"
}
]
./data/vqa/vg_train_part_sample
vqa_fc_w_0 2.5
vqa_fc_w_1 2.5
vqa_fc_b_0 2.5
vqa_fc_b_1 2.5
./data/vqa/test_dev_part_sample
./data/vqa/test_std_part_sample
./data/vqa/train_part_sample
67 0 A group of people stand in the back of a truck filled with cotton .
(lp1
S''
aVwoods
p2
aVtrash can
p3
aVhanging
p4
aVwooden
p5
aVcooking
p6
aVchina
p7
aVkids
p8
aVbike rack
p9
aVon phone
p10
aVmusic
p11
aVtravel
p12
aVtulip
p13
aVarrow
p14
aVbranch
p15
aVchevron
p16
aVmouth
p17
aVon right
p18
aVrice
p19
aVplate
p20
aVlots
p21
aVnature
p22
aVfruits
p23
aVthrowing frisbee
p24
aVblonde
p25
aVlife jacket
p26
aVham
p27
aVhay
p28
aVhat
p29
aVto get to other side
p30
aV12:35
p31
aVshadow
p32
aV12:30
p33
aVcrown
p34
aVcrows
p35
aVbottom
p36
aVfish
p37
aVbenches
p38
aVfabric
p39
aVcongratulations
p40
aVpassenger
p41
aVana
p42
aVtriangles
p43
aVroll
p44
aVchain
p45
aVchair
p46
aVcrates
p47
aV11:45
p48
aVshirts
p49
aVmirrors
p50
aVparachute
p51
aVsurfing
p52
aVmango
p53
aVyears
p54
aVapron
p55
aVentering
p56
aV6:20
p57
aV6:25
p58
aVin her hand
p59
aVnight time
p60
aVzoo
p61
aVprinter
p62
aVsuitcases
p63
aVs
aV4:20
p64
aVwine tasting
p65
aVwest
p66
aVnewspaper
p67
aVwires
p68
aVsingapore
p69
aVadvertisement
p70
aVfor photo
p71
aVmulti colored
p72
aVsteamed
p73
aVtraffic
p74
aVbrushing
p75
aVstagecoach
p76
aVthailand
p77
aVpelicans
p78
aVpull
p79
aVdirty
p80
aVrust
p81
aVtennis shoes
p82
aVcream
p83
aVpuppy
p84
aVin box
p85
aVwaving
p86
aVceramic
p87
aVskate park
p88
aVsand
p89
aVbathing suit
p90
aVon tower
p91
aV193
p92
aVfast food
p93
aVclock
p94
aVfull
p95
aVadidas
p96
aVmore
p97
aV100 feet
p98
aVdoor
p99
aVseagull
p100
aVhuge
p101
aVcity bus
p102
aVpaper
p103
aVsigns
p104
aVsmiling
p105
aVsauce
p106
aVweeds
p107
aVvery tall
p108
aVright side
p109
aVtheater
p110
aVemirates
p111
aVgraffiti
p112
aVyellow and red
p113
aVback left
p114
aVvisilab
p115
aVsingles
p116
aVbarrel
p117
aVblender
p118
aVsunny
p119
aVcaramel
p120
aVanalog
p121
aVsnowboarder
p122
aVbell
p123
aVbelt
p124
aVawake
p125
aV37
p126
aVcheddar
p127
aVemergency
p128
aVcouple
p129
aVharry potter
p130
aVchina airlines
p131
aVhawaiian
p132
aVlufthansa
p133
aVland
p134
aVwindowsill
p135
aVvideo
p136
aVscania
p137
aVlistening to music
p138
aVpencil
p139
aVbaby
p140
aVpocket
p141
aVrobot
p142
aVspatula
p143
aVmutt
p144
aVsnake
p145
aV1:55
p146
aVvalentine's day
p147
aV1:50
p148
aVglasses
p149
aVbooks
p150
aVin cabbage town
p151
aVfestival
p152
aVbig ben
p153
aVstomach
p154
aVschool bus
p155
aVmountainous
p156
aVdishes
p157
aVhot dog
p158
aVthanksgiving
p159
aVteddy bear
p160
aVwatching tv
p161
aVjar
p162
aVpolice officer
p163
aVjal
p164
aVtape
p165
aVriding
p166
aVstyrofoam
p167
aVframe
p168
aVroman numerals
p169
aVskateboarding
p170
aVstaring
p171
aVdoorway
p172
aVcommuter
p173
aVtattoo
p174
aVoff
p175
aVclocks
p176
aVwet
p177
aVstreet sign
p178
aVpier
p179
aVflying kites
p180
aVpomeranian
p181
aVtaking selfie
p182
aVswimming
p183
aVletters
p184
aVcat and dog
p185
aVponytail
p186
aVdon't know
p187
aVby window
p188
aVwater ski
p189
aVpanda
p190
aVholding umbrella
p191
aVski slope
p192
aVplow
p193
aVsweater
p194
aVcoins
p195
aVcomcast
p196
aV1990
p197
aVnuts
p198
aVladder
p199
aVcell phone
p200
aVarch
p201
aV1 foot
p202
aVcoffee
p203
aV1 in back
p204
aVwhite and blue
p205
aVsafe
p206
aVl
aVturn right
p207
aVsled
p208
aVbarrier
p209
aVtree branch
p210
aVon chair
p211
aVjohn
p212
aVdogs
p213
aVon grass
p214
aVcereal
p215
aVcrosstown
p216
aVgiraffes
p217
aVgravy
p218
aVgermany
p219
aVrectangle
p220
aVt shirt
p221
aVaway
p222
aVplaying game
p223
aVskull and crossbones
p224
aVkitchen
p225
aV3 feet
p226
aVonion rings
p227
aVcylinder
p228
aVtissue
p229
aVcone
p230
aVwheel
p231
aVhand
p232
aVyamaha
p233
aVcooler
p234
aVnight
p235
aVcirrus
p236
aVbottom right
p237
aVsailboats
p238
aVsilverware
p239
aVgun
p240
aVhorseback riding
p241
aVslow down
p242
aV2:30
p243
aV2:35
p244
aVteacher
p245
aV3:25
p246
aV3:20
p247
aVpillow
p248
aVgreen and red
p249
aVuphill
p250
aVflat screen
p251
aVproduce
p252
aVvases
p253
aVcorona
p254
aVserving
p255
aVstill
p256
aVtiara
p257
aVtop left
p258
aVslacks
p259
aVnot
p260
aVnow
p261
aV5 star
p262
aVdo not enter
p263
aVstir fry
p264
aVring
p265
aVbutton up
p266
aVcoffee pot
p267
aVhsbc
p268
aVapples
p269
aVwindsor
p270
aVamerica
p271
aV7:35
p272
aVcauliflower
p273
aVliving room
p274
aVjeans
p275
aVtropical
p276
aVskiers
p277
aVplaying baseball
p278
aVfoam
p279
aVwestin
p280
aVtaxi
p281
aVneon
p282
aVfoot
p283
aVswim
p284
aVblackberry
p285
aVapartments
p286
aVkite flying
p287
aVstars
p288
aVpitcher
p289
aVomelet
p290
aV600
p291
aVcabinets
p292
aVnot at all
p293
aVpoor
p294
aVpoop
p295
aVpool
p296
aVred bull
p297
aVtracks
p298
aVall way
p299
aVgrill
p300
aVus airways
p301
aVhorizontal
p302
aVchopsticks
p303
aVcomputers
p304
aVevening
p305
aVtalking on phone
p306
aVstarbucks
p307
aVheavy
p308
aVsafety
p309
aV7
aVhouses
p310
aVdishwasher
p311
aVamerican
p312
aVhorse
p313
aVstation
p314
aVon track
p315
aVboredom
p316
aVtoward
p317
aVsilver and red
p318
aVin grass
p319
aVgray and black
p320
aVtongs
p321
aVclose
p322
aVpictures
p323
aVzipper
p324
aVgmc
p325
aVempty
p326
aVjuice
p327
aV000
p328
aVcherries
p329
aVmotorola
p330
aVvests
p331
aVcones
p332
aVno man
p333
aVtowards
p334
aVblue and white
p335
aVsquirrel
p336
aVat&t
p337
aVworking
p338
aVon car
p339
aVe
aVpizza hut
p340
aVwave
p341
aVsnowboards
p342
aVjump
p343
aVpancake
p344
aVon sidewalk
p345
aVcamera
p346
aVvisibility
p347
aVi don't know
p348
aVon woman
p349
aVtoaster oven
p350
aVbandana
p351
aVfarm
p352
aV10:35
p353
aV10:45
p354
aVtaking pictures
p355
aVcostume
p356
aVslide
p357
aVankle
p358
aV9:50
p359
aVottoman
p360
aVbaggage claim
p361
aVtrucks
p362
aVtent
p363
aVkicking
p364
aVlooking at phone
p365
aVcontroller
p366
aVsmoking
p367
aVtoothbrushes
p368
aVspeaker
p369
aVparty
p370
aV42
p371
aVballoons
p372
aVriding elephant
p373
aVpropeller
p374
aVintersection
p375
aVlibrary
p376
aVhome
p377
aVgrinding
p378
aVblue and gray
p379
aVnorth
p380
aVstrawberries
p381
aVdisplay
p382
aVfinch
p383
aVstar
p384
aVfoil
p385
aVsamsung
p386
aVhelmets
p387
aVwakeboard
p388
aVmen
p389
aVjoshua
p390
aVsliced
p391
aVjackets
p392
aVunder
p393
aVfor fun
p394
aVroom
p395
aVroof
p396
aVcheckerboard
p397
aVbefore
p398
aVglazed
p399
aVascending
p400
aVhoodie
p401
aVblinders
p402
aVin street
p403
aVlettuce
p404
aVfrosted
p405
aVginger
p406
aVsan diego
p407
aVblue
p408
aVmario
p409
aVno cat
p410
aVhotel room
p411
aVcelery
p412
aVwatermelon
p413
aVears
p414
aVlotion
p415
aVparsley
p416
aVmozzarella
p417
aVbasil
p418
aVhot sauce
p419
aVcane
p420
aVradiator
p421
aVyes
p422
aVstrap
p423
aVkitesurfing
p424
aVdead
p425
aV10 years
p426
aVham and cheese
p427
aVskateboarder
p428
aVmagazine
p429
aVafternoon
p430
aVselfie
p431
aVdown
p432
aVtennis
p433
aVbatman
p434
aVlanding
p435
aVmuffins
p436
aVhandicap
p437
aVjacket
p438
aVriding bikes
p439
aVfather
p440
aV0
aVround
p441
aVfrosting
p442
aVbox
p443
aVboy
p444
aVbow
p445
aVbob
p446
aVsun hat
p447
aVpizza box
p448
aVman in middle
p449
aVtablecloth
p450
aVbasket
p451
aVshoes
p452
aVpolice
p453
aVmonitor
p454
aVlunch
p455
aVman made
p456
aVelephants
p457
aVfirst
p458
aV4 ft
p459
aVpastry
p460
aVjungle
p461
aV200
p462
aVsheets
p463
aVnot high
p464
aVmorning
p465
aVseat
p466
aVboundaries
p467
aV2:00
p468
aVtour
p469
aVpacifier
p470
aVsquash
p471
aVlaptops
p472
aVriding motorcycle
p473
aVsleeping
p474
aVwestjet
p475
aVoutdoor
p476
aVbow tie
p477
aVwii remotes
p478
aVgenetics
p479
aVknife
p480
aVpockets
p481
aVharness
p482
aVtraffic lights
p483
aVon water
p484
aVon road
p485
aVsitting
p486
aVwashing
p487
aVafrica
p488
aVdachshund
p489
aVbrushing her teeth
p490
aVnumbers
p491
aVcomfort
p492
aVcoming
p493
aVtoiletries
p494
aVdragon
p495
aVfaucet
p496
aVtop
p497
aVfork and spoon
p498
aVurban
p499
aV1:05
p500
aV1:00
p501
aVcouch
p502
aVeggs
p503
aV11:10
p504
aV11:15
p505
aVopaque
p506
aVmotorcycles
p507
aVdoubles
p508
aVmouthwash
p509
aVmailbox
p510
aVrestaurant
p511
aVbaseball bat
p512
aVvery big
p513
aVpeppers
p514
aVkodak
p515
aVwindmill
p516
aVstripe
p517
aVshape
p518
aVcut
p519
aVcup
p520
aVeaster
p521
aVcandle
p522
aV5:18
p523
aV5:15
p524
aV5:10
p525
aVtennis court
p526
aVplanes
p527
aVcurb
p528
aVcafe
p529
aVmayonnaise
p530
aVcooked
p531
aVpink and blue
p532
aVdessert
p533
aVlegos
p534
aVpacking
p535
aVorange and yellow
p536
aVketchup and mustard
p537
aVmustache
p538
aVhorns
p539
aVpotato salad
p540
aVbackwards
p541
aVyard
p542
aVskateboard
p543
aVyarn
p544
aVbulldog
p545
aVnobody
p546
aVbeets
p547
aVmercedes
p548
aVdaffodils
p549
aVracket
p550
aVto dry
p551
aVhammock
p552
aVstatues
p553
aVdenim
p554
aVflashlight
p555
aVcarnations
p556
aVhair
p557
aVshaking hands
p558
aVsocks
p559
aVfemale
p560
aVlooking at camera
p561
aVcoke
p562
aVflip
p563
aVcircus
p564
aVdresser
p565
aVcocker spaniel
p566
aVschool
p567
aVfeathers
p568
aVumbrella
p569
aVluggage
p570
aVleaves
p571
aVto left
p572
aV2:05
p573
aV400
p574
aV3:15
p575
aV3:10
p576
aVdetroit
p577
aVpublic market center
p578
aVsmall
p579
aVsoda
p580
aVfire truck
p581
aV700
p582
aVschnauzer
p583
aVsparrow
p584
aVfork and knife
p585
aVwedding
p586
aVbeads
p587
aVmud
p588
aVmug
p589
aVfinger
p590
aVherding
p591
aVnews
p592
aVbehind fence
p593
aVelm
p594
aVt shirt and jeans
p595
aVconference
p596
aVmonkey
p597
aVn
aVnike
p598
aVocean
p599
aVcherry
p600
aVteddy bears
p601
aVwii controllers
p602
aVdescending
p603
aVcoffee maker
p604
aVkhaki
p605
aVdinner
p606
aVtoothpicks
p607
aVfern
p608
aVbats
p609
aVsunlight
p610
aVtowing
p611
aVkiting
p612
aVsetting
p613
aVpapers
p614
aVpicture
p615
aVfootball
p616
aVlong time
p617
aVposts
p618
aVcloudy
p619
aVpork
p620
aVtuxedo
p621
aV1 in middle
p622
aVpickle
p623
aVnursing
p624
aVblack and brown
p625
aVsailboat
p626
aV2 years
p627
aVplatform
p628
aVfarmer
p629
aVcutting hair
p630
aVcatching
p631
aV100 year party ct
p632
aVturn
p633
aVmen's
p634
aVbatting
p635
aVsurf
p636
aVequestrian
p637
aVwii remote
p638
aV1 hour
p639
aVguitar
p640
aVturkey
p641
aVdirection
p642
aVpilot
p643
aVcase
p644
aVstatue
p645
aVpeanut butter
p646
aVtowel
p647
aVtower
p648
aVcompetition
p649
aVmoon
p650
aVburton
p651
aVon floor
p652
aVbarber shop
p653
aVflats
p654
aVgrass
p655
aVroses
p656
aVpillows
p657
aVqantas
p658
aVapartment
p659
aVdairy
p660
aVcrest
p661
aVsub
p662
aVsun
p663
aVsuv
p664
aVchristian
p665
aVdonuts
p666
aVhorses
p667
aVflat
p668
aVflag
p669
aVlighting
p670
aVshort
p671
aVshore
p672
aV10:50
p673
aV10:55
p674
aVsoccer
p675
aVclouds
p676
aValcohol
p677
aVhill
p678
aVsnowboarding
p679
aVurinal
p680
aVstork
p681
aVstorm
p682
aVstore
p683
aVsurfboard
p684
aVking
p685
aV1 on right
p686
aV8:35
p687
aVskyscrapers
p688
aVwilson
p689
aVelectric
p690
aVpatterned
p691
aVnational express
p692
aVopponent
p693
aVtriangle
p694
aV9:05
p695
aVsweet
p696
aVshirt
p697
aV9
aVcement
p698
aVprince
p699
aVping pong
p700
aVphone
p701
aVplaying video game
p702
aVsports
p703
aVsepia
p704
aVrainbow
p705
aVpeach
p706
aVpeace
p707
aVparking meter
p708
aVwindy
p709
aVvery high
p710
aVgray and white
p711
aVcompaq
p712
aVkangaroo
p713
aVleaving
p714
aVcell phones
p715
aVduck
p716
aVfrog
p717
aVparrots
p718
aVhammer time
p719
aVwrist
p720
aVstuffed animal
p721
aVon bed
p722
aVout
p723
aVbrown and white
p724
aVorganic
p725
aVg
aVtennis racket
p726
aVumbrellas
p727
aVunknown
p728
aVblue and black
p729
aVdreadlocks
p730
aVclip
p731
aVpipe
p732
aVcubs
p733
aVmarried
p734
aVidentification
p735
aVon his face
p736
aVstucco
p737
aVunclear
p738
aVmotorbike
p739
aVrooster
p740
aVcamel
p741
aVname
p742
aVcan't tell
p743
aVin middle
p744
aVpasta
p745
aVspiral
p746
aVcoffee cup
p747
aVcake
p748
aVmaroon
p749
aVspace shuttle
p750
aVleft and right
p751
aVyellow
p752
aVelectricity
p753
aVdeli
p754
aVdell
p755
aVeagle
p756
aVbehind woman
p757
aVmilitary
p758
aVon dresser
p759
aVmachine
p760
aVgaming
p761
aVin sky
p762
aVwing
p763
aVwind
p764
aVwine
p765
aVbaseball uniform
p766
aVnew orleans
p767
aVsilver
p768
aVbowls
p769
aVlego
p770
aVguitar hero
p771
aVlegs
p772
aVman on left
p773
aVperson
p774
aVin stands
p775
aVzebras
p776
aVvictoria
p777
aVlittle girl
p778
aVrecliner
p779
aVmenu
p780
aVfair
p781
aVlooking out window
p782
aVpirate
p783
aVvaio
p784
aVchrome
p785
aVlife
p786
aVlift
p787
aVchild
p788
aVchili
p789
aVabove toilet
p790
aVgray and red
p791
aVbabies
p792
aVbird feeder
p793
aVroast beef
p794
aVmain street
p795
aVtrain car
p796
aVpavement
p797
aVsteps
p798
aVpeople
p799
aVfox
p800
aVfog
p801
aVhappy
p802
aVtigers
p803
aVlays
p804
aV4 way
p805
aVpeacock
p806
aVvenice
p807
aVollie
p808
aVsubway
p809
aVteal
p810
aVteam
p811
aVcurrent
p812
aVno smoking
p813
aVlove
p814
aVsunbathing
p815
aVheineken
p816
aVwinter
p817
aVelephant
p818
aVwheelchair
p819
aV8 feet
p820
aVcolorado
p821
aVbagels
p822
aVpolka dot
p823
aV2 men
p824
aVgatorade
p825
aVarriving
p826
aVbowling
p827
aVnot there
p828
aVcontainer
p829
aVdodgers
p830
aV12:10
p831
aV12:15
p832
aVostrich
p833
aV2
aVtyping
p834
aVgive way
p835
aVhockey
p836
aVrobe
p837
aVbridge
p838
aVbarbed wire
p839
aVcantaloupe
p840
aVjets
p841
aV1:35
p842
aV1:30
p843
aVtwin
p844
aVteddy
p845
aVtoothpaste
p846
aVvacation
p847
aVorange juice
p848
aVno dog
p849
aVskier
p850
aVorange and black
p851
aVclay
p852
aVtying tie
p853
aVtube
p854
aV6:40
p855
aV6:45
p856
aVtake off
p857
aVbundt
p858
aVaa
p859
aVgone
p860
aVam
p861
aV5:45
p862
aV5:40
p863
aVfor sale
p864
aVdinosaur
p865
aV2 hours
p866
aVdead end
p867
aVbottom left
p868
aVusing laptop
p869
aVyoung
p870
aVindoors
p871
aVparking garage
p872
aVpitching
p873
aV4:00
p874
aV4:05
p875
aVinformation
p876
aVnot very
p877
aVno parking
p878
aVcountryside
p879
aVbackyard
p880
aVhauling
p881
aVpans
p882
aVrottweiler
p883
aVmexican
p884
aVswimsuit
p885
aVharbor
p886
aVair canada
p887
aVhighway
p888
aVporcelain
p889
aVhydrant
p890
aVw
aV2 people
p891
aVcartoon
p892
aVobama
p893
aVrefrigerators
p894
aVplay
p895
aVcover
p896
aVsleeve
p897
aVcrosswalk
p898
aV1 year
p899
aVwhole
p900
aVlanyard
p901
aVhitting
p902
aVfire
p903
aVfarmers market
p904
aVwhite and brown
p905
aVsesame
p906
aVknee pads
p907
aVaudi
p908
aVcucumber
p909
aVstainless steel
p910
aVbaseball glove
p911
aVsanta
p912
aVcobblestone
p913
aVtaking picture
p914
aVgarden
p915
aVindex
p916
aVtwins
p917
aVbird
p918
aVleg
p919
aVbanana split
p920
aVmariners
p921
aVstanding
p922
aVcasserole
p923
aVhigh
p924
aVanimal
p925
aVstop
p926
aVwallpaper
p927
aVnotebook
p928
aVlogitech
p929
aVnext to toilet
p930
aVcatch frisbee
p931
aVjesus
p932
aVowner
p933
aVsteel
p934
aV2:55
p935
aV2:50
p936
aVsoap
p937
aVbiker
p938
aVbikes
p939
aVnot possible
p940
aVamtrak
p941
aVemail
p942
aVblack and red
p943
aVcd
p944
aVripe
p945
aVnorth face
p946
aVball
p947
aVdusk
p948
aVbald
p949
aVoveralls
p950
aVbananas
p951
aV7:55
p952
aVfire hydrant
p953
aVconference room
p954
aVresidential
p955
aVto see
p956
aVonion
p957
aVtransport
p958
aVfedex
p959
aVseeds
p960
aVbud light
p961
aVdelivery
p962
aVconstruction
p963
aVsmooth
p964
aVvolvo
p965
aVblack and yellow
p966
aVmushroom
p967
aVfootprints
p968
aVrural
p969
aVbuoy
p970
aVall
p971
aVwreath
p972
aVblanket
p973
aVwomen's
p974
aVhome plate
p975
aVlizard
p976
aVabove
p977
aVsmoke
p978
aVindians
p979
aVboats
p980
aVgreen
p981
aVgiraffe
p982
aVon tray
p983
aVlying down
p984
aVin vase
p985
aVsliding
p986
aVpooping
p987
aVtrain tracks
p988
aVgoalie
p989
aVdc
p990
aVlemonade
p991
aVcow
p992
aVplates
p993
aVrail
p994
aVrain
p995
aVcondiments
p996
aVhills
p997
aVhilly
p998
aVski pole
p999
aVgrizzly
p1000
aVsecurity
p1001
aVantique
p1002
aVtie dye
p1003
aVpurple
p1004
aVmantle
p1005
aV20
p1006
aVtoothpick
p1007
aVgrape
p1008
aVp
aVspace needle
p1009
aVhiking
p1010
aVsnowboard
p1011
aVplaying tennis
p1012
aVcandles
p1013
aVhaircut
p1014
aVcold
p1015
aVbirds
p1016
aVwindow
p1017
aVhalf
p1018
aVgoat
p1019
aVgoose
p1020
aVfires
p1021
aVspace
p1022
aVsize
p1023
aVsheep
p1024
aVsheet
p1025
aVfriend
p1026
aV10:08
p1027
aV10:05
p1028
aV10:00
p1029
aVfork
p1030
aVangel
p1031
aVbreakfast
p1032
aVsunset
p1033
aVcat food
p1034
aVground
p1035
aVremotes
p1036
aVred and yellow
p1037
aVconcert
p1038
aVacross street
p1039
aVforehand
p1040
aVmattress
p1041
aVpigtails
p1042
aVwhite and orange
p1043
aV3:45
p1044
aVheels
p1045
aVfruit
p1046
aVburrito
p1047
aVpitbull
p1048
aVin back
p1049
aVdesktop
p1050
aV9:35
p1051
aVred and silver
p1052
aV9:30
p1053
aVreal
p1054
aVlady
p1055
aVrear
p1056
aVfell
p1057
aVheater
p1058
aVsmiley face
p1059
aVbucket
p1060
aVantenna
p1061
aVstorage
p1062
aVpolar
p1063
aVat table
p1064
aVpointing
p1065
aVvery
p1066
aVbathtub
p1067
aVcolored
p1068
aVahead
p1069
aVsoldier
p1070
aVbroke
p1071
aVcruise ship
p1072
aVhitting ball
p1073
aVfood
p1074
aVtrailer
p1075
aVwhite and gray
p1076
aVdirt
p1077
aVpug
p1078
aVbase
p1079
aVbnsf
p1080
aVcoleslaw
p1081
aVdog food
p1082
aVeaten
p1083
aVcopyright
p1084
aVtrees
p1085
aVon horse
p1086
aVgloves
p1087
aVswitzerland
p1088
aVcorgi
p1089
aVbadminton
p1090
aVsingle engine
p1091
aVused
p1092
aVgrind
p1093
aVfly kite
p1094
aVqatar
p1095
aVgolden retriever
p1096
aVrun
p1097
aVrug
p1098
aVcigarette
p1099
aV72
p1100
aVhimself
p1101
aVno sign
p1102
aVrussian
p1103
aVadults
p1104
aVcotton
p1105
aV1 way
p1106
aVspectators
p1107
aVge
p1108
aVgo
p1109
aVasparagus
p1110
aVcupcakes
p1111
aVright 1
p1112
aVcloud
p1113
aVplaying frisbee
p1114
aVmitsubishi
p1115
aVdonut
p1116
aVhalloween
p1117
aVbathing
p1118
aVhigh chair
p1119
aVherself
p1120
aVphotograph
p1121
aVbed
p1122
aVdefense
p1123
aVcarrots
p1124
aVasics
p1125
aVsprinkles
p1126
aVvisor
p1127
aVblack and gray
p1128
aVparked
p1129
aV120
p1130
aVneither
p1131
aVflickr
p1132
aVtank top
p1133
aVreflection
p1134
aVi
aVgerman shepherd
p1135
aVtaller
p1136
aVcandy
p1137
aVcontrollers
p1138
aVmushrooms
p1139
aVmississippi
p1140
aVsuzuki
p1141
aVbackhand
p1142
aVhp
p1143
aVcarriage
p1144
aVon elephant
p1145
aVbalcony
p1146
aVnintendo
p1147
aVblue and pink
p1148
aVsnowflakes
p1149
aVups
p1150
aVclothes
p1151
aVnew
p1152
aVnet
p1153
aVnever
p1154
aVremote control
p1155
aVcardboard
p1156
aVkeyboard
p1157
aVdawn
p1158
aVenclosure
p1159
aVsurprise
p1160
aVbehind clouds
p1161
aVhood
p1162
aVgirls
p1163
aVblue jay
p1164
aVno left turn
p1165
aVleft side
p1166
aVplaystation
p1167
aVlaying down
p1168
aVtablet
p1169
aVprotection
p1170
aVorange and blue
p1171
aVtank
p1172
aVin
p1173
aVbottles
p1174
aV50 feet
p1175
aVsunflower
p1176
aVdining room
p1177
aVdaytime
p1178
aVwhite and yellow
p1179
aVyellow and white
p1180
aVford
p1181
aVsticks
p1182
aVclassic
p1183
aVship
p1184
aVdigital
p1185
aVred sox
p1186
aVskis
p1187
aVstring
p1188
aVmagnet
p1189
aVtransportation
p1190
aVeaston
p1191
aVwallet
p1192
aVpink
p1193
aVrays
p1194
aVto catch frisbee
p1195
aVpine
p1196
aVtile
p1197
aVbiplane
p1198
aVcute
p1199
aVstability
p1200
aVmarshmallows
p1201
aVsmoothie
p1202
aVtexas
p1203
aVtropicana
p1204
aVgreen and brown
p1205
aVmonster
p1206
aVlong
p1207
aVvictorian
p1208
aVknives
p1209
aVtractor
p1210
aVcoconut
p1211
aVwhite and green
p1212
aVshirt and tie
p1213
aVnapkin
p1214
aVtennis player
p1215
aVblack
p1216
aVcalendar
p1217
aVpuma
p1218
aVbehind
p1219
aVreading
p1220
aVtub
p1221
aVno plate
p1222
aVrugby
p1223
aVbmw
p1224
aVcaptivity
p1225
aVdrain
p1226
aVkrispy kreme
p1227
aV4
aVblankets
p1228
aVtoilet brush
p1229
aVnose
p1230
aVdoll
p1231
aVdole
p1232
aVmustard
p1233
aVbroken
p1234
aVdrums
p1235
aVsoccer field
p1236
aVisland
p1237
aVstrawberry
p1238
aVsuitcase
p1239
aVwhite house
p1240
aVfield
p1241
aVstudents
p1242
aVremote
p1243
aVbuffalo
p1244
aVfall
p1245
aVgrilled cheese
p1246
aVmother and child
p1247
aVstool
p1248
aVsurfer
p1249
aVfirst base
p1250
aVairport
p1251
aVtown
p1252
aVmiddle
p1253
aVmarble
p1254
aVblood
p1255
aValps
p1256
aVcoat
p1257
aVpaw
p1258
aVcoal
p1259
aVdough
p1260
aVpens
p1261
aVmales
p1262
aVclams
p1263
aVpigeon
p1264
aVpaint
p1265
aVsteering wheel
p1266
aVb
aVcatch ball
p1267
aVtraveling
p1268
aVthomas
p1269
aVplay tennis
p1270
aVbatter
p1271
aVthrow frisbee
p1272
aVaround neck
p1273
aV1 on left
p1274
aVon rock
p1275
aVireland
p1276
aVto hit ball
p1277
aVvery old
p1278
aVfur
p1279
aVfun
p1280
aVsandals
p1281
aVbrushing hair
p1282
aV12:40
p1283
aV12:45
p1284
aVglove
p1285
aVback
p1286
aVwater skis
p1287
aVpen
p1288
aVsneakers
p1289
aVpee
p1290
aVin basket
p1291
aVbeans
p1292
aVwine bottle
p1293
aVjockey
p1294
aVhomemade
p1295
aVon table
p1296
aVforward
p1297
aVbored
p1298
aVlifeguard
p1299
aVcan't see it
p1300
aVstroller
p1301
aVgravel
p1302
aV11:30
p1303
aV11:35
p1304
aVchoppy
p1305
aVcrow
p1306
aVfisheye
p1307
aVholding phone
p1308
aVmouse pad
p1309
aVbaskets
p1310
aVriding horse
p1311
aVoval
p1312
aVbuildings
p1313
aVroundabout
p1314
aVmeter
p1315
aVlicense plate
p1316
aVla
p1317
aVdr pepper
p1318
aVelectronics
p1319
aVred
p1320
aVno number
p1321
aVusing computer
p1322
aVupside down
p1323
aVpots
p1324
aVshampoo
p1325
aVon sink
p1326
aV4:50
p1327
aV4:55
p1328
aVin snow
p1329
aVairplanes
p1330
aVarabic
p1331
aVstained glass
p1332
aVcenter
p1333
aVcoaster
p1334
aVbench
p1335
aVcitizen
p1336
aVshelter
p1337
aV1980
p1338
aVfirefighter
p1339
aVbus
p1340
aVbun
p1341
aVbookshelf
p1342
aVshearing
p1343
aVgranite
p1344
aVgas station
p1345
aVbaker
p1346
aVbaked
p1347
aVhats
p1348
aVtrolley
p1349
aVdog show
p1350
aVstreet
p1351
aVdisney
p1352
aVhundreds
p1353
aVme
p1354
aVblack and blue
p1355
aVend
p1356
aVcharging
p1357
aVin bowl
p1358
aVcroissant
p1359
aVplaster
p1360
aVmercedes benz
p1361
aVceiling
p1362
aVserve
p1363
aVwestern
p1364
aVburger
p1365
aVfrench fries
p1366
aVdonkey
p1367
aVfashion
p1368
aVtalking
p1369
aVsunglasses
p1370
aVliquor
p1371
aVthey aren't
p1372
aVtowels
p1373
aVbeef
p1374
aVbeer
p1375
aVcarnation
p1376
aVseveral
p1377
aVsoccer ball
p1378
aV4th of july
p1379
aV2:25
p1380
aV2:20
p1381
aVice
p1382
aV3:30
p1383
aVchristmas
p1384
aVcord
p1385
aVcorn
p1386
aVhead
p1387
aVmedium
p1388
aVheat
p1389
aVno
p1390
aVny
p1391
aVevergreen
p1392
aVbackward
p1393
aVdouble decker
p1394
aVon wall
p1395
aVheinz
p1396
aVcinnamon
p1397
aVshark
p1398
aV1st
p1399
aVhamburger
p1400
aVchocolate
p1401
aVtaking photo
p1402
aV7:00
p1403
aV7:05
p1404
aVzebra
p1405
aV1000
p1406
aVbudweiser
p1407
aVmillions
p1408
aVr
aVbaseball game
p1409
aVdining
p1410
aVon rack
p1411
aVfar right
p1412
aVblue and yellow
p1413
aVcolor
p1414
aVpole
p1415
aVpolo
p1416
aVcleats
p1417
aVspiderman
p1418
aVgood
p1419
aVhelmet
p1420
aVon skateboard
p1421
aVmodel
p1422
aVpacific
p1423
aVeiffel tower
p1424
aVparmesan
p1425
aVbroccoli and carrots
p1426
aVwoman
p1427
aVgrandfather
p1428
aVtea
p1429
aVdesign
p1430
aVdirections
p1431
aVbehind bus
p1432
aVdump truck
p1433
aVairplane
p1434
aVswim trunks
p1435
aVhandlebars
p1436
aVathletics
p1437
aVscratching
p1438
aVcaucasian
p1439
aVwisconsin
p1440
aVcollar
p1441
aVyellow and orange
p1442
aVstones
p1443
aVasphalt
p1444
aVrecently
p1445
aVbronze
p1446
aVpm
p1447
aVkorean air
p1448
aVdancing
p1449
aVpigeons
p1450
aVloading
p1451
aVmiami
p1452
aVdunkin donuts
p1453
aVfast
p1454
aVmountains
p1455
aVfries
p1456
aVfried
p1457
aVwarning
p1458
aVlemon
p1459
aVtrain station
p1460
aVpottery
p1461
aVnasa
p1462
aVdrinking
p1463
aVibm
p1464
aVair france
p1465
aVcheckers
p1466
aVoars
p1467
aVreds
p1468
aVusa
p1469
aVrackets
p1470
aVfew
p1471
aVst patrick's day
p1472
aVcarrot
p1473
aVunited
p1474
aVwetsuits
p1475
aVferris wheel
p1476
aVorange
p1477
aVon runway
p1478
aVfood processor
p1479
aVon motorcycle
p1480
aVagainst wall
p1481
aVit isn't
p1482
aVshallow
p1483
aVyellow and blue
p1484
aVfeeding giraffe
p1485
aVsunflowers
p1486
aVbraid
p1487
aV12 feet
p1488
aVmeatballs
p1489
aVblue and orange
p1490
aVout of focus
p1491
aVegg salad
p1492
aVpainted
p1493
aVcross
p1494
aVnathan's
p1495
aVyellow and green
p1496
aVfighting
p1497
aVbending
p1498
aVfishing
p1499
aVclydesdale
p1500
aVgreen and yellow
p1501
aVearrings
p1502
aVsaddle
p1503
aVowl
p1504
aVstop sign
p1505
aVcontinental
p1506
aVboardwalk
p1507
aVon toilet
p1508
aVraining
p1509
aVk
aVriding horses
p1510
aVstrike
p1511
aVdecoration
p1512
aVplunger
p1513
aVgreen beans
p1514
aVcaution
p1515
aV6 inches
p1516
aVitaly
p1517
aVcigarettes
p1518
aVrv
p1519
aVelmo
p1520
aVon pole
p1521
aVbasketball
p1522
aVbus driver
p1523
aVtree
p1524
aVshower
p1525
aVbleachers
p1526
aVy
aVcamo
p1527
aVhonda
p1528
aVrelish
p1529
aVtv stand
p1530
aVyogurt
p1531
aVcountry
p1532
aVgrazing
p1533
aVstadium
p1534
aVdots
p1535
aVbanana peel
p1536
aVnot sure
p1537
aVspeed limit
p1538
aVamerican flag
p1539
aVmagazines
p1540
aVinside
p1541
aV150
p1542
aVpink and white
p1543
aVskate
p1544
aVskiing
p1545
aVgerman
p1546
aVbicycles
p1547
aVicing
p1548
aVcollage
p1549
aVsign
p1550
aVpassengers
p1551
aVpetting horse
p1552
aVfake
p1553
aVwicker
p1554
aVangry
p1555
aVpumpkin
p1556
aVunder sink
p1557
aVnatural
p1558
aVpizza cutter
p1559
aVwater skiing
p1560
aVsquare
p1561
aVair force
p1562
aVsquares
p1563
aVcarrot cake
p1564
aVopen
p1565
aVcity
p1566
aVstuffed
p1567
aVcoats
p1568
aVdrywall
p1569
aVrussia
p1570
aVzebra and giraffe
p1571
aVkenmore
p1572
aVgreyhound
p1573
aVpalm tree
p1574
aVdrive
p1575
aVpeaches
p1576
aVbright
p1577
aVstanding still
p1578
aVcheckered
p1579
aVvertical
p1580
aVscreen
p1581
aVmany
p1582
aVmane
p1583
aVto right
p1584
aVpurse
p1585
aV3 inches
p1586
aVivy
p1587
aVhummingbird
p1588
aVmohawk
p1589
aVcilantro
p1590
aVdark
p1591
aVlobster
p1592
aVpickles
p1593
aVcross country
p1594
aVpedestal
p1595
aVthrow ball
p1596
aVmask
p1597
aVtv
p1598
aVtail
p1599
aVlarge
p1600
aVpearl
p1601
aV6
aVfrisbees
p1602
aV1:10
p1603
aV1:15
p1604
aVbidet
p1605
aV11:00
p1606
aV11:05
p1607
aVred and gray
p1608
aVwaves
p1609
aVvolleyball
p1610
aVring finger
p1611
aVcastle
p1612
aVfeeder
p1613
aVjet
p1614
aVmural
p1615
aVpedestrian
p1616
aVmicrosoft
p1617
aVgolf
p1618
aVgold
p1619
aVfreight
p1620
aVgrocery store
p1621
aVtaking off
p1622
aV5:25
p1623
aVpickup
p1624
aVon shore
p1625
aVchurch
p1626
aVtop hat
p1627
aVresting
p1628
aVstraight
p1629
aVleaning
p1630
aV870
p1631
aVconcrete
p1632
aVsydney
p1633
aVaustralia
p1634
aVskateboards
p1635
aVformal
p1636
aVd
aVspring
p1637
aVpalm
p1638
aVcurious
p1639
aVsprint
p1640
aVsuit
p1641
aVposter
p1642
aVup
p1643
aVus
p1644
aVuk
p1645
aVpedestrians
p1646
aVsesame seeds
p1647
aVfresh
p1648
aVrace
p1649
aVrack
p1650
aVspaghetti
p1651
aVsauerkraut
p1652
aVwindow sill
p1653
aVblack and pink
p1654
aVhalf full
p1655
aVred light
p1656
aVcabbage
p1657
aVhawk
p1658
aVtomato
p1659
aVcounter
p1660
aVantelope
p1661
aV2015
p1662
aV2016
p1663
aV2011
p1664
aV2010
p1665
aV2013
p1666
aV2012
p1667
aVfeta
p1668
aVdock
p1669
aVpowdered sugar
p1670
aVsniffing
p1671
aVmets
p1672
aVtents
p1673
aVcuddling
p1674
aVshorts
p1675
aV7 eleven
p1676
aVlily
p1677
aVface
p1678
aVpainting
p1679
aVbaseball player
p1680
aVbuttons
p1681
aVhandle
p1682
aVminnie mouse
p1683
aVlg
p1684
aVparking meters
p1685
aVrunway
p1686
aVwinnie pooh
p1687
aVdrink
p1688
aV350
p1689
aVdrawing
p1690
aVchairs
p1691
aViris
p1692
aVskatepark
p1693
aVfrisbee
p1694
aVcannot tell
p1695
aVghost
p1696
aVfruit salad
p1697
aVbike
p1698
aVoriental
p1699
aVtriumph
p1700
aVchef
p1701
aVuniforms
p1702
aValligator
p1703
aVstreet name
p1704
aVveggies
p1705
aVhappy birthday
p1706
aVbeach
p1707
aVpizza
p1708
aVafter
p1709
aVin air
p1710
aVsouth
p1711
aVsalon
p1712
aVwalgreens
p1713
aVjapan
p1714
aVavocado
p1715
aVsavory
p1716
aVband
p1717
aVbank
p1718
aVrocky
p1719
aVrocks
p1720
aVcrocs
p1721
aVlogs
p1722
aVlogo
p1723
aVthick
p1724
aVcardinals
p1725
aVin corner
p1726
aVon bench
p1727
aVwarmth
p1728
aVair
p1729
aVbunny
p1730
aVhispanic
p1731
aVphoto
p1732
aVboard
p1733
aVboxes
p1734
aVboxer
p1735
aVpost
p1736
aVoctopus
p1737
aVwax
p1738
aVwar
p1739
aVconverse
p1740
aVwelcome
p1741
aVsouthwest
p1742
aVbanana bread
p1743
aVkitten
p1744
aVkitchenaid
p1745
aVcargo
p1746
aVwildebeest
p1747
aVuniform
p1748
aVfuton
p1749
aVcar
p1750
aVcap
p1751
aVcat
p1752
aVcan
p1753
aVheart
p1754
aVclothing
p1755
aVfreezer
p1756
aVstripes
p1757
aVducati
p1758
aVstate farm
p1759
aVtoilet paper
p1760
aVblue team
p1761
aVwillow
p1762
aV1950
p1763
aVparmesan cheese
p1764
aVlooking
p1765
aVcars
p1766
aVcart
p1767
aVbritish
p1768
aVtelevision
p1769
aVsilver and black
p1770
aVphotoshop
p1771
aVonly
p1772
aVasleep
p1773
aVcolgate
p1774
aVholding baby
p1775
aVnikon
p1776
aVtrick
p1777
aVit's raining
p1778
aVlistening
p1779
aVt
aVwheat
p1780
aVvolkswagen
p1781
aVthrow
p1782
aVearring
p1783
aVlog
p1784
aVlow
p1785
aVlot
p1786
aVpaddle boarding
p1787
aVchristmas tree
p1788
aVpear
p1789
aVpeas
p1790
aV10:25
p1791
aV10:20
p1792
aVskeleton
p1793
aVpicnic table
p1794
aVwii controller
p1795
aVcookies
p1796
aVvegetarian
p1797
aVfamily
p1798
aVmickey mouse
p1799
aVtoys
p1800
aVcows
p1801
aVcrafts
p1802
aVhello kitty
p1803
aVsoon
p1804
aVcupcake
p1805
aVpeeing
p1806
aVstopped
p1807
aV9:12
p1808
aV9:15
p1809
aVdoughnuts
p1810
aVmaple
p1811
aVbanana
p1812
aVselling
p1813
aVcamouflage
p1814
aVbridle
p1815
aVmack
p1816
aV2.00
p1817
aVscooter
p1818
aVchiquita
p1819
aVchimney
p1820
aVcatcher
p1821
aVpedestrian crossing
p1822
aVrope
p1823
aVbikini
p1824
aVbiking
p1825
aVfedora
p1826
aVolder
p1827
aVjp morgan
p1828
aVgame
p1829
aVbrown and black
p1830
aVstew
p1831
aVblack white
p1832
aVfor balance
p1833
aVdog and cat
p1834
aVto eat
p1835
aVberries
p1836
aVdesserts
p1837
aVsuits
p1838
aVcluttered
p1839
aVjet ski
p1840
aVboots
p1841
aVsausage
p1842
aVcommercial
p1843
aVwii
p1844
aVcosmo
p1845
aVbeijing
p1846
aVskirt
p1847
aVchevrolet
p1848
aVbakery
p1849
aVamerican airlines
p1850
aVnokia
p1851
aVhands
p1852
aVon plane
p1853
aVin motion
p1854
aVlaundry
p1855
aVwatermark
p1856
aVcranes
p1857
aVsewing
p1858
aVcamper
p1859
aVtire
p1860
aVgray
p1861
aVspotted
p1862
aVj
aVhumans
p1863
aVdriveway
p1864
aVon beach
p1865
aVbest buy
p1866
aVswinging
p1867
aVno water
p1868
aVgoatee
p1869
aVno shirt
p1870
aVtell time
p1871
aVbeige
p1872
aVtongue
p1873
aVpastries
p1874
aVwashington
p1875
aVmountain dew
p1876
aVear
p1877
aVeat
p1878
aVtissues
p1879
aVutensils
p1880
aVfriends
p1881
aVasian
p1882
aVflip flops
p1883
aVvase
p1884
aVbaking
p1885
aVcan't see
p1886
aVhazy
p1887
aV100
p1888
aV101
p1889
aV106
p1890
aVdry
p1891
aVwarm
p1892
aVadult
p1893
aVchihuahua
p1894
aVmeat
p1895
aVside
p1896
aVbone
p1897
aVunder table
p1898
aVnavy
p1899
aVcomforter
p1900
aVhollywood
p1901
aVgym
p1902
aVboarding
p1903
aVdistance
p1904
aVmodern
p1905
aVmint
p1906
aVregular
p1907
aVpolar bear
p1908
aVm
aVdog
p1909
aVunder tree
p1910
aVtiles
p1911
aVtiled
p1912
aVgrilled
p1913
aVfluffy
p1914
aVwetsuit
p1915
aVon top
p1916
aVwebsite
p1917
aVmouse
p1918
aVgarlic
p1919
aVkia
p1920
aVkid
p1921
aVbutter
p1922
aVhuman
p1923
aVin front
p1924
aVflying kite
p1925
aVwings
p1926
aVshoulder
p1927
aVhe's not
p1928
aVover easy
p1929
aVpotato
p1930
aVteeth
p1931
aVcessna
p1932
aVwalking
p1933
aV12:28
p1934
aV12:25
p1935
aV12:20
p1936
aVkite string
p1937
aVgazebo
p1938
aVmakeup
p1939
aVeasyjet
p1940
aVkayak
p1941
aVlondon
p1942
aVconfused
p1943
aVmap
p1944
aVmat
p1945
aVmac
p1946
aVman
p1947
aVneck
p1948
aVtall
p1949
aVpitch
p1950
aVbunt
p1951
aVsafari
p1952
aVarrows
p1953
aVrock
p1954
aVcanada
p1955
aV3rd
p1956
aVsideways
p1957
aV11:55
p1958
aV11:50
p1959
aVheadband
p1960
aVheadphones
p1961
aVcheese
p1962
aVcrib
p1963
aVlittle
p1964
aVeyes
p1965
aVbutt
p1966
aV11
p1967
aV10
p1968
aV13
p1969
aV15
p1970
aV14
p1971
aVsailing
p1972
aV16
p1973
aV19
p1974
aV18
p1975
aVspeakers
p1976
aVcheesecake
p1977
aVshop
p1978
aVshow
p1979
aVshoe
p1980
aVcorner
p1981
aVpink and black
p1982
aV5 years
p1983
aVoven
p1984
aVcamping
p1985
aVdiamonds
p1986
aV6:35
p1987
aV6:30
p1988
aVglaze
p1989
aV5:50
p1990
aV5:55
p1991
aVoil
p1992
aVbillabong
p1993
aVclimbing
p1994
aVflip phone
p1995
aVmoney
p1996
aVshingles
p1997
aVlinoleum
p1998
aVmixer
p1999
aVshells
p2000
aV4:35
p2001
aV4:30
p2002
aVmagnets
p2003
aVdeer
p2004
aVdeep
p2005
aVhusky
p2006
aVworms
p2007
aVpublic
p2008
aVbowtie
p2009
aVbible
p2010
aVtexting
p2011
aVraft
p2012
aVlemons
p2013
aVon boat
p2014
aVwavy
p2015
aVplaying
p2016
aV24
p2017
aV25
p2018
aV26
p2019
aV27
p2020
aV21
p2021
aV22
p2022
aV23
p2023
aV28
p2024
aV29
p2025
aVon napkin
p2026
aVhard
p2027
aVpower lines
p2028
aVbrewers
p2029
aVprint
p2030
aVcopper
p2031
aVwine glasses
p2032
aVleash
p2033
aVpare
p2034
aVpark
p2035
aVzucchini
p2036
aVtrains
p2037
aVsalmon
p2038
aVmoss
p2039
aVred and black
p2040
aVdirt bike
p2041
aV8
aVwhite and pink
p2042
aVsilk
p2043
aVlion
p2044
aVfans
p2045
aVchampagne
p2046
aVpepsi
p2047
aVvans
p2048
aVgas
p2049
aVswiss
p2050
aVsemi
p2051
aVbamboo
p2052
aVred white blue
p2053
aVmirror
p2054
aVat camera
p2055
aVcameraman
p2056
aV300
p2057
aV5 feet
p2058
aVboys
p2059
aVski boots
p2060
aVgeese
p2061
aVsingle
p2062
aVgreen and blue
p2063
aVpipes
p2064
aVpaddle
p2065
aV39
p2066
aV38
p2067
aV33
p2068
aV32
p2069
aV31
p2070
aV30
p2071
aV36
p2072
aV35
p2073
aV34
p2074
aVvideo game
p2075
aV2:45
p2076
aV2:40
p2077
aVtrain
p2078
aVf
aVlamb
p2079
aVboxing
p2080
aVlamp
p2081
aVforest
p2082
aVterrier
p2083
aVtim hortons
p2084
aVsedan
p2085
aVlos angeles
p2086
aVnot likely
p2087
aVafrican american
p2088
aV1:40
p2089
aV1:45
p2090
aVcurtain
p2091
aVbull
p2092
aVmulti
p2093
aVstreet light
p2094
aVflamingo
p2095
aVitalian
p2096
aVpepper
p2097
aVcrossing
p2098
aVvegetables
p2099
aVmulticolored
p2100
aV15 feet
p2101
aVskating
p2102
aV48
p2103
aV49
p2104
aV46
p2105
aV47
p2106
aV44
p2107
aV45
p2108
aVoak
p2109
aV43
p2110
aV40
p2111
aV41
p2112
aVoar
p2113
aVformica
p2114
aVbicycle
p2115
aVstraight ahead
p2116
aVstraw
p2117
aVhospital
p2118
aVjumping
p2119
aVharley
p2120
aVtourist
p2121
aVshadows
p2122
aVpotatoes
p2123
aVwaffle
p2124
aVmotel
p2125
aVtoy
p2126
aVtow
p2127
aVcurtains
p2128
aVrolex
p2129
aVpaved
p2130
aVsnow
p2131
aVpineapple
p2132
aVradio
p2133
aVbus stop
p2134
aVgirl on right
p2135
aVwatch
p2136
aVamazon
p2137
aVeach other
p2138
aVnoodles
p2139
aV1 4
p2140
aVsprite
p2141
aVcoffee table
p2142
aVstopping
p2143
aV9:55
p2144
aVmother
p2145
aV59
p2146
aV55
p2147
aV54
p2148
aV56
p2149
aV51
p2150
aV50
p2151
aV53
p2152
aV52
p2153
aVpetting
p2154
aVplaying soccer
p2155
aVlamps
p2156
aVdalmatian
p2157
aVcowboy
p2158
aVvirgin
p2159
aVpurple and white
p2160
aVyacht
p2161
aVcoach
p2162
aVfocus
p2163
aVcook
p2164
aVcool
p2165
aVhawaii
p2166
aVhealthy
p2167
aVno light
p2168
aV10:15
p2169
aV10:10
p2170
aViphone
p2171
aV1
aVcardinal
p2172
aVpc
p2173
aVwagon
p2174
aVtrunks
p2175
aVice cream
p2176
aVboogie board
p2177
aVtulips
p2178
aVon counter
p2179
aVswing
p2180
aV3:50
p2181
aV3:55
p2182
aVhungry
p2183
aVyellow and black
p2184
aVfarmers
p2185
aVoutside
p2186
aVbricks
p2187
aVpot
p2188
aV64
p2189
aV65
p2190
aV66
p2191
aV68
p2192
aVhelicopter
p2193
aV500
p2194
aVengine
p2195
aVtiger
p2196
aVmound
p2197
aVwax paper
p2198
aVslippers
p2199
aVvest
p2200
aVnest
p2201
aVranch
p2202
aVnot long
p2203
aVrose
p2204
aV7:25
p2205
aVfurniture
p2206
aVnissan
p2207
aVbracelet
p2208
aVbagel
p2209
aVred white and blue
p2210
aVjeep
p2211
aVshrimp
p2212
aVstand
p2213
aVunited states
p2214
aVpolka dots
p2215
aVacer
p2216
aVred and white
p2217
aVus air force
p2218
aVbrass
p2219
aVswirls
p2220
aVchinese
p2221
aVreins
p2222
aVracing
p2223
aVtags
p2224
aVfreightliner
p2225
aVv
aVpony
p2226
aVpond
p2227
aVcourt
p2228
aVgoal
p2229
aVsoftball
p2230
aVshade
p2231
aVstyle
p2232
aVthoroughbred
p2233
aVcalm
p2234
aVday time
p2235
aVtennis ball
p2236
aVremodeling
p2237
aVcookie
p2238
aVbushes
p2239
aVfeet
p2240
aVwatch tv
p2241
aVtennis rackets
p2242
aVhotel
p2243
aV60
p2244
aV61
p2245
aVoutdoors
p2246
aV75
p2247
aV70
p2248
aVnothing
p2249
aVno flag
p2250
aVcds
p2251
aVtarmac
p2252
aVshepherd
p2253
aV1 in front
p2254
aVair conditioner
p2255
aVnot in service
p2256
aVroman
p2257
aV1 world
p2258
aVgoats
p2259
aVreferee
p2260
aVplants
p2261
aVnightstand
p2262
aVlexus
p2263
aVtables
p2264
aVholding it
p2265
aVlife vest
p2266
aVmcdonald's
p2267
aVsurfers
p2268
aVin cup
p2269
aVvery long
p2270
aVcloth
p2271
aVdelta
p2272
aVpenguin
p2273
aVrainy
p2274
aVfood truck
p2275
aVspoon
p2276
aVsmaller
p2277
aVman's
p2278
aVnapkins
p2279
aVlight
p2280
aVnecklace
p2281
aVhorizontally
p2282
aVoranges
p2283
aVdiet coke
p2284
aVlilac
p2285
aV88
p2286
aV80
p2287
aVipod
p2288
aVbirthday
p2289
aVshell
p2290
aVshelf
p2291
aVcatholic
p2292
aVstove
p2293
aVto catch ball
p2294
aVpiano
p2295
aVwatching
p2296
aVchips
p2297
aVcutting board
p2298
aVblack and white
p2299
aVwall
p2300
aVwalk
p2301
aVin water
p2302
aVgothic
p2303
aVtomatoes
p2304
aVhair dryer
p2305
aVstudent
p2306
aVwhale
p2307
aV9:45
p2308
aVtoilets
p2309
aVdoughnut
p2310
aVkawasaki
p2311
aVfluorescent
p2312
aVred velvet
p2313
aVsweat
p2314
aVblack and orange
p2315
aVvan
p2316
aVbelow
p2317
aVrailing
p2318
aVsnowing
p2319
aVtired
p2320
aVbacon
p2321
aVtires
p2322
aVsecond
p2323
aV5 ft
p2324
aVherd
p2325
aVswan
p2326
aV99
p2327
aVglass
p2328
aV90
p2329
aVit's not
p2330
aVdrying
p2331
aVhot
p2332
aVplastic wrap
p2333
aVbackpack
p2334
aVbrushing teeth
p2335
aVspider
p2336
aVgrapes
p2337
aVshelves
p2338
aVatv
p2339
aVveggie
p2340
aVboston
p2341
aVbusy
p2342
aVbush
p2343
aVcushion
p2344
aVgame controller
p2345
aVsalt and pepper
p2346
aVaccident
p2347
aVsheepdog
p2348
aV2nd
p2349
aVblending
p2350
aVlighthouse
p2351
aVplayer
p2352
aVtoaster
p2353
aVtoasted
p2354
aVtuna
p2355
aVburgers
p2356
aVeasy
p2357
aVeast
p2358
aV6:05
p2359
aV6:00
p2360
aVprivate
p2361
aVright
p2362
aVold
p2363
aVo
aVon sign
p2364
aVfloor
p2365
aVflood
p2366
aVwhipped cream
p2367
aVtime
p2368
aVslope
p2369
aVon plate
p2370
aVdandelions
p2371
aVfalling
p2372
aVdon't walk
p2373
aVbus station
p2374
aVradish
p2375
aVon tracks
p2376
aVoregon
p2377
aVplacemat
p2378
aVtruck
p2379
aVthumb
p2380
aVketchup
p2381
aValuminum
p2382
aVturtle
p2383
aVblt
p2384
aVchevy
p2385
aVspanish
p2386
aVpicnic
p2387
aVsad
p2388
aVsas
p2389
aVbutterfly
p2390
aVsalt
p2391
aVshower head
p2392
aVslow
p2393
aVgoing
p2394
aVblinds
p2395
aVwheelie
p2396
aVsnowsuit
p2397
aVjumped
p2398
aVsurfboards
p2399
aVlong sleeve
p2400
aVdome
p2401
aVboat
p2402
aVgolden gate
p2403
aVlanes
p2404
aVcanoe
p2405
aVcanon
p2406
aVsummer
p2407
aVwristband
p2408
aVnoon
p2409
aVexit
p2410
aVpower
p2411
aVstone
p2412
aVmoped
p2413
aVracquet
p2414
aVphiladelphia
p2415
aVfighter
p2416
aVbaseball cap
p2417
aVcollie
p2418
aVmotocross
p2419
aVhe isn't
p2420
aVsmile
p2421
aVlaying
p2422
aVstuffed animals
p2423
aVswinging bat
p2424
aVpoodle
p2425
aVklm
p2426
aVred and green
p2427
aVmall
p2428
aVclock tower
p2429
aVmale
p2430
aVdress
p2431
aVplant
p2432
aVplane
p2433
aVpatio
p2434
aVbroadway
p2435
aVsanta hat
p2436
aVengland
p2437
aVdriver
p2438
aVright hand
p2439
aVwoman's
p2440
aVmain st
p2441
aVstairs
p2442
aV10 feet
p2443
aVpajamas
p2444
aVbrace
p2445
aV12:55
p2446
aV12:50
p2447
aVcoca cola
p2448
aVgiants
p2449
aVon stove
p2450
aVpaddling
p2451
aVtabby
p2452
aVriver
p2453
aVmovie
p2454
aVeurope
p2455
aVbuoys
p2456
aVdrinking water
p2457
aVkissing
p2458
aVrobin
p2459
aVducks
p2460
aVchase
p2461
aVshorter
p2462
aVwhite and red
p2463
aV11:25
p2464
aV11:20
p2465
aVgrapefruit
p2466
aVsurprised
p2467
aVnowhere
p2468
aVgarbage
p2469
aVlabrador
p2470
aVparakeet
p2471
aVtarp
p2472
aVage
p2473
aVballs
p2474
aVindian
p2475
aVscenery
p2476
aV5:00
p2477
aV5:05
p2478
aVsugar
p2479
aVtools
p2480
aVon
p2481
aVharley davidson
p2482
aVanimals
p2483
aVon tree
p2484
aVthin
p2485
aVleft 1
p2486
aVties
p2487
aVmotorcycle
p2488
aVcage
p2489
aV4:45
p2490
aV4:40
p2491
aVshut
p2492
aVscarf
p2493
aVknife and fork
p2494
aVtater tots
p2495
aVbedroom
p2496
aVrough
p2497
aVwall st
p2498
aVwire
p2499
aVramp
p2500
aVtoilet
p2501
aVbritain
p2502
aVbroccoli
p2503
aVmosaic
p2504
aVovercast
p2505
aVrubber
p2506
aVdifferent teams
p2507
aVdesk
p2508
aVgarage
p2509
aValmonds
p2510
aVon laptop
p2511
aVblurry
p2512
aVhiding
p2513
aVus airways express
p2514
aVspots
p2515
aVpink and yellow
p2516
aVzig zag
p2517
aVphones
p2518
aVblueberry
p2519
aVwork
p2520
aVindia
p2521
aVlab
p2522
aVlap
p2523
aVambulance
p2524
aVoffice
p2525
aVmuffin
p2526
aV4 inches
p2527
aVbread
p2528
aVdown street
p2529
aVcameras
p2530
aVdiesel
p2531
aVkiwi
p2532
aVtoronto
p2533
aVtarget
p2534
aViron
p2535
aVcircles
p2536
aVanniversary
p2537
aVchandelier
p2538
aVtoast
p2539
aVon bike
p2540
aVon desk
p2541
aV2:10
p2542
aV2:15
p2543
aVsteak
p2544
aVsteam
p2545
aV3:00
p2546
aVno hat
p2547
aVon couch
p2548
aVrelaxing
p2549
aVski lift
p2550
aVsticker
p2551
aVgoggles
p2552
aVphotographer
p2553
aVmayo
p2554
aVxbox
p2555
aVorchid
p2556
aVwireless
p2557
aVsnowy
p2558
aVbraves
p2559
aV7:10
p2560
aVladybug
p2561
aVumpire
p2562
aVholding
p2563
aVthumbs up
p2564
aVtugboat
p2565
aVstick
p2566
aVmicrowave
p2567
aVbrown
p2568
aVsteeple
p2569
aVsidewalk
p2570
aVhtc
p2571
aVplanter
p2572
aVplaying wii
p2573
aVmarket
p2574
aVangels
p2575
aVclub
p2576
aVgrassy
p2577
aVposing
p2578
aVbaseball field
p2579
aVcurly
p2580
aVh
aVparking lot
p2581
aVmarina
p2582
aVflying
p2583
aVtusks
p2584
aVthai
p2585
aVflowers
p2586
aVbirthday cake
p2587
aVprofessional
p2588
aVtennis racquet
p2589
aVlilies
p2590
aVleather
p2591
aVon suitcase
p2592
aVblue and red
p2593
aV3
aVskyscraper
p2594
aVdairy queen
p2595
aVwheels
p2596
aVlearning
p2597
aVolives
p2598
aVcycling
p2599
aVfoggy
p2600
aVstickers
p2601
aVon fridge
p2602
aVapple and banana
p2603
aVin car
p2604
aVforks
p2605
aVbusiness
p2606
aVriding bike
p2607
aVcats
p2608
aVorange and white
p2609
aVbuilding
p2610
aVvines
p2611
aVmessy
p2612
aVcarpet
p2613
aVblueberries
p2614
aVfence
p2615
aVcasual
p2616
aVtrunk
p2617
aVbrushing his teeth
p2618
aVsmartphone
p2619
aVegg
p2620
aVpeanuts
p2621
aVfancy
p2622
aVkayaking
p2623
aVfountain
p2624
aVdaisy
p2625
aVwashington dc
p2626
aVgrocery
p2627
aVtop right
p2628
aVbride
p2629
aVconcentration
p2630
aVabove stove
p2631
aVlid
p2632
aVclear
p2633
aVclean
p2634
aVchain link
p2635
aVvery deep
p2636
aVcircle
p2637
aVnot here
p2638
aV10:40
p2639
aVx
aVthrowing
p2640
aVunion station
p2641
aVboth
p2642
aV8:05
p2643
aV8:00
p2644
aVhallway
p2645
aVcleaning
p2646
aVon pizza
p2647
aVyankees
p2648
aVtoyota
p2649
aV20 feet
p2650
aVparade
p2651
aVeating
p2652
aV30 mph
p2653
aVraspberry
p2654
aVgiraffe and zebra
p2655
aVabove door
p2656
aVskull
p2657
aVwine glass
p2658
aVfacebook
p2659
aVon building
p2660
aVbutton
p2661
aVcell
p2662
aVpoles
p2663
aVraspberries
p2664
aVhexagon
p2665
aVsan francisco
p2666
aVsinging
p2667
aVwhite and black
p2668
aVsalad
p2669
aVride
p2670
aVshower curtain
p2671
aVcontrol
p2672
aVin his hand
p2673
aVbroom
p2674
aVspray paint
p2675
aVfront
p2676
aVglobe
p2677
aVpartly cloudy
p2678
aVmt airy
p2679
aVflorida
p2680
aVprivacy
p2681
aVbears
p2682
aVnew york
p2683
aVbeard
p2684
aVrecessed
p2685
aVmicrophone
p2686
aVkickstand
p2687
aVscrambled
p2688
aVindoor
p2689
aVsoldiers
p2690
aVdodge
p2691
aVmichigan
p2692
aVhardwood
p2693
aVcomputer
p2694
aVon bus
p2695
aVhigh heels
p2696
aVchalk
p2697
aV10 inches
p2698
aVballoon
p2699
aVfireman
p2700
aVon left
p2701
aVsavannah
p2702
aVshed
p2703
aVmonday
p2704
aVkettle
p2705
aVtelling time
p2706
aVrodeo
p2707
aVon shelf
p2708
aVmuseum
p2709
aVbrick
p2710
aVpenne
p2711
aV2 feet
p2712
aVquilt
p2713
aVjapanese
p2714
aVcactus
p2715
aVtow truck
p2716
aVasia
p2717
aVlights
p2718
aVpancakes
p2719
aVon train
p2720
aVlandscape
p2721
aVarmy
p2722
aVvery fast
p2723
aVsony ericsson
p2724
aVhoney
p2725
aV4 feet
p2726
aVbriefcase
p2727
aVcadillac
p2728
aVman on right
p2729
aVabove sink
p2730
aVdeck
p2731
aVside of road
p2732
aVturning
p2733
aVspinach
p2734
aVboating
p2735
aVhorse racing
p2736
aVin field
p2737
aVchest
p2738
aVnighttime
p2739
aV12
p2740
aVbat
p2741
aVbar
p2742
aV17
p2743
aVbag
p2744
aVbad
p2745
aVbrazil
p2746
aVsail
p2747
aVshaved
p2748
aVblack and silver
p2749
aVolympics
p2750
aVbalance
p2751
aVmexico
p2752
aVsushi
p2753
aVseattle
p2754
aVvegetable
p2755
aVleft
p2756
aVbackground
p2757
aVmeow
p2758
aVtinkerbell
p2759
aVdonut shop
p2760
aVfloating
p2761
aVjones
p2762
aVmilk
p2763
aVdim
p2764
aVlimes
p2765
aVbikers
p2766
aVdip
p2767
aVfrench
p2768
aVfrance
p2769
aVphillies
p2770
aVdowntown
p2771
aVfly
p2772
aVtokyo
p2773
aVsoup
p2774
aVdrawer
p2775
aVred and blue
p2776
aVski poles
p2777
aVafrican
p2778
aVus open
p2779
aVmain
p2780
aVgirl
p2781
aVliving
p2782
aVpalm trees
p2783
aVcalifornia
p2784
aVpaper towels
p2785
aVpepperoni
p2786
aVno clock
p2787
aVnormal
p2788
aVtrack
p2789
aVin hand
p2790
aVfanta
p2791
aVyield
p2792
aVdog bed
p2793
aV12:05
p2794
aV12:00
p2795
aVshopping
p2796
aVcutting cake
p2797
aVbags
p2798
aVpan
p2799
aVexhaust
p2800
aVseagulls
p2801
aVrunning
p2802
aVrailroad crossing
p2803
aVbottle
p2804
aVbeagle
p2805
aV1.00
p2806
aVmom
p2807
aVlaughing
p2808
aVroad
p2809
aV1:25
p2810
aV1:20
p2811
aVclose up
p2812
aVtripod
p2813
aVoutfield
p2814
aVchildren
p2815
aVearbuds
p2816
aVribbon
p2817
aVmovement
p2818
aVtattoos
p2819
aVstretching
p2820
aVbuses
p2821
aVdiamond
p2822
aVgreen and black
p2823
aVcast iron
p2824
aVrye
p2825
aVcsx
p2826
aV5:30
p2827
aVsomeone
p2828
aVhouse
p2829
aVflower
p2830
aVcups
p2831
aVsweet potato
p2832
aVmountain
p2833
aV4:15
p2834
aVrefrigerator
p2835
aVhugging
p2836
aVbaby's breath
p2837
aVmeeting
p2838
aVsolid
p2839
aVkeys
p2840
aVleopard
p2841
aVbeanie
p2842
aVflags
p2843
aVkites
p2844
aVbin
p2845
aVbig
p2846
aVbib
p2847
aVgoogle
p2848
aVscale
p2849
aVbritish airways
p2850
aVin background
p2851
aV1 inch
p2852
aVgraduation
p2853
aVdaisies
p2854
aVdownhill
p2855
aVcalico
p2856
aVpasture
p2857
aVdesert
p2858
aVvehicles
p2859
aV2000
p2860
aV2007
p2861
aV2008
p2862
aV2009
p2863
aVchicago
p2864
aVsink
p2865
aValaska
p2866
aVlime
p2867
aVtunnel
p2868
aVdb
p2869
aVlines
p2870
aVsquatting
p2871
aVdad
p2872
aVtoshiba
p2873
aVcutting
p2874
aVday
p2875
aVsweatshirt
p2876
aVwindsurfing
p2877
aVlacoste
p2878
aVno train
p2879
aVcurved
p2880
aVsidecar
p2881
aVplaying video games
p2882
aVkite
p2883
aVphotography
p2884
aVbear
p2885
aVstriped
p2886
aVgreen and white
p2887
aVski resort
p2888
aVvent
p2889
aVplaid
p2890
aVplain
p2891
aVdomestic
p2892
aVon ground
p2893
aVlake
p2894
aVsofa
p2895
aVsoft
p2896
aValive
p2897
aVchains
p2898
aVhose
p2899
aVivory
p2900
aVbrand
p2901
aVsupreme
p2902
aVpie
p2903
aVpig
p2904
aVbirthday party
p2905
aVno 1
p2906
aVsleep
p2907
aVfeeding
p2908
aVparis
p2909
aV7:45
p2910
aVfireplace
p2911
aVparrot
p2912
aVbank of america
p2913
aVstar wars
p2914
aVdecorative
p2915
aVwoodpecker
p2916
aVin suitcase
p2917
aVbunk
p2918
aVgate
p2919
aVover
p2920
aVwriting
p2921
aVhanger
p2922
aVdriving
p2923
aVgreen and orange
p2924
aVarizona
p2925
aVblue and green
p2926
aVfree
p2927
aVscissors
p2928
aVram
p2929
aVraw
p2930
aVmetal
p2931
aVqueen
p2932
aVnorth america
p2933
aVunder armour
p2934
aVrowing
p2935
aVlicking
p2936
aVtoothbrush
p2937
aVjelly
p2938
aVplayers
p2939
aVbathroom
p2940
aVcatch
p2941
aVwashington monument
p2942
aVhearts
p2943
aVlaptop
p2944
aVcatching frisbee
p2945
aVcnn
p2946
aVtie
p2947
aVyounger
p2948
aVserious
p2949
aVprotest
p2950
aVhit ball
p2951
aVdrinks
p2952
aVstands
p2953
aVon man
p2954
aVwater
p2955
aVbaseball
p2956
aVname tag
p2957
aVmiddle 1
p2958
aVconductor
p2959
aVferry
p2960
aVwiimote
p2961
aVjetblue
p2962
aVon his head
p2963
aVfloral
p2964
aVclassroom
p2965
aVbranches
p2966
aVswans
p2967
aVwhirlpool
p2968
aVpersian
p2969
aVsunrise
p2970
aVfactory
p2971
aVfishing boat
p2972
aVmotion
p2973
aVpelican
p2974
aVplastic
p2975
aVall of them
p2976
aVwhite
p2977
aVwide
p2978
aVmultiple
p2979
aVseafood
p2980
aVwood
p2981
aVwool
p2982
aVlighter
p2983
aVparking
p2984
aVstop light
p2985
aVspoons
p2986
aVin sink
p2987
aVabstract
p2988
aVtelephone
p2989
aV10:30
p2990
aVbowl
p2991
aVmacaroni
p2992
aVshih tzu
p2993
aVhot dogs
p2994
aVkneeling
p2995
aVfridge
p2996
aVtable
p2997
aV8:55
p2998
aV8:50
p2999
aVcabinet
p3000
aVverizon
p3001
aV5
aVtrash
p3002
aVlace
p3003
aVsiblings
p3004
aVfar
p3005
aVfan
p3006
aVsony
p3007
aV9:25
p3008
aV9:20
p3009
aVforeground
p3010
aVrectangles
p3011
aVsiamese
p3012
aVcucumbers
p3013
aVurinals
p3014
aVsandwich
p3015
aVwater bottle
p3016
aVpirates
p3017
aVdugout
p3018
aVweather vane
p3019
aVsyrup
p3020
aVdouble
p3021
aVcleaner
p3022
aVsandwiches
p3023
aVwindows
p3024
aVhit
p3025
aVbars
p3026
aVart
p3027
aVdump
p3028
aVunsure
p3029
aVbark
p3030
aVarm
p3031
aVbarn
p3032
aVpaisley
p3033
aVfire extinguisher
p3034
aVc
aVsuspenders
p3035
aVorioles
p3036
aV1950s
p3037
aVcumulus
p3038
aVmoving
p3039
aVno grass
p3040
aVmarker
p3041
aVprom
p3042
aV20 ft
p3043
aVbehind bench
p3044
aVupright
p3045
aVcrane
p3046
aVnotes
p3047
aVwaiting
p3048
aVmitt
p3049
aVmetro
p3050
aVspices
p3051
aVapple
p3052
aVmotor
p3053
aVporch
p3054
aVrabbit
p3055
aVwomen
p3056
aVtag
p3057
aVtam
p3058
aVtan
p3059
aVonions
p3060
aVshopping cart
p3061
aVflour
p3062
aVpractice
p3063
aVmaple leaf
p3064
aVtray
p3065
aVparasailing
p3066
aV6 feet
p3067
aVpowdered
p3068
aVseaweed
p3069
aVchicken
p3070
aVthousands
p3071
aVwords
p3072
aVcloset
p3073
aVvirgin atlantic
p3074
aVclosed
p3075
aVpants
p3076
aVcanopy
p3077
aVtraffic light
p3078
aVvanilla
p3079
aVwild
p3080
aVon street
p3081
aVoctagon
p3082
aVstar alliance
p3083
aVenglish
p3084
aVcakes
p3085
aVhappiness
p3086
aVsky
p3087
aVski
p3088
aVleaf
p3089
aVbrush
p3090
aVsweatband
p3091
aVpolar bears
p3092
aVdaffodil
p3093
aVboot
p3094
aVillinois
p3095
aVbook
p3096
aVjunk
p3097
a.
\ No newline at end of file
...@@ -23,12 +23,18 @@ import argparse ...@@ -23,12 +23,18 @@ import argparse
import numpy as np import numpy as np
import multiprocessing import multiprocessing
import json import json
import math
import pickle
from reader.vcr_finetuning import VCRDataJointReader from reader.vcr_finetuning import VCRDataJointReader
from reader.refcoco_plus_finetuning import RefcocoPlusDataReader
from reader.flickr_finetuning import FlickrDataReader
from reader.vqa_finetuning import VQADataReader
from model.ernie_vil import ErnieVilModel, ErnieVilConfig from model.ernie_vil import ErnieVilModel, ErnieVilConfig
from optim.optimization import optimization from optim.optimization import optimization
from utils.args import print_arguments from utils.args import print_arguments
from utils.init import init_checkpoint, init_pretraining_params from utils.init import init_checkpoint, init_pretraining_params
from utils.loss import circle_loss
from args.finetune_args import parser from args.finetune_args import parser
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -37,8 +43,23 @@ args = parser.parse_args() ...@@ -37,8 +43,23 @@ args = parser.parse_args()
# yapf: enable. # yapf: enable.
#READERS = {"vcr": VCRDataJointReader, "vqa": VQADataReader, "refcoco+": RefcocoReader, "flickr": FlickrReader} #READERS = {"vcr": VCRDataJointReader, "vqa": VQADataReader, "refcoco_plus": RefcocoPlusReader, "flickr": FlickrReader}
READERS = {"vcr": VCRDataJointReader} READERS = {"vcr": VCRDataJointReader, "refcoco_plus": RefcocoPlusDataReader,
"flickr": FlickrDataReader, "vqa": VQADataReader}
def write_result_file(res_arr, qids, labels, ans_arr):
""" trans batch results into json format (for VQA test)
"""
for i in range(len(qids)):
#print(int(qids[i]))
res = {
'question_id': int(qids[i]),
'answer': ans_arr[labels[i]]
}
res_arr.append(res)
return res_arr
def format_result(res_arr, qids, pred, labels, scores): def format_result(res_arr, qids, pred, labels, scores):
""" """
...@@ -50,6 +71,249 @@ def format_result(res_arr, qids, pred, labels, scores): ...@@ -50,6 +71,249 @@ def format_result(res_arr, qids, pred, labels, scores):
return res_arr return res_arr
def vqa_classifier_loss(emb_fuse, hidden_size, label, is_test):
"""
classifier loss for vqa
"""
co_emb_size = emb_fuse.shape[-1]
num_class = label.shape[-1]
weight_init_0 = fluid.initializer.UniformInitializer(
low = - math.sqrt(3 / (co_emb_size + hidden_size)), high = math.sqrt(3 / (co_emb_size + hidden_size)))
weight_init_1 = fluid.initializer.UniformInitializer(
low = - math.sqrt(3 / (hidden_size + num_class)), high = math.sqrt(3 / (hidden_size + num_class)))
hidden_emb = fluid.layers.fc(input=emb_fuse, size=hidden_size,
param_attr = fluid.ParamAttr(
initializer = weight_init_0, name = "vqa_fc_w_0"),
bias_attr = "vqa_fc_b_0", act='relu')
hidden_emb = fluid.layers.dropout(hidden_emb, 0.5, dropout_implementation="upscale_in_train")
pred = fluid.layers.fc(input=hidden_emb,
param_attr = fluid.ParamAttr(
initializer = weight_init_1, name = "vqa_fc_w_1"),
bias_attr = "vqa_fc_b_1", size=num_class)
pred = fluid.layers.cast(x=pred, dtype='float32')
cost = fluid.layers.sigmoid_cross_entropy_with_logits(pred, label, name="cross_entropy_loss")
cost = fluid.layers.reduce_sum(cost, -1)
max_conf_label = fluid.layers.argmax(pred, axis=1)
max_conf_label_re = fluid.layers.reshape(max_conf_label, [-1, 1])
one_hot_label = fluid.layers.one_hot(input=max_conf_label_re, depth=num_class)
acc = fluid.layers.reduce_sum(one_hot_label * label, -1)
return max_conf_label, fluid.layers.reduce_mean(cost), fluid.layers.reduce_mean(acc)
def create_vqa_model(pyreader_name, ernie_config, task_group, is_prediction=False):
"""
detail model arch for vqa task
"""
num_class = task_group[0]["num_class"]
classifier_hid_size = task_group[0]["classifier_hid_size"]
shapes=[[-1, args.max_seq_len, 1], #src_id
[-1, args.max_seq_len, 1], #pos_id
[-1, args.max_seq_len, 1], #sent_id
[-1, args.max_seq_len, 1], #input_mask
[-1, args.max_img_len, args.feature_size], #image_embedding
[-1, args.max_img_len, 5], #image_loc
[-1, args.max_img_len, 1], #image_mask
[-1, num_class], #soft_labels
[-1], #q_id
]
dtypes = ['int64', 'int64', 'int64', 'float32', 'float32', 'float32', 'float32', 'float32', 'int64']
#srd_id pos_id sent_id input_mask image_emb image_loc image_mask, labels
lod_levels = [0] * len(dtypes)
pyreader = fluid.layers.py_reader(
capacity=30,
shapes=shapes,
dtypes=dtypes,
lod_levels=lod_levels,
name=pyreader_name,
use_double_buffer=True)
inputs = fluid.layers.read_file(pyreader)
src_ids, pos_ids, sent_ids, input_mask, image_embeddings, \
image_loc, image_mask, labels, q_ids = inputs[: 11]
ernie_vil = ErnieVilModel(
src_ids=src_ids,
position_ids=pos_ids,
sentence_ids=sent_ids,
input_mask=input_mask,
image_embeddings=image_embeddings,
image_loc=image_loc,
input_image_mask=image_mask,
config=ernie_config
)
h_cls, h_img = ernie_vil.get_pooled_output()
score = ernie_vil.get_match_score(h_cls, h_img, "mul")
pred_label, loss, acc = vqa_classifier_loss(score, classifier_hid_size, labels, args.do_test)
task_vars = [loss, acc, pred_label, q_ids]
for var in task_vars:
var.persistable = True
return pyreader, task_vars
def create_refcoco_plus_model(pyreader_name, ernie_config, task_group, is_prediction=False):
"""
detail model arch for refcoco_plus task
"""
shapes=[[-1, args.max_seq_len, 1], #src_id
[-1, args.max_seq_len, 1], #pos_id
[-1, args.max_seq_len, 1], #sent_id
[-1, args.max_seq_len, 1], #input_mask
[-1, 1], #seq_lens
[-1, args.max_img_len, args.feature_size], #image_embedding
[-1, args.max_img_len, 5], #image_loc
[-1, args.max_img_len, 1], #image_mask
[-1, args.max_img_len, 1], #labels
[-1, 1], #add_items
]
dtypes = ['int64', 'int64', 'int64', 'float', 'int64', \
'float32', 'float32', 'float32', 'float32', 'float32']
lod_levels = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pyreader = fluid.layers.py_reader(
capacity=30,
shapes=shapes,
dtypes=dtypes,
lod_levels=lod_levels,
name=pyreader_name,
use_double_buffer=True)
inputs = fluid.layers.read_file(pyreader)
src_ids, pos_ids, sent_ids, input_mask, seq_lens, \
image_embeddings, image_loc, image_mask, labels, add_item = inputs[: 10]
ernie_vil = ErnieVilModel(
src_ids=src_ids,
position_ids=pos_ids,
sentence_ids=sent_ids,
input_mask=input_mask,
image_embeddings=image_embeddings,
image_loc=image_loc,
input_image_mask=image_mask,
config=ernie_config
)
enc_l_out, enc_vl_out = ernie_vil.get_sequence_output()
pred_fc = enc_vl_out
if args.seq_dropout > 0.0:
pred_fc = fluid.layers.dropout(pred_fc, args.seq_dropout, dropout_implementation="upscale_in_train")
logits = fluid.layers.fc(
input=pred_fc, size=1,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name="cls_seq_label_vl_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02),
learning_rate=1.0),
bias_attr=fluid.ParamAttr(
name="cls_seq_label_vl_out_b",
initializer=fluid.initializer.Constant(0.),
learning_rate=1.0))
logits_re = fluid.layers.reduce_mean(logits, -1)
labels_re = fluid.layers.reduce_mean(labels, -1)
input_image_mask = fluid.layers.reduce_mean(image_mask, -1)
ce_loss = fluid.layers.sigmoid_cross_entropy_with_logits(logits_re, labels_re)
ce_loss = ce_loss * input_image_mask
loss = fluid.layers.reduce_sum(ce_loss) / fluid.layers.reduce_sum(input_image_mask)
loss = fluid.layers.mean(x = loss) * args.batch_size
with_mask_loss = fluid.layers.mean(ce_loss) * args.batch_size
if is_prediction:
task_vars = [logits, image_loc, labels, add_item]
else:
task_vars = [loss, with_mask_loss]
for var in task_vars:
var.persistable = True
return pyreader, task_vars
def create_flickr_model(pyreader_name, ernie_config, task_group, is_prediction=False):
"""
detailed model arch for flickr task
"""
shapes=[[-1, args.max_seq_len, 1], #src_id
[-1, args.max_seq_len, 1], #pos_id
[-1, args.max_seq_len, 1], #sent_id
[-1, args.max_seq_len, 1], #input_mask
[-1, args.max_img_len, args.feature_size], #image_embedding
[-1, args.max_img_len, 5], #image_loc
[-1, args.max_img_len, 1], #image_mask
[-1, 1], #labels
[-1, 1], #ids
]
dtypes = ['int64', 'int64', 'int64', 'float', 'float32', 'float32', 'float32', 'int64', 'int64']
#srd_id pos_id sent_id input_mask image_emb image_loc image_mask, labels, ids
#lod_levels = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
lod_levels = [0] * len(dtypes)
pyreader = fluid.layers.py_reader(
capacity=30,
shapes=shapes,
dtypes=dtypes,
lod_levels=lod_levels,
name=pyreader_name,
use_double_buffer=True)
inputs = fluid.layers.read_file(pyreader)
src_ids, pos_ids, sent_ids, input_mask, image_embeddings, \
image_loc, image_mask, labels, ids = inputs[: 9]
ernie = ErnieVilModel(
src_ids=src_ids,
position_ids=pos_ids,
sentence_ids=sent_ids,
input_mask=input_mask,
image_embeddings=image_embeddings,
image_loc=image_loc,
input_image_mask=image_mask,
config=ernie_config
)
h_cls, h_img = ernie.get_pooled_output()
match_emb = ernie.get_match_score(h_cls, h_img)
match_score = fluid.layers.fc(
input=match_emb,
size=1,
act=None,
param_attr=fluid.ParamAttr(
name='match_fc.w_0',
initializer=fluid.initializer.Xavier()),
bias_attr=fluid.ParamAttr(name='match_fc.b_0',
initializer=fluid.initializer.UniformInitializer()))
if not is_prediction:
outs = len(task_group[0]["negative_schema"]) + 1
match_score = fluid.layers.reshape(match_score, [-1, outs])
match_score = fluid.layers.sigmoid(match_score)
positive_score = match_score[:, 0]
image_neg_score = match_score[:, 1:int((outs + 1) / 2)]
caption_neg_score = match_score[:, int((outs + 1) / 2):]
positive_score = fluid.layers.reshape(x=positive_score, shape=[-1, 1])
loss_c = circle_loss(positive_score, caption_neg_score, args.margin, args.scale_circle)
loss_i = circle_loss(positive_score, image_neg_score, args.margin, args.scale_circle)
#total_loss = fluid.layers.mean(loss_c + loss_i)
total_loss = (loss_c + loss_i) / 2
acc = fluid.layers.accuracy(match_score, labels, k=1)
task_vars = [total_loss, acc, match_score, ids]
else:
outs = 1
match_score = fluid.layers.reshape(match_score, [-1, outs])
task_vars = [match_score, ids]
for var in task_vars:
var.persistable = True
return pyreader, task_vars
def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=False): def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=False):
""" """
create model arc for vcr tasks create model arc for vcr tasks
...@@ -57,7 +321,6 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals ...@@ -57,7 +321,6 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals
shapes = [[-1, args.max_seq_len, 1], #src_id shapes = [[-1, args.max_seq_len, 1], #src_id
[-1, args.max_seq_len, 1], #pos_id [-1, args.max_seq_len, 1], #pos_id
[-1, args.max_seq_len, 1], #sent_id [-1, args.max_seq_len, 1], #sent_id
[-1, args.max_seq_len, 1], #task_id
[-1, args.max_seq_len, 1], #input_mask [-1, args.max_seq_len, 1], #input_mask
[-1, args.max_img_len, args.feature_size], #image_embedding [-1, args.max_img_len, args.feature_size], #image_embedding
[-1, args.max_img_len, 5], #image_loc [-1, args.max_img_len, 5], #image_loc
...@@ -67,7 +330,7 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals ...@@ -67,7 +330,7 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals
[], #task_index [], #task_index
[-1, 1], #binary_labels [-1, 1], #binary_labels
] ]
dtypes = ['int64', 'int64', 'int64', 'int64', 'float32', 'float32', 'float32', 'float32', dtypes = ['int64', 'int64', 'int64', 'float32', 'float32', 'float32', 'float32',
'int64', 'int64', 'int64', 'float32'] 'int64', 'int64', 'int64', 'float32']
lod_levels = [0] * len(dtypes) lod_levels = [0] * len(dtypes)
...@@ -85,14 +348,13 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals ...@@ -85,14 +348,13 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals
use_double_buffer=False) use_double_buffer=False)
inputs = fluid.layers.read_file(pyreader) inputs = fluid.layers.read_file(pyreader)
src_ids, pos_ids, sent_ids, task_ids, input_mask, image_embeddings, \ src_ids, pos_ids, sent_ids, input_mask, image_embeddings, \
image_loc, image_mask, labels, q_ids, task_index, binary_labels = inputs[: 12] image_loc, image_mask, labels, q_ids, task_index, binary_labels = inputs[: 11]
ernie_vil = ErnieVilModel( ernie_vil = ErnieVilModel(
src_ids=src_ids, src_ids=src_ids,
position_ids=pos_ids, position_ids=pos_ids,
sentence_ids=sent_ids, sentence_ids=sent_ids,
task_ids=task_ids,
input_mask=input_mask, input_mask=input_mask,
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
image_loc=image_loc, image_loc=image_loc,
...@@ -124,7 +386,7 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals ...@@ -124,7 +386,7 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals
var.persistable = True var.persistable = True
return pyreader, task_vars return pyreader, task_vars
else: else:
start_ind = 12 start_ind = 11
mean_loss = fluid.layers.zeros(shape = [1], dtype = 'float32') mean_loss = fluid.layers.zeros(shape = [1], dtype = 'float32')
mean_acc = fluid.layers.zeros(shape = [1], dtype = 'float32') mean_acc = fluid.layers.zeros(shape = [1], dtype = 'float32')
for task_conf in task_group: for task_conf in task_group:
...@@ -151,9 +413,9 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals ...@@ -151,9 +413,9 @@ def create_vcr_model(pyreader_name, ernie_config, task_group, is_prediction=Fals
return pyreader, task_vars return pyreader, task_vars
#MODELS = {"vcr": create_vcr_model, "vqa": create_vqa_model, "refcoco+": create_refcoco_model} #MODELS = {"vcr": create_vcr_model, "vqa": create_vqa_model, "refcoco+": create_refcoco_model}
MODELS = {"vcr": create_vcr_model} MODELS = {"vcr": create_vcr_model, "refcoco_plus": create_refcoco_plus_model,
"flickr": create_flickr_model, "vqa": create_vqa_model}
def predict_wrapper(args, def predict_wrapper(args,
exe, exe,
...@@ -170,7 +432,6 @@ def predict_wrapper(args, ...@@ -170,7 +432,6 @@ def predict_wrapper(args,
split=args.test_split, split=args.test_split,
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
is_test=True, is_test=True,
shuffle=False,
batch_size=args.batch_size, batch_size=args.batch_size,
epoch=args.epoch) epoch=args.epoch)
if args.do_test: if args.do_test:
...@@ -180,20 +441,16 @@ def predict_wrapper(args, ...@@ -180,20 +441,16 @@ def predict_wrapper(args,
init_pretraining_params(exe, args.init_checkpoint, test_prog) init_pretraining_params(exe, args.init_checkpoint, test_prog)
print(("testing on %s %s split") % (args.task_name, args.test_split)) print(("testing on %s %s split") % (args.task_name, args.test_split))
def predict(exe=exe, pyreader=pyreader): def predict_vcr(exe=exe, pyreader=pyreader):
""" """
inference for downstream tasks inference for vcr tasks
""" """
pyreader.decorate_tensor_provider(data_reader.data_generator()) pyreader.decorate_tensor_provider(data_reader.data_generator())
pyreader.start() pyreader.start()
cost = 0
appear_step = 0
task_acc = {} task_acc = {}
task_steps = {} task_steps = {}
steps = 0 steps = 0
case_f1 = 0
appear_f1 = 0
time_begin = time.time() time_begin = time.time()
task_name_list = [v.name for v in graph_vars] task_name_list = [v.name for v in graph_vars]
fetch_list = task_name_list fetch_list = task_name_list
...@@ -233,7 +490,145 @@ def predict_wrapper(args, ...@@ -233,7 +490,145 @@ def predict_wrapper(args,
except: except:
pass pass
return ret return ret
return predict
def predict_flickr(exe=exe, pyreader=pyreader):
"""
inference for flickr tasks
"""
pyreader.decorate_tensor_provider(data_reader.data_generator())
pyreader.start()
task_acc = {}
task_steps = {}
steps = 0
time_begin = time.time()
task_name_list = [v.name for v in graph_vars]
fetch_list = task_name_list
print('task name list : ', task_name_list)
out_file = open(args.result_file, 'w')
sum_acc = 0
res_arr = []
while True:
try:
outputs = exe.run(fetch_list=fetch_list, program=test_prog)
score = outputs[0]
ids = outputs[1]
for i in range(len(score)):
out_list = [str(score[i][0]), str(ids[i][0]), str(ids[i][1])]
out_file.write('\t'.join(out_list) + '\n')
steps += 1
except fluid.core.EOFException:
pyreader.reset()
break
out_file.close()
used_time = time.time() - time_begin
return None
def predict_vqa(exe=exe, pyreader=pyreader):
"""
inference for vqa tasks
"""
pyreader.decorate_tensor_provider(data_reader.data_generator())
pyreader.start()
appear_step = 0
task_acc = {}
task_steps = {}
steps = 0
time_begin = time.time()
task_name_list = [v.name for v in graph_vars]
fetch_list = task_name_list
print('task name list : ', task_name_list)
sum_acc = 0
total_data = 0
res_arr = []
pickle_file = task_group[0]["pickle_file"]
pkl_file = open(pickle_file)
ans_arr = pickle.load(pkl_file)
while True:
try:
outputs = exe.run(fetch_list=fetch_list, program=test_prog)
each_acc = outputs[1][0]
labels = outputs[2]
qids = outputs[3]
total_data += len(qids.tolist())
sum_acc += each_acc * len(qids.tolist())
steps += 1
if steps % 10 == 0:
print('cur_step:', steps, 'cur_acc:', sum_acc / total_data)
write_result_file(res_arr, qids.tolist(), labels.tolist(), ans_arr)
except fluid.core.EOFException:
pyreader.reset()
break
used_time = time.time() - time_begin
with open(args.result_file, "w") as f:
json.dump(res_arr, f)
print("step:", steps)
print("average_acc:", sum_acc / total_data)
ret = {}
ret["acc"] = "acc: %f" % (sum_acc / total_data)
for item in ret:
try:
ret[item] = ret[item].split(':')[-1]
except:
pass
return ret
def predict_refcoco_plus(exe=exe, pyreader=pyreader):
"""
inference for refcoco_plus tasks
"""
pyreader.decorate_tensor_provider(data_reader.data_generator())
pyreader.start()
task_acc = {}
task_steps = {}
steps = 0
time_begin = time.time()
task_name_list = [v.name for v in graph_vars]
fetch_list = task_name_list
print('task name list : ', task_name_list)
res_arr = []
acc_all = 0
sample_all = 0
while True:
try:
outputs = exe.run(fetch_list=fetch_list, program=test_prog)
logits, image_locs, labels, items = outputs[0:]
for i in range(len(items)):
acc = 0
logit, loc, label, item = logits[i], image_locs[i], labels[i], items[i]
number_box, width, height, w1, h1, w2, h2 = item
start_idx = 1
list_label = list(label[:, 0])[start_idx: int(number_box)]
list_logit = list(logit[:, 0])[start_idx: int(number_box)]
logit = logit[start_idx:int(number_box), 0]
pred = np.argmax(logit)
if label[pred + start_idx, 0] >= 0.5:
acc = 1
acc_all += acc
sample_all += 1
print(acc_all * 1.0 / sample_all)
steps += 1
except fluid.core.EOFException:
pyreader.reset()
break
print('all', sample_all, acc_all * 1.0 / sample_all)
if args.task_name == "vcr":
return predict_vcr
elif args.task_name == "refcoco_plus":
return predict_refcoco_plus
elif args.task_name == "flickr":
return predict_flickr
else:
return predict_vqa
def get_optimizer(total_loss, train_program, startup_prog, args): def get_optimizer(total_loss, train_program, startup_prog, args):
...@@ -255,7 +650,10 @@ def get_optimizer(total_loss, train_program, startup_prog, args): ...@@ -255,7 +650,10 @@ def get_optimizer(total_loss, train_program, startup_prog, args):
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
scheduler=args.lr_scheduler, scheduler=args.lr_scheduler,
decay_steps=decay_steps, decay_steps=decay_steps,
lr_decay_ratio=args.lr_decay_ratio) lr_decay_ratio=args.lr_decay_ratio,
layer_decay_rate=args.layer_decay_rate,
text_init_layers=args.text_init_layers,
n_layers=args.n_layers)
return scheduled_lr return scheduled_lr
...@@ -368,7 +766,7 @@ def main(args): ...@@ -368,7 +766,7 @@ def main(args):
split="train", split="train",
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
batch_size=args.batch_size, batch_size=args.batch_size,
epoch=args.epoch,) epoch=args.epoch)
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
if args.use_fast_executor: if args.use_fast_executor:
...@@ -411,6 +809,12 @@ def main(args): ...@@ -411,6 +809,12 @@ def main(args):
time_begin = time.time() time_begin = time.time()
node_nums = 1 #int(os.getenv("PADDLE_NODES_NUM")) node_nums = 1 #int(os.getenv("PADDLE_NODES_NUM"))
used_time_all = 0 used_time_all = 0
if args.task_name == "refcoco_plus":
metr = "all image loss"
else:
metr = "acc"
while steps < args.num_train_steps: while steps < args.num_train_steps:
try: try:
steps += node_nums steps += node_nums
...@@ -423,6 +827,7 @@ def main(args): ...@@ -423,6 +827,7 @@ def main(args):
time_begin = time.time() time_begin = time.time()
outputs = train_exe.run(fetch_list=fetch_list) outputs = train_exe.run(fetch_list=fetch_list)
if outputs: if outputs:
print("feed_queue size", train_pyreader.queue.size()) print("feed_queue size", train_pyreader.queue.size())
progress_file = data_reader.get_progress() progress_file = data_reader.get_progress()
...@@ -432,12 +837,11 @@ def main(args): ...@@ -432,12 +837,11 @@ def main(args):
current_file = progress_file["current_file"] current_file = progress_file["current_file"]
print( print(
"epoch: %d, progress: %d/%d, step: %d, loss: %f, " "epoch: %d, progress: %d/%d, step: %d, loss: %f, "
"acc: %f" "%s : %f"
% (epoch, current_file_index, total_file, steps, % (epoch, current_file_index, total_file, steps,
outputs[0][0], outputs[0][0],
metr,
outputs[1][0])) outputs[1][0]))
print("steps:", steps)
print("save_steps:", args.save_steps)
np_lr = outputs[-1:] np_lr = outputs[-1:]
......
...@@ -63,7 +63,6 @@ class ErnieVilModel(object): ...@@ -63,7 +63,6 @@ class ErnieVilModel(object):
src_ids, src_ids,
position_ids, position_ids,
sentence_ids, sentence_ids,
task_ids,
input_mask, input_mask,
image_embeddings, image_embeddings,
image_loc, image_loc,
...@@ -115,10 +114,10 @@ class ErnieVilModel(object): ...@@ -115,10 +114,10 @@ class ErnieVilModel(object):
self._param_initializer = fluid.initializer.TruncatedNormal( self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range']) scale=config['initializer_range'])
self._build_model(src_ids, position_ids, sentence_ids, task_ids, input_mask, \ self._build_model(src_ids, position_ids, sentence_ids, input_mask, \
image_embeddings, image_loc, input_image_mask) image_embeddings, image_loc, input_image_mask)
def _build_model(self, src_ids, position_ids, sentence_ids, task_ids, input_mask, \ def _build_model(self, src_ids, position_ids, sentence_ids, input_mask, \
image_embeddings, image_loc, input_image_mask): image_embeddings, image_loc, input_image_mask):
# padding id in vocabulary must be set to 0 # padding id in vocabulary must be set to 0
emb_out = fluid.layers.embedding( emb_out = fluid.layers.embedding(
......
...@@ -20,9 +20,7 @@ import numpy as np ...@@ -20,9 +20,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
def manual_warmup_decay(learning_rate, warmup_steps, num_train_steps, decay_steps=[], lr_decay_ratio=0.1): def manual_warmup_decay(learning_rate, warmup_steps, num_train_steps, decay_steps=[], lr_decay_ratio=0.1):
""" """ Applies linear warmup of learning rate from 0 and keep constant."""
Applies linear warmup of learning rate from 0 and keep constant.
"""
with fluid.default_main_program()._lr_schedule_guard(): with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var( lr = fluid.layers.tensor.create_global_var(
shape=[1], shape=[1],
...@@ -49,9 +47,7 @@ def manual_warmup_decay(learning_rate, warmup_steps, num_train_steps, decay_step ...@@ -49,9 +47,7 @@ def manual_warmup_decay(learning_rate, warmup_steps, num_train_steps, decay_step
def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps): def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
""" """ Applies linear warmup of learning rate from 0 and decay to 0."""
Applies linear warmup of learning rate from 0 and decay to 0.
"""
with fluid.default_main_program()._lr_schedule_guard(): with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var( lr = fluid.layers.tensor.create_global_var(
shape=[1], shape=[1],
...@@ -78,6 +74,41 @@ def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps): ...@@ -78,6 +74,41 @@ def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
return lr return lr
def layer_decay(param, param_last, learning_rate, decay_rate, text_layers, n_layers):
""" layer_decay implementation """
delta = param - param_last
if "encoder_layer" in param.name and param.name.index("encoder_layer")==0:
layer = int(param.name.split("_")[2])
if layer >= text_layers:
cur_layer = text_layers + (layer - text_layers) * 2 + 1
else:
cur_layer = layer
ratio = decay_rate ** (n_layers - cur_layer)
print("text_layer_name:", param.name, "\t", "ratio:", ratio)
param_update = param + (ratio - 1) * delta
elif "encoder_vlayer" in param.name and param.name.index("encoder_vlayer")==0:
layer = int(param.name.split("_")[2])
cur_layer = text_layers + (layer) * 2 + 1
ratio = decay_rate ** (n_layers - cur_layer)
param_update = param + (ratio - 1) * delta
print("image_layer_name:", param.name, "\t", "ratio:", ratio)
elif "encoder_colayer" in param.name and param.name.index("encoder_colayer")==0:
layer = int(param.name.split("_")[2])
cur_layer = text_layers + (layer) * 2
ratio = decay_rate ** (n_layers - cur_layer)
param_update = param + (ratio - 1) * delta
print("co_layer_name:", param.name, "\t", "ratio:", ratio)
elif "embedding" in param.name:
ratio = decay_rate ** (n_layers + 1)
param_update = param + (ratio - 1) * delta
elif "image_emb" in param.name or "image_loc" in param.name:
ratio = decay_rate ** (n_layers - text_layers + 1)
param_update = param + (ratio - 1) * delta
else:
param_update = None
return param_update
def optimization(loss, def optimization(loss,
warmup_steps, warmup_steps,
num_train_steps, num_train_steps,
...@@ -88,10 +119,11 @@ def optimization(loss, ...@@ -88,10 +119,11 @@ def optimization(loss,
scheduler='linear_warmup_decay', scheduler='linear_warmup_decay',
decay_steps=[], decay_steps=[],
lr_decay_dict_file="", lr_decay_dict_file="",
lr_decay_ratio=0.1): lr_decay_ratio=0.1,
""" layer_decay_rate=0.0,
optimization implementation text_init_layers=18,
""" n_layers=30):
""" optimization implementation """
if warmup_steps > 0: if warmup_steps > 0:
if scheduler == 'noam_decay': if scheduler == 'noam_decay':
scheduled_lr = fluid.layers.learning_rate_scheduler \ scheduled_lr = fluid.layers.learning_rate_scheduler \
...@@ -135,8 +167,7 @@ def optimization(loss, ...@@ -135,8 +167,7 @@ def optimization(loss,
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)) clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
def exclude_from_weight_decay(name): def exclude_from_weight_decay(name):
""" """ parameters not use weight decay
Parameters not use weight decay
""" """
if name.find("layer_norm") > -1: if name.find("layer_norm") > -1:
return True return True
...@@ -154,6 +185,16 @@ def optimization(loss, ...@@ -154,6 +185,16 @@ def optimization(loss,
_, param_grads = optimizer.minimize(loss) _, param_grads = optimizer.minimize(loss)
if layer_decay_rate > 0:
for param, grad in param_grads:
with param.block.program._optimized_guard(
[param, grad]), fluid.framework.name_scope("layer_decay"):
param_decay = layer_decay(param, param_list[param.name], scheduled_lr,
layer_decay_rate, text_init_layers, n_layers)
if param_decay:
fluid.layers.assign(output=param, input=param_decay)
if weight_decay > 0: if weight_decay > 0:
for param, grad in param_grads: for param, grad in param_grads:
if exclude_from_weight_decay(param.name): if exclude_from_weight_decay(param.name):
......
...@@ -17,7 +17,6 @@ import sys ...@@ -17,7 +17,6 @@ import sys
import os import os
import base64 import base64
import numpy as np import numpy as np
reload(sys) reload(sys)
sys.setdefaultencoding("utf-8") sys.setdefaultencoding("utf-8")
...@@ -25,7 +24,7 @@ from preprocess import tokenization ...@@ -25,7 +24,7 @@ from preprocess import tokenization
class PreprocessorBasic(object): class PreprocessorBasic(object):
""" """
Main class for text preprocess parent class for preprocess
""" """
def __init__(self, def __init__(self,
tokenizer_name, tokenizer_name,
...@@ -39,7 +38,7 @@ class PreprocessorBasic(object): ...@@ -39,7 +38,7 @@ class PreprocessorBasic(object):
def convert_sentence_to_ids_without_cls(self, sentence): def convert_sentence_to_ids_without_cls(self, sentence):
""" """
Convert sentence to ids without cls convert sentence to ids without cls
""" """
tokens = self.tokenizer.tokenize(sentence) tokens = self.tokenizer.tokenize(sentence)
ids = self.tokenizer.convert_tokens_to_ids(tokens) ids = self.tokenizer.convert_tokens_to_ids(tokens)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" VQA Data Reader implementation """
from __future__ import print_function
from __future__ import division
import os
import base64
import functools
import numpy as np
import types
import gzip
import logging
import re
import six
import collections
import copy
import random
import pickle
import paddle
import paddle.fluid as fluid
from batching.finetune_batching import prepare_flickr_data
from preprocess import preprocessor
class FlickrDataReader(object):
"""
data reader task for flickr
"""
def __init__(self,
task_group,
vocab_path,
split,
batch_size=4096,
max_seq_len=512,
shuffle_files=True,
epoch=100,
voc_size=0,
cls_size=0,
is_test=False):
self.vocab = self.load_vocab(vocab_path)
self.task_group = task_group
self.max_seq_len = max_seq_len
self.processor = getattr(preprocessor, task_group[0]["Proprocessor"])(
tokenizer_name =self.task_group[0]["tokenizer_name"],
vocab_path = vocab_path)
self.batch_size = batch_size
self.shuffle_files = shuffle_files
self.epoch = epoch
self.current_epoch = 0
self.current_file_index = 0
self.total_file = 0
self.current_file = None
self.voc_size = voc_size
self.cls_size = cls_size
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test
if self.is_test:
self.epoch = 1
self.shuffle_files = False
self._test_image_list = []
if split == "dev":
image_path = self.task_group[0]["dev_image_path"]
else:
image_path = self.task_group[0]["test_image_path"]
else:
caption_path = self.task_group[0]["train_caption_path"]
self._load_caption_dict(caption_path)
image_path = self.task_group[0]["train_image_path"]
self._get_hardest_setting(self.task_group[0]["hardest_setting_path"])
self._negative_schema=self.task_group[0]["negative_schema"]
self._load_image_dict(image_path)
def decode_all(self, image_id, width, height, number_box, boxes, image_embeddings):
""" decode all data """
def decode_feature(base64_str, size):
""" decode feature from base64 """
size = int(size)
fea_base64 = base64.b64decode(base64_str)
fea_decode = np.frombuffer(fea_base64, dtype=np.float32)
shape = size, int(fea_decode.shape[0] / size)
features = np.resize(fea_decode, shape)
return features
image_embeddings = decode_feature(image_embeddings, number_box)
image_embeddings_cls = np.mean(image_embeddings, axis = 0, keepdims = True)
image_embeddings = np.concatenate([image_embeddings_cls, image_embeddings], 0)
boxes = decode_feature(boxes, number_box)
shape = np.repeat(np.array([[float(width), float(height), float(width), float(height)]]), \
number_box, axis=0)
boxes = boxes / shape
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
image_loc = np.concatenate((boxes, np.expand_dims(area, 1)), axis = 1)
loc_cls = np.array([[0.0, 0.0, 1.0, 1.0, 1.0]], dtype = "float32")
image_loc = np.concatenate([loc_cls, image_loc], 0)
return int(number_box) + 1, image_loc, image_embeddings
def _load_image_dict(self, image_path):
self._image_feature_dict = {}
image_items = image_path.split(',')
cnt = 0
for image_item in image_items:
with open(image_item) as f:
for line in f:
cnt += 1
if cnt % 1000 == 0:
print('precessing image feature:', cnt)
image_id, width, height, number_box, image_loc, image_embeddings \
= line.strip().split('\t')
number_box, image_loc, image_embeddings = self.decode_all( \
image_id, width, height, number_box, image_loc, image_embeddings)
self._image_feature_dict[int(image_id)] = (width, height, image_embeddings, number_box, image_loc)
if self.is_test:
self._test_image_list.append(int(image_id))
def _load_caption_dict(self, image_caption):
"""
Load caption dict for flickr
"""
self._caption_ids_dict = {}
self._image_sent_map = {}
with open(image_caption) as f:
cnt = 0
for line in f:
cnt += 1
line = line.strip().split("\t")
image_id, sent_id, text = line
token_ids = []
raw_ids = self.processor.convert_sentence_to_ids_without_cls(text)
token_ids.append(self.vocab["[CLS]"])
token_ids.extend(raw_ids)
token_ids.append(self.vocab["[SEP]"])
sent_ids = [0] * len(token_ids)
pos_ids = range(0, len(token_ids))
if cnt % 5000 == 0:
print(cnt)
if len(token_ids) > self.max_seq_len:
token_ids = token_ids[0: self.max_seq_len - 1] + [token_ids[-1]]
sent_ids = sent_ids[0: self.max_seq_len - 1] + [sent_ids[-1]]
pos_ids = pos_ids[0: self.max_seq_len]
assert len(token_ids) == len(sent_ids) == len(pos_ids), \
"[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids)"
self._caption_ids_dict[int(sent_id)] = \
[token_ids, sent_ids, pos_ids, int(image_id)]
self._image_sent_map.setdefault(int(image_id), [])
self._image_sent_map[int(image_id)].append(int(sent_id))
self._train_caption_ids = self._caption_ids_dict.keys()
def _get_hardest_setting(self, hardest_setting_path):
"""
Get the training metrix
"""
with open(hardest_setting_path, 'rb') as f:
data = pickle.load(f)
self._train_hard_pool = data['train_hard_pool']
self._train_image_list = data['train_image_list']
self._train_imgId2pool = {imageId:i for i, imageId in enumerate(self._train_image_list)}
def get_progress(self):
"""
Return current progress of traning data
"""
progress_dict = {"current_epoch": self.current_epoch,
"current_file_index": self.current_file_index,
"total_file": self.total_file,
"current_file": self.current_file
}
return progress_dict
def process_vl(self, line, max_seq_len):
"""
Process single v+l data
"""
if self.is_test:
line = line.strip().split("\t")
image_id, sent_id, text = line
token_ids = []
raw_ids = self.processor.convert_sentence_to_ids_without_cls(text)
token_ids.append(self.vocab["[CLS]"])
token_ids.extend(raw_ids)
token_ids.append(self.vocab["[SEP]"])
sent_ids = [0] * len(token_ids)
pos_ids = range(0, len(token_ids))
if len(token_ids) > self.max_seq_len:
token_ids = token_ids[0: self.max_seq_len - 1] + [token_ids[-1]]
sent_ids = sent_ids[0: self.max_seq_len - 1] + [sent_ids[-1]]
pos_ids = pos_ids[0: self.max_seq_len]
width, height, image_embeddings, number_box, image_loc = self._image_feature_dict[int(image_id)]
else:
sent_id = line
captions_pos = self._caption_ids_dict[sent_id]
image_id = captions_pos[-1]
captions = [captions_pos]
_, _, features, number_box, box = self._image_feature_dict[image_id]
images = [[features, number_box, box]]
for item in self._negative_schema:
if item[0] == "h":
rand_img_id_pool = self._train_hard_pool[self._train_imgId2pool[image_id]]
rand_idx = rand_img_id_pool[random.randint(1, len(rand_img_id_pool) - 1)]
image_id_neg = self._train_image_list[int(rand_idx)]
elif item[0] == "e":
while True:
image_id_neg = random.choice(self._train_image_list)
if image_id_neg != image_id:
break
else:
print("error negative schema")
exit()
if item[1] == "i":
_, _, features_neg, number_box_neg, box_neg = self._image_feature_dict[image_id_neg]
captions.append(self._caption_ids_dict[sent_id])
images.append([features_neg, number_box_neg, box_neg])
elif item[1] == "c":
sent_id_neg = random.choice(self._image_sent_map[image_id_neg])
captions.append(self._caption_ids_dict[sent_id_neg])
images.append([features, number_box, box])
else:
print("error negative schema")
exit()
token_ids, sent_ids, pos_ids, _ = zip(*captions)
image_embeddings, number_box, image_loc = zip(*images)
sample_json = {
"token_ids": token_ids,
"sent_ids": sent_ids,
"pos_ids": pos_ids,
"image_loc": image_loc,
"image_embeddings": image_embeddings,
"image_id": int(image_id),
"sent_id": int(sent_id),
"ids": [image_id, sent_id]
}
return sample_json
def parse_line(self, line, max_seq_len=512, task_index=None):
""" parse one line to token_ids, sentence_ids, pos_ids, label """
sample_json = self.process_vl(line, max_seq_len)
token_ids = sample_json["token_ids"]
return sample_json
def read_file(self, file, task_index):
"""
read line data from file
"""
if self.is_test:
with open(file) as f:
lines = f.readlines()
for line in lines:
yield line
else:
random.shuffle(self._train_caption_ids)
for item in self._train_caption_ids:
yield item
def convert_to_unicode(self, text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(self, vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = self.convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def data_generator(self):
""" data_generator """
filelist_key = "train_filelist"
if self.is_test:
filelist_key = "dev_filelist"
all_files = []
task_probs = []
sum = 0.0
for task in self.task_group:
all_files.append(open(task[filelist_key]).readlines())
task_probs.append(task["prob"])
sum += task["prob"]
for i in xrange(len(task_probs)):
task_probs[i] = task_probs[i] / sum
task_probs = np.array(task_probs).ravel()
def wrapper():
""" wrapper """
def reader(task_index):
""" reader """
files = all_files[task_index]
for epoch in range(self.epoch):
if self.shuffle_files:
np.random.shuffle(files)
for index, file in enumerate(files):
file = file.strip()
sample_generator = paddle.reader.xmap_readers(self.parse_line, \
functools.partial(self.read_file, file=file, task_index=task_index), 8, 2000)
for sample in sample_generator():
if not self.is_test:
self.current_epoch = epoch + 1
self.current_file_index = index + 1
self.current_file = file
self.total_file = len(files)
yield sample
else:
cap_id = sample["ids"][1]
for image_id in self._test_image_list:
line_json = copy.deepcopy(sample)
_, _, image_embeddings, number_box, image_loc = self._image_feature_dict[image_id]
line_json["image_embeddings"] = image_embeddings
line_json["image_loc"] = image_loc
line_json["ids"][0] = image_id
yield line_json
def batch_reader(reader, batch_size):
""" batch reader """
batch, total_token_num, max_len = [], 0, 0
cur_size = 0
dev_count = 1
buff = []
readers = []
for i in xrange(len(task_probs)):
buff.append(None)
readers.append(reader(i))
task_indices = range(len(task_probs))
end_times = 0
while end_times < 50:
task_index = np.random.choice(task_indices, p=task_probs)
dev_num = 0
cur_reader = readers[task_index]
while dev_num < dev_count:
if buff[task_index] is not None:
cur_len = len(buff[task_index]["token_ids"])
max_len = max(max_len, cur_len)
batch.append(buff[task_index])
total_token_num += cur_len
buff[task_index] = None
cur_size += 1
parsed_line = next(cur_reader, None)
if parsed_line is None:
end_times += 1
dev_num += 1
if len(batch) > 0:
yield batch, total_token_num, task_index
batch, total_token_num, max_len = [], 0, 0
continue
end_times = 0
cur_len = len(parsed_line["token_ids"])
max_len = max(max_len, cur_len)
if cur_size >= batch_size:
yield batch, total_token_num, task_index
batch, total_token_num, max_len = [], 0, 0
cur_size = 0
dev_num += 1
buff[task_index] = parsed_line
else:
batch.append(parsed_line)
cur_size += 1
total_token_num += cur_len
for batch_data, total_token_num, task_index in batch_reader(reader, self.batch_size):
if self.is_test:
outs = 1
else:
outs = len(self._negative_schema)+1
yield prepare_flickr_data(
batch_data,
total_token_num,
task_index,
len(self.task_group),
voc_size=self.voc_size,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=self.mask_id,
outs=outs,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
if __name__ == "__main__":
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RefcocoPlus DataReader implementation """
from __future__ import print_function
from __future__ import division
import os
import base64
import numpy as np
import types
import gzip
import logging
import re
import six
import collections
import random
import paddle
import paddle.fluid as fluid
from batching.finetune_batching import prepare_refcoco_plus_batch_data
from preprocess import preprocessor
class RefcocoPlusDataReader(object):
"""
data reader task for refcoco plus
"""
def __init__(self,
task_group,
split,
vocab_path,
batch_size=4096,
max_seq_len=512,
shuffle_files=True,
epoch=100,
voc_size=0,
is_test=False):
self.vocab = self.load_vocab(vocab_path)
self.task_group = task_group
self.processor = getattr(preprocessor, task_group[0]["Proprocessor"])(
tokenizer_name =self.task_group[0]["tokenizer_name"],
vocab_path = vocab_path)
self.batch_size = batch_size
self.shuffle_files = shuffle_files
self.epoch = epoch
self.split = split
self.current_epoch = 0
self.current_file_index = 0
self.total_file = 0
self.current_file = None
self.voc_size = voc_size
self.max_seq_len = max_seq_len
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"]
self.input_slots = 9
self.is_test = is_test
if is_test:
self.epoch = 1
self.shuffle_files = False
def get_progress(self):
"""
return current progress of traning data
"""
self.progress_dict = {"current_epoch": self.current_epoch,
"current_file_index": self.current_file_index,
"total_file": self.total_file,
"current_file": self.current_file
}
return self.progress_dict
def process_vl(self, line, max_seq_len):
"""
process single v+l data
"""
def decode_feature(base64_str, size):
"""
decode feature from base64
"""
fea_base64 = base64.b64decode(base64_str)
fea_decode = np.frombuffer(fea_base64, dtype=np.float32)
shape = size, int(fea_decode.shape[0] / size)
features = np.resize(fea_decode, shape)
return features
text, image_w, image_h, number_boxes, number_boxes_gl, image_loc, \
image_embeddings, box_label, label = line
token_ids = []
raw_ids = self.processor.convert_sentence_to_ids_without_cls(text)
token_ids.append(self.vocab["[CLS]"])
token_ids.extend(raw_ids)
token_ids.append(self.vocab["[SEP]"])
sent_ids = [0] * len(token_ids)
pos_ids = range(0, len(token_ids))
#print("sent_ids:", sent_ids)
token_ids = [int(token) for token in token_ids]
sent_ids = [int(token) for token in sent_ids]
pos_ids = [int(token) for token in pos_ids]
assert len(token_ids) == len(sent_ids) == len(pos_ids), \
"[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids)"
if len(token_ids) > self.max_seq_len:
token_ids = token_ids[0: self.max_seq_len - 1] + [token_ids[-1]]
sent_ids = sent_ids[0: self.max_seq_len - 1] + [sent_ids[-1]]
pos_ids = pos_ids[0: self.max_seq_len]
all_number_box = int(number_boxes) + int(number_boxes_gl)
image_loc = decode_feature(image_loc, all_number_box)
shape_np = np.repeat(np.array(\
[[float(image_w), float(image_h), float(image_w), float(image_h)]]), all_number_box, axis=0)
boxes_np = image_loc / shape_np
area = (boxes_np[:, 3] - boxes_np[:, 1]) * (boxes_np[:, 2] - boxes_np[:, 0])
image_loc = np.concatenate((boxes_np, np.expand_dims(area, 1)), axis = 1)
loc_cls = np.array([[0.0, 0.0, 1.0, 1.0, 1.0]], dtype = "float32")
image_loc = np.concatenate([loc_cls, image_loc], 0)
image_embeddings = decode_feature(image_embeddings, all_number_box)
image_embeddings_cls = np.mean(image_embeddings, axis = 0, keepdims = True)
image_embeddings = np.concatenate([image_embeddings_cls, image_embeddings], 0)
x1, y1, x2, y2 = [float(item) for item in box_label.split(" ")]
cls_label = (x2 - x1 + 1) * (y2 - y1 + 1) /(float(image_w) * float(image_h))
score_th = 0.5
if cls_label < score_th:
cls_label = 0.0
label_tmp = label.split(" ")
if not self.is_test:
for i in range(len(label_tmp)):
if float(label_tmp[i]) < score_th:
label_tmp[i] = 0.0
label = [[cls_label]] + [[float(token)] for token in label_tmp]
label = np.array(label, dtype="float32")
add_item = [all_number_box + 1, image_w, image_h] + [float(item) for item in box_label.split(" ")]
sample_json = {
"token_ids": token_ids,
"sent_ids": sent_ids,
"pos_ids": pos_ids,
"label": label,
"image_loc": image_loc,
"image_embeddings": image_embeddings,
"all_number_box": all_number_box,
"add_item": add_item
}
return sample_json
def parse_line(self, line, max_seq_len=512, task_index=None):
""" parse one line to token_ids, sentence_ids, pos_ids, label """
line = line.strip().split("\t")
assert len(line) == self.input_slots, "One sample must have %d fields!" % self.input_slots
sample_json = self.process_vl(line, max_seq_len)
token_ids = sample_json["token_ids"]
return sample_json
def read_file(self, file, task_index):
""" read line data from a file """
try:
assert file.endswith('.gz'), "[ERROR] %s is not a gzip file" % file
with gzip.open(file, "rb") as f:
lines = f.readlines()
except:
with open(file, "rb") as f:
lines = f.readlines()
if not self.is_test:
np.random.shuffle(lines)
for line in lines:
parsed_line = self.parse_line(
line, max_seq_len=self.max_seq_len, task_index=task_index)
if parsed_line is None:
continue
yield parsed_line
def convert_to_unicode(self, text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(self, vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = self.convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def data_generator(self):
""" data_generator """
if self.split == "train":
filelist_key = "train_filelist"
elif self.split == "val":
filelist_key = "val_filelist"
elif self.split == "testA":
filelist_key = "testA_filelist"
else: filelist_key = "testB_filelist"
all_files = []
task_probs = []
sum = 0.0
for task in self.task_group:
all_files.append(open(task[filelist_key]).readlines())
task_probs.append(task["prob"])
sum += task["prob"]
for i in xrange(len(task_probs)):
task_probs[i] = task_probs[i] / sum
task_probs = np.array(task_probs).ravel()
def wrapper():
"""
wrapper
"""
def reader(task_index):
"""
reader
"""
files = all_files[task_index]
for epoch in range(self.epoch):
if self.shuffle_files:
if epoch < 0:
files = files + open(task["gt_train_filelist"]).readlines()
np.random.shuffle(files)
for index, file in enumerate(files):
file = file.strip()
sample_generator = self.read_file(file, task_index)
for sample in sample_generator:
self.current_epoch = epoch + 1
self.current_file_index = index + 1
self.current_file = file
self.total_file = len(files)
if sample is None:
continue
yield sample
def batch_reader(reader, batch_size):
"""
batch reader
"""
batch, total_token_num, max_len = [], 0, 0
cur_size = 0
dev_count = 1
buff = []
readers = []
for i in xrange(len(task_probs)):
buff.append(None)
readers.append(reader(i))
task_indices = range(len(task_probs))
end_times = 0
while end_times < 50:
task_index = np.random.choice(task_indices, p=task_probs)
dev_num = 0
cur_reader = readers[task_index]
while dev_num < dev_count:
if buff[task_index] is not None:
cur_len = len(buff[task_index]["token_ids"])
max_len = max(max_len, cur_len)
batch.append(buff[task_index])
total_token_num += cur_len
buff[task_index] = None
cur_size += 1
parsed_line = next(cur_reader, None)
if parsed_line is None:
end_times += 1
dev_num += 1
if len(batch) > 0:
yield batch, total_token_num, task_index
batch, total_token_num, max_len = [], 0, 0
continue
end_times = 0
cur_len = len(parsed_line["token_ids"])
max_len = max(max_len, cur_len)
if cur_size >= batch_size:
yield batch, total_token_num, task_index
batch, total_token_num, max_len = [], 0, 0
cur_size = 0
dev_num += 1
buff[task_index] = parsed_line
else:
batch.append(parsed_line)
cur_size += 1
total_token_num += cur_len
for batch_data, total_token_num, task_index in batch_reader(reader, self.batch_size):
yield prepare_refcoco_plus_batch_data(
batch_data,
total_token_num,
task_index,
len(self.task_group),
voc_size=self.voc_size,
pad_id=self.pad_id,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
if __name__ == "__main__":
pass
...@@ -33,9 +33,7 @@ from batching.finetune_batching import prepare_batch_data ...@@ -33,9 +33,7 @@ from batching.finetune_batching import prepare_batch_data
import paddle.fluid as fluid import paddle.fluid as fluid
def _converId(img_id): def _converId(img_id):
""" """ conversion for image ID """
conversion for image ID
"""
img_id = img_id.split('-') img_id = img_id.split('-')
if 'train' in img_id[0]: if 'train' in img_id[0]:
new_id = int(img_id[1]) new_id = int(img_id[1])
...@@ -49,9 +47,7 @@ def _converId(img_id): ...@@ -49,9 +47,7 @@ def _converId(img_id):
def _load_annotationsQ_A(annotations_jsonpath, split): def _load_annotationsQ_A(annotations_jsonpath, split):
""" """Build an index out of FOIL annotations, mapping each image ID with its corresponding captions."""
Build an index out of FOIL annotations, mapping each image ID with its corresponding captions.
"""
entries = [] entries = []
with open(annotations_jsonpath) as f: with open(annotations_jsonpath) as f:
for annotation in json_lines.reader(f): for annotation in json_lines.reader(f):
...@@ -76,9 +72,7 @@ def _load_annotationsQ_A(annotations_jsonpath, split): ...@@ -76,9 +72,7 @@ def _load_annotationsQ_A(annotations_jsonpath, split):
def _load_annotationsQA_R(annotations_jsonpath, split): def _load_annotationsQA_R(annotations_jsonpath, split):
""" """Build an index out of FOIL annotations, mapping each image ID with its corresponding captions."""
Build an index out of FOIL annotations, mapping each image ID with its corresponding captions.
"""
entries = [] entries = []
with open(annotations_jsonpath, 'rb') as f: with open(annotations_jsonpath, 'rb') as f:
for annotation in json_lines.reader(f): for annotation in json_lines.reader(f):
...@@ -117,7 +111,7 @@ def _load_annotationsQA_R(annotations_jsonpath, split): ...@@ -117,7 +111,7 @@ def _load_annotationsQA_R(annotations_jsonpath, split):
class VCRDataReader(object): class VCRDataReader(object):
""" """
Data reader for sub VCR task data reader task for vcr
""" """
def __init__(self, def __init__(self,
task_conf, task_conf,
...@@ -193,7 +187,7 @@ class VCRDataReader(object): ...@@ -193,7 +187,7 @@ class VCRDataReader(object):
def generate_random_name(self, det_names): def generate_random_name(self, det_names):
""" """
Replace "person" with a random name replace "person" with a random name
""" """
random_name = [] random_name = []
for name in det_names: for name in det_names:
...@@ -207,7 +201,7 @@ class VCRDataReader(object): ...@@ -207,7 +201,7 @@ class VCRDataReader(object):
def replace_det_with_name(self, inputs, random_names): def replace_det_with_name(self, inputs, random_names):
""" """
Replace det with name replace det with name
""" """
tokens = [] tokens = []
mask = [] mask = []
...@@ -224,7 +218,7 @@ class VCRDataReader(object): ...@@ -224,7 +218,7 @@ class VCRDataReader(object):
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
""" """
Truncates a sequence pair in place to the maximum length. Truncates a sequence pair in place to the maximum length.
""" """
while True: while True:
total_length = len(tokens_a) + len(tokens_b) total_length = len(tokens_a) + len(tokens_b)
...@@ -237,7 +231,7 @@ class VCRDataReader(object): ...@@ -237,7 +231,7 @@ class VCRDataReader(object):
def get_progress(self): def get_progress(self):
""" """
Return current progress of traning data return current progress of traning data
""" """
progress_dict = {"current_epoch": self.current_epoch, progress_dict = {"current_epoch": self.current_epoch,
"current_file_index": self.current_file_index, "current_file_index": self.current_file_index,
...@@ -248,7 +242,7 @@ class VCRDataReader(object): ...@@ -248,7 +242,7 @@ class VCRDataReader(object):
def tokenize(self): def tokenize(self):
""" """
Tokenizes the captions. Tokenizes the captions.
""" """
# This will add caption_tokens in each entry of the dataset. # This will add caption_tokens in each entry of the dataset.
# -1 represents nil, and should be treated as padding_idx in embedding. # -1 represents nil, and should be treated as padding_idx in embedding.
...@@ -312,7 +306,7 @@ class VCRDataReader(object): ...@@ -312,7 +306,7 @@ class VCRDataReader(object):
def parse_line(self, s_index): def parse_line(self, s_index):
""" """
Form slot info with the line information form the slot info from line
""" """
entry = self._entries[s_index] entry = self._entries[s_index]
image_id = entry["img_id"] image_id = entry["img_id"]
...@@ -367,13 +361,11 @@ class VCRDataReader(object): ...@@ -367,13 +361,11 @@ class VCRDataReader(object):
return record return record
def data_generator(self): def data_generator(self):
""" """ data_generator """
Data_generator
"""
sample_indice = range(len(self._entries)) sample_indice = range(len(self._entries))
def wrapper(): def wrapper():
""" """
Wrapper wrapper
""" """
for epoch_index in range(self.epoch): for epoch_index in range(self.epoch):
if self._split == "train": if self._split == "train":
...@@ -402,9 +394,7 @@ class VCRDataReader(object): ...@@ -402,9 +394,7 @@ class VCRDataReader(object):
class VCRDataJointReader(object): class VCRDataJointReader(object):
""" """ Joint data reader for Q2A task and QA2R task"""
Joint data reader for Q2A task and QA2R task
"""
def __init__(self, def __init__(self,
task_conf_group, task_conf_group,
split, split,
...@@ -435,8 +425,7 @@ class VCRDataJointReader(object): ...@@ -435,8 +425,7 @@ class VCRDataJointReader(object):
self.task_generators = [reader.data_generator() for reader in self.task_readers] self.task_generators = [reader.data_generator() for reader in self.task_readers]
def get_progress(self): def get_progress(self):
""" """return current progress of traning data
Return current progress of traning data
""" """
current_epoch = max([reader.current_epoch for reader in self.task_readers]) current_epoch = max([reader.current_epoch for reader in self.task_readers])
current_file_index = max([reader.current_file_index for reader in self.task_readers]) current_file_index = max([reader.current_file_index for reader in self.task_readers])
...@@ -450,9 +439,7 @@ class VCRDataJointReader(object): ...@@ -450,9 +439,7 @@ class VCRDataJointReader(object):
return self.progress_dict return self.progress_dict
def data_generator(self): def data_generator(self):
""" """ data_generator """
Data_generator
"""
def wrapper(): def wrapper():
""" """
warpper warpper
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" VQA Data Reader implementation """
from __future__ import print_function
from __future__ import division
import os
import base64
import functools
import numpy as np
import types
import gzip
import logging
import re
import six
import collections
import random
import paddle
import paddle.fluid as fluid
from batching.finetune_batching import prepare_vqa_batch_data
from preprocess import preprocessor
class VQADataReader(object):
"""
data reader task for vqa
"""
def __init__(self,
task_group,
split,
vocab_path,
batch_size=4096,
num_class=3129,
max_seq_len=512,
shuffle_files=True,
epoch=100,
voc_size=0,
cls_size=0,
is_test=False):
self.vocab = self.load_vocab(vocab_path)
self.task_group = task_group
self.processor = getattr(preprocessor, task_group[0]["Proprocessor"])(
tokenizer_name =self.task_group[0]["tokenizer_name"],
vocab_path = vocab_path)
self.batch_size = batch_size
self.shuffle_files = shuffle_files
self.epoch = epoch
self.current_epoch = 0
self.current_file_index = 0
self.total_file = 0
self.num_class=num_class
self.current_file = None
self.voc_size = voc_size
self.cls_size = cls_size
self.max_seq_len = max_seq_len
self.pad_id = self.vocab["[PAD]"]
self.cls_id = self.vocab["[CLS]"]
self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"]
self.is_test = is_test
self.split = split
if self.is_test:
self.epoch = 1
self.shuffle_files = False
self.vg_init_epochs = 0
else:
self.vg_init_epochs = int(self.task_group[0]["vg_init_epochs"])
def get_progress(self):
"""
return current progress of traning data
"""
self.progress_dict = {"current_epoch": self.current_epoch,
"current_file_index": self.current_file_index,
"total_file": self.total_file,
"current_file": self.current_file
}
return self.progress_dict
def process_vl(self, line, max_seq_len):
"""
trans the orgin tokens to the wanted tokens
"""
def decode_feature(base64_str, size):
"""
decode feature from base64
"""
fea_base64 = base64.b64decode(base64_str)
fea_decode = np.frombuffer(fea_base64, dtype=np.float32)
shape = size, int(fea_decode.shape[0] / size)
features = np.resize(fea_decode, shape)
return features
question_id, text, match_label, score, image_w, image_h, number_box, \
image_loc, image_embeddings = line
token_ids = []
raw_ids = self.processor.convert_sentence_to_ids_without_cls(text)
token_ids.append(self.vocab["[CLS]"])
token_ids.extend(raw_ids)
token_ids.append(self.vocab["[SEP]"])
sent_ids = [0] * len(token_ids)
pos_ids = range(0, len(token_ids))
token_ids = [int(token) for token in token_ids]
sent_ids = [int(token) for token in sent_ids]
pos_ids = [int(token) for token in pos_ids]
if len(token_ids) > self.max_seq_len:
token_ids = token_ids[0: self.max_seq_len - 1] + [token_ids[-1]]
sent_ids = sent_ids[0: self.max_seq_len - 1] + [sent_ids[-1]]
pos_ids = pos_ids[0: self.max_seq_len]
labels = [int(label_tok) for label_tok in match_label.split("|")]
scores = [float(score_tok) for score_tok in score.split("|")]
number_box = int(number_box)
question_id = int(question_id)
image_loc = decode_feature(image_loc, number_box)
shape_np = np.repeat(np.array(\
[[float(image_w), float(image_h), float(image_w), float(image_h)]]), number_box, axis=0)
boxes_np = image_loc / shape_np
area = (boxes_np[:, 3] - boxes_np[:, 1]) * (boxes_np[:, 2] - boxes_np[:, 0])
image_loc = np.concatenate((boxes_np, np.expand_dims(area, 1)), axis = 1)
loc_cls = np.array([[0.0, 0.0, 1.0, 1.0, 1.0]], dtype = "float32")
image_loc = np.concatenate([loc_cls, image_loc], 0)
try:
image_embeddings = decode_feature(image_embeddings, number_box)
image_embeddings_cls = np.mean(image_embeddings, axis = 0, keepdims = True)
image_embeddings = np.concatenate([image_embeddings_cls, image_embeddings], 0)
self.default_image_emb = image_embeddings
except:
print("error data occur, a random default image emb will be assin to this one")
print("the wrong line occur")
image_embeddings = self.default_image_emb
weight_labels = self.get_weight_label(self.num_class, labels, scores)
sample_json = {
"question_id": question_id,
"token_ids": token_ids,
"sent_ids": sent_ids,
"pos_ids": pos_ids,
"weight_labels": weight_labels,
"image_loc": image_loc,
"image_embeddings": image_embeddings,
}
return sample_json
def get_weight_label(self, num_class, labels, scores):
"""assign the corresponding score for the labels
Input: labels (Indefinite length list, like [1, 2, 3])
scores (Indefinite length list, like [0.1, 0.2, 0.3])
Output: weight_score (list, length equals num_class)
"""
assert len(labels) == len(scores), \
"unequals length with labels has %d number(s) while scores has %d number(s)!" % (len(labels), len(scores))
weight_score = [0] * num_class
for i in range(len(labels)):
weight_score[labels[i]] = scores[i]
return weight_score
def parse_line(self, line, max_seq_len=512, task_index=None):
""" parse one line to token_ids, sentence_ids, pos_ids, label """
line = line.strip().split("\t")
sample_json = self.process_vl(line, max_seq_len)
return sample_json
def read_file(self, file, task_index):
""" read line data from a file """
with open(file, "rb") as f:
lines = f.readlines()
if not self.is_test:
np.random.shuffle(lines)
for line in lines:
yield line
def convert_to_unicode(self, text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(self, vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = self.convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def data_generator(self):
""" data_generator """
filelist_key = "train_filelist"
if self.is_test:
if self.split == "val":
filelist_key = "val_filelist"
elif self.split == "test_dev":
filelist_key = "test_dev_filelist"
elif self.split == "test_std":
filelist_key = "test_std_filelist"
else:
print("*************no split named as :", self.split, "********************")
return None
all_files = []
task_probs = []
sum = 0.0
for task in self.task_group:
all_files.append(open(task[filelist_key]).readlines())
task_probs.append(task["prob"])
sum += task["prob"]
for i in xrange(len(task_probs)):
task_probs[i] = task_probs[i] / sum
task_probs = np.array(task_probs).ravel()
def wrapper():
"""
warpper
"""
def reader(task_index):
"""
reader
"""
files = all_files[task_index]
global_rng = np.random.RandomState(0)
for epoch in range(self.epoch):
if epoch < self.vg_init_epochs:
files = open(task["vg_train_filelist"]).readlines() + all_files[task_index]
if self.shuffle_files:
global_rng.shuffle(files)
for index, file in enumerate(files):
file = file.strip()
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
try:
trainers_num = int(os.getenv("PADDLE_TRAINERS_NUM"))
except:
print("can not get env PADDLE_TRAINERS_NUM, set trainer_nums to 1")
trainers_num = 1
if index % trainers_num != trainer_id:
continue
sample_generator = paddle.reader.xmap_readers(self.parse_line, \
functools.partial(self.read_file, file=file, task_index=task_index), 4, 200)
for sample in sample_generator():
self.current_epoch = epoch + 1
self.current_file_index = index + 1
self.current_file = file
self.total_file = len(files)
if sample is None:
continue
yield sample
def batch_reader(reader, batch_size):
"""
Batch data reader
"""
batch, total_token_num, max_len = [], 0, 0
cur_size = 0
dev_count = 1
buff = []
readers = []
for i in xrange(len(task_probs)):
buff.append(None)
readers.append(reader(i))
task_indices = range(len(task_probs))
end_times = 0
while end_times < 50:
task_index = np.random.choice(task_indices, p=task_probs)
dev_num = 0
cur_reader = readers[task_index]
while dev_num < dev_count:
if buff[task_index] is not None:
cur_len = len(buff[task_index]["token_ids"])
max_len = max(max_len, cur_len)
batch.append(buff[task_index])
total_token_num += cur_len
buff[task_index] = None
cur_size += 1
parsed_line = next(cur_reader, None)
if parsed_line is None:
end_times += 1
dev_num += 1
if len(batch) > 0:
yield batch, total_token_num, task_index
batch, total_token_num, max_len = [], 0, 0
continue
end_times = 0
cur_len = len(parsed_line["token_ids"])
max_len = max(max_len, cur_len)
if cur_size >= batch_size:
yield batch, total_token_num, task_index
batch, total_token_num, max_len = [], 0, 0
cur_size = 0
dev_num += 1
buff[task_index] = parsed_line
else:
batch.append(parsed_line)
cur_size += 1
total_token_num += cur_len
for batch_data, total_token_num, task_index in batch_reader(reader, self.batch_size):
yield prepare_vqa_batch_data(
batch_data,
total_token_num,
task_index,
len(self.task_group),
voc_size=self.voc_size,
pad_id=self.pad_id,
cls_id=self.cls_id,
sep_id=self.sep_id,
mask_id=self.mask_id,
return_input_mask=True,
return_max_len=False,
return_num_token=False)
return wrapper
if __name__ == "__main__":
pass
...@@ -13,6 +13,8 @@ source $CONF_FILE ...@@ -13,6 +13,8 @@ source $CONF_FILE
#configure your cuda and cudnn #configure your cuda and cudnn
#configure nccl #configure nccl
#export LD_LIBRARY_PATH=/home/work/cuda-9.0/lib64:/home/work/cudnn/cudnn_v7/cuda/lib64:$LD_LIBRARY_PATH
#export LD_LIBRARY_PATH=./nccl_2.3.5/lib/:$LD_LIBRARY_PATH
export FLAGS_fast_eager_deletion_mode=1 export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
...@@ -44,6 +46,10 @@ python finetune.py --use_cuda "True" \ ...@@ -44,6 +46,10 @@ python finetune.py --use_cuda "True" \
--lr_scheduler ${lr_scheduler} \ --lr_scheduler ${lr_scheduler} \
--decay_steps ${decay_steps-""} \ --decay_steps ${decay_steps-""} \
--lr_decay_ratio ${lr_decay_ratio-0.1} \ --lr_decay_ratio ${lr_decay_ratio-0.1} \
--layer_decay_rate ${layer_decay_rate-0.0} \
--text_init_layers ${text_init_layers-18} \
--n_layers ${n_layers-30} \
--margin ${margin-0.3} \
--num_train_steps ${num_train_steps} \ --num_train_steps ${num_train_steps} \
--checkpoints $output_model_path \ --checkpoints $output_model_path \
--save_steps ${SAVE_STEPS} \ --save_steps ${SAVE_STEPS} \
...@@ -53,7 +59,6 @@ python finetune.py --use_cuda "True" \ ...@@ -53,7 +59,6 @@ python finetune.py --use_cuda "True" \
--warmup_steps ${WARMUP_STEPS} \ --warmup_steps ${WARMUP_STEPS} \
--weight_decay ${WEIGHT_DECAY:-0} \ --weight_decay ${WEIGHT_DECAY:-0} \
--max_seq_len ${MAX_LEN} \ --max_seq_len ${MAX_LEN} \
--validation_steps ${VALID_STEPS} \
--skip_steps 10 --skip_steps 10
...@@ -13,6 +13,8 @@ RES_FILE=$8 ...@@ -13,6 +13,8 @@ RES_FILE=$8
source $CONF_FILE source $CONF_FILE
#export LD_LIBRARY_PATH=/home/work/cuda-9.0/lib64:/home/work/cudnn/cudnn_v7/cuda/lib64:$LD_LIBRARY_PATH
#export LD_LIBRARY_PATH=./nccl_2.3.5/lib/:$LD_LIBRARY_PATH
#configure your cuda and cudnn #configure your cuda and cudnn
#configure nccl #configure nccl
......
import sys
ans_dict = {}
text_ans_dict = {}
filename = './data/flickr/flickr.dev.data'
with open(filename) as f:
for line in f:
line = line.strip().split('\t')
image_id, sent_id = line[0], line[1]
ans_dict[sent_id.strip(' ')] = image_id.strip(' ')
text_ans_dict.setdefault(image_id.strip(' '), [])
text_ans_dict[image_id.strip(' ')].append(sent_id.strip(' '))
if len(sys.argv) > 1:
res_file = sys.argv[1]
else:
res_file = "./result"
print ('=============== IMAGE RETRIEVAL ==================')
with open(res_file) as f:
r1, r5, r10 = 0, 0, 0
cnt = 0
res_dict = {}
text_res_dict = {}
idx_all = 0.0
for line in f:
line = line.strip().split('\t')
if len(line) != 3:
break
score, image_id, sent_id = float(line[0]), line[1], line[2]
res_dict.setdefault(sent_id, [])
res_dict[sent_id].append((score, image_id))
text_res_dict.setdefault(image_id, [])
text_res_dict[image_id].append((score, sent_id))
if len(res_dict[sent_id]) == 1000:
res_list = res_dict[sent_id]
res_list = sorted(res_list, reverse = True)
ans = ans_dict[sent_id]
image_id_sort = list(zip(*res_list)[1])
ans_idx = image_id_sort.index(ans.strip())
if ans_idx < 1:
r1 += 1.0
if ans_idx < 5:
r5 += 1.0
if ans_idx < 10:
r10 += 1.0
idx_all += (ans_idx + 1)
cnt += 1
if cnt % 100 == 0:
print cnt, round(r1/cnt, 4), round(r5/cnt, 4), round(r10/cnt, 4), round(idx_all/cnt, 4)
print '-----------------------------'
print "instance %d r1:%.4f, r5:%.4f, r10:%.4f, avg_rank:%.4f" % (cnt, r1/cnt, r5/cnt, r10/cnt, idx_all/cnt)
print ('\n=============== TEXT RETRIEVAL ==================')
cnt = 0
r1, r5, r10 = 0, 0, 0
idx_all = 0.0
for image_id in text_res_dict:
res_list = text_res_dict[image_id]
res_list = sorted(res_list, reverse = True)
ans = text_ans_dict[image_id]
text_id_sort = list(zip(*res_list)[1])
ans_idx_all = []
for item in ans:
ans_idx_all.append(text_id_sort.index(item.strip()))
ans_idx = min(ans_idx_all)
if ans_idx < 1:
r1 += 1.0
if ans_idx < 5:
r5 += 1.0
if ans_idx < 10:
r10 += 1.0
idx_all += (ans_idx + 1)
cnt += 1
if cnt % 500 == 0:
print cnt, round(r1/cnt, 4), round(r5/cnt, 4), round(r10/cnt, 4), round(idx_all/cnt, 4)
print '-----------------------------'
print "instance %d r1:%.4f, r5:%.4f, r10:%.4f, avg_rank:%.4f" % (cnt, r1/cnt, r5/cnt, r10/cnt, idx_all/cnt)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""parameters init function implementations"""
from __future__ import print_function
import os
import six
import numpy as np
import paddle.fluid as fluid
def circle_loss(sp, sn, m, scale_circle):
"""
sp: score list of positive samples, shape [B * L]
sn: score list of negative samples, shape [B * K]
m: relaxation factor in circle loss function
scale: scale factor in circle loss function
return: circle loss value, shape [1]
"""
op = 1. + m
on = 0. - m
delta_p = 1 - m
delta_n = m
ap = fluid.layers.relu(op - sp)
ap.stop_gradient = True
an = fluid.layers.relu(sn - on)
an.stop_gradient = True
logit_p = ap * (sp - delta_p)
logit_p = -1. * scale_circle * logit_p
logit_p = fluid.layers.cast(x=logit_p, dtype=np.float64)
loss_p = fluid.layers.reduce_sum(fluid.layers.exp(logit_p), dim=1, keep_dim=False)
logit_n = an * (sn - delta_n)
logit_n = scale_circle * logit_n
logit_n = fluid.layers.cast(x=logit_n, dtype=np.float64)
loss_n = fluid.layers.reduce_sum(fluid.layers.exp(logit_n), dim=1, keep_dim=False)
circle_loss = fluid.layers.log(1 + loss_n * loss_p)
circle_loss = fluid.layers.cast(x=circle_loss, dtype=np.float32)
return fluid.layers.mean(circle_loss)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册