未验证 提交 8e8ce17f 编写于 作者: X xfcygaocan 提交者: GitHub

Repro (#661)

* add ernie-unimo

* add ernie-unimo
上级 51239d91
# <p align=center>`UNIMO`</p>
Code for the main conference of ACL 2021 long paper [UNIMO: Towards Unified-Modal Understanding and Generation via Cross-Modal Contrastive Learning](https://arxiv.org/pdf/2012.15409.pdf)
## Abstract
Existed pre-training methods either focus on single-modal tasks or multi-modal tasks, and cannot effectively adapt to each other.
They can only utilize single-modal data (i.e., text or image) or limited multi-modal data (i.e., image-text pairs).
In this work, we propose a UNIfied-MOdal pre-training architecture, namely `UNIMO`, which can effectively adapt to both single-modal and multi-modal understanding and generation tasks.
Large scale of free text corpus and image collections are utilized to improve the capability of visual and textual understanding, and cross-modal contrastive learning (CMCL) is leveraged to align the textual and visual information into a unified semantic space over a corpus of image-text pairs augmented with related images and texts.
With the help of rich non-paired single-modal data, our model is able to learn more generalizable representations, by allowing textual knowledge and visual knowledge to enhance each other in the unified semantic space.
The experimental results show that `UNIMO` greatly improves the performance of several single-modal and multi-modal downstream tasks.
![UNIMO](images/framework.png#pic_center)
## Performance
Results on multi-modal understanding and generation tasks:
![UNIMO](images/multiple.png#pic_center)
Results on single-modal understanding and generation tasks:
![UNIMO](images/single.png#pic_center)
---
## TODOs
- [] Add all downstream tasks
- [] Add unimo large model
## Dependencies
python 3.7.4\
paddlepaddle-gpu==1.8.4.post107\
pyrouge==0.1.3
## Pre-trained Models
`UNIMO` adopts large-scale text corpus, image collections and image-text aligned datasets as the pre-training data.
We provide `UNIMO` models of 1 scale settings which are pretrained:
[UNIMO base](https://unimo.bj.bcebos.com/model/unimo_base_en.tar.gz) (lowercased | 12 layers)
```
MODEL_SIZE=base
cd /path/to/model_files
wget --no-check-certificate -q https://unimo.bj.bcebos.com/model/unimo_${MODEL_SIZE}_en.tar.gz
tar -zxf unimo_${MODEL_SIZE}_en.tar.gz
```
## Experiments
Our fine-tuning experiments are carried on V100 GPU. Here are the results from the `UNIMO` model:
<table>
<tr>
<td><strong><center>Task Type</strong></td>
<td><strong><center>Datatset</strong></td>
<td><strong><center>Pre-trained Models</strong></td>
<td><strong><center>Start Command</strong></td>
<td><strong><center>V100 GPU Cards</strong></td>
<td><strong><center>Running Time</strong></td>
</tr>
<tr>
<td rowspan="1"><center>Text Understanding<center></td>
<td rowspan="1"><center>SST-2<center></td>
<td><center>UNIMO base</td>
<td><center>sh ./script/classification/SST-2/run.sh</td>
<td><center>8</td>
<td><center>9h</td>
</tr>
<tr>
<td rowspan="1"><center>Text Generation<center></td>
<td rowspan="1"><center>CoQA<center></td>
<td><center>UNIMO base</td>
<td><center>sh ./script/seq2seq/coqa/run.sh</td>
<td><center>4</td>
<td><center>7h</td>
</tr>
<tr>
<td rowspan="1"><center>Multi-Modal Understanding<center></td>
<td rowspan="1"><center>Flickr30k<center></td>
<td><center>UNIMO base</td>
<td><center>sh ./script/retrieval/Flickr30k/run.sh</td>
<td><center>16</td>
<td><center>3d</td>
</tr>
<table>
---
## Text Understanding Tasks
### (1) Sentiment Classification
#### Download SST-2 dataset:
```
cd /path/to/data
wget --no-check-certificate -q https://unimo.bj.bcebos.com/data/SST-2.tar.gz
tar -zxf SST.tar.gz
```
#### Run the following common to train and evaluate on the SST-2 dataset:
For base model:
```
bash ./script/classification/SST-2/run.sh
```
#### Evaluation Results:
<table>
<tr>
<td><strong><center>Model</strong></td>
<td><strong><center>Acc</strong></td>
</tr>
<tr>
<td><center>UNIMO-base</td>
<td><center>95.1</td>
</tr>
<table>
## Text Generation Tasks
### (1) Conversation Question Answering
#### Download CoQA dataset:
```
cd /path/to/data
wget --no-check-certificate -q https://unimo.bj.bcebos.com/data/coqa.tar.gz
tar -zxf coqa.tar.gz
```
#### Download evaluation script:
```
cd src/eval/tasks
wget --no-check-certificate -q https://unimo.bj.bcebos.com/eval_script/coqa.tar.gz
tar -zxf coqa.tar.gz
```
#### Run the following common to train and evaluate on the CoQA dataset:
For base model:
```
bash ./script/seq2seq/coqa/run.sh
```
#### Evaluation Results:
<table>
<tr>
<td><strong><center>Model</strong></td>
<td><strong><center>Acc</strong></td>
</tr>
<tr>
<td><center>UNIMO-base</td>
<td><center>80.2</td>
</tr>
<table>
## Multi-Modal Understanding Tasks
### (1) Image-Text Retrieval
#### Download Flickr30k dataset:
##### Note: Visual features are extracted by [bottom-up-attention](https://github.com/peteanderson80/bottom-up-attention)
```
cd /path/to/data
wget --no-check-certificate -q https://unimo.bj.bcebos.com/data/Flickr30k.tar.gz # occupies about 37G disk space
tar -zxf Flickr30k.tar.gz
```
#### Run the following common to train and evaluate on the Flickr30k dataset:
For base model:
```
bash ./script/retrieval/Flickr30k/run.sh
```
#### Evaluation Results:
Results of Image Retrieval task on Flickr30k dataset
<table>
<tr>
<td><strong><center>Model</strong></td>
<td><strong><center>R@1</strong></td>
<td><strong><center>R@5</strong></td>
<td><strong><center>R@10</strong></td>
</tr>
<tr>
<td><center>UNIMO-base</td>
<td><center>74.66</td>
<td><center>93.40</td>
<td><center>96.08</td>
</tr>
<table>
Results of Text Retrieval task on Flickr30k dataset
<table>
<tr>
<td><strong><center>Model</strong></td>
<td><strong><center>R@1</strong></td>
<td><strong><center>R@5</strong></td>
<td><strong><center>R@10</strong></td>
</tr>
<tr>
<td><center>UNIMO-base</td>
<td><center>89.70</td>
<td><center>98.40</td>
<td><center>99.10</td>
</tr>
<table>
---
Citation
---
If you find our paper and code useful, please cite the following paper:
```
@article{li2020unimo,
title={UNIMO: Towards Unified-Modal Understanding and Generation via Cross-Modal Contrastive Learning},
author={Li, Wei and Gao, Can and Niu, Guocheng and Xiao, Xinyan and Liu, Hao and Liu, Jiachen and Wu, Hua and Wang, Haifeng},
journal={arXiv preprint arXiv:2012.15409},
year={2020}
}
```
Contact information
---
For help or issues using `UNIMO`, please submit a GitHub issue.
For personal communication related to `UNIMO`, please contact Wei Li (liwei85@baidu.com), Guocheng Niu (niuguocheng@baidu.com) , Can Gao (gaocan01@baidu.com).
\ No newline at end of file
data_name=SST-2
data_tar=${data_name}.tar.gz
bos_url=https://unimo.bj.bcebos.com/data/SST-2.tar.gz
rm -rf $data_name
wget --no-check-certificate -q $bos_url
if [[ $? -ne 0 ]]; then
echo "url link: $bos_url"
echo "download data failed"
exit 1
fi
tar zxf $data_tar
rm -f $data_tar
exit 0
#!/usr/bin/env bash
set -x
# add CUDA, cuDNN and NCCL to environment variable
# export LD_LIBRARY_PATH=/home/work/cuda-10.0/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}
# export LD_LIBRARY_PATH=/home/work/cuda-10.0/extras/CUPTI/lib64:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/home/work/cudnn/cudnn_v7.6/cuda/lib64:$LD_LIBRARY_PATH
# export LD_LIBRARY_PATH=/home/work/nccl/nccl2.4.2_cuda10.1/lib:$LD_LIBRARY_PATH
export FLAGS_sync_nccl_allreduce=1
export FLAGS_fraction_of_gpu_memory_to_use=1
export FLAGS_eager_delete_tensor_gb=1.0
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_memory_fraction_of_eager_deletion=1
export iplist=`hostname -i`
unset http_proxy
unset https_proxy
set +x
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"max_position_embeddings": 514,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 0,
"sent_type_vocab_size": 0,
"task_type_vocab_size": 0,
"vocab_size": 50265,
"max_img_len": 37,
"max_obj_len": 50,
"image_class_size": 1601,
"image_attr_size": 401,
"image_embedding_size": 2048,
"image_predict_feature": true,
"image_predict_class": true,
"image_use_attr": false,
"image_use_soft_label": true,
"use_neg_lm_loss": false,
"fusion_method": "mul",
"similarity_method": "softmax",
"txt_mask_ratio": 0.15,
"vl_mask_ratio": 0.15,
"scenegraph_mask_ratio": 0.3,
"overlap_ratio": 0.4,
"num_labels": 2,
"max_pixel_len": 256,
"max_pixel_position_embeddings": 196
}
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"max_position_embeddings": 514,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"type_vocab_size": 0,
"sent_type_vocab_size": 0,
"task_type_vocab_size": 0,
"vocab_size": 50265,
"max_img_len": 101,
"max_obj_len": 100,
"image_class_size": 1601,
"image_attr_size": 401,
"image_embedding_size": 2048,
"image_predict_feature": true,
"image_predict_class": true,
"image_use_attr": false,
"image_use_soft_label": true,
"use_neg_lm_loss": false,
"fusion_method": "mul",
"txt_mask_ratio": 0.15,
"vl_mask_ratio": 0.15,
"scenegraph_mask_ratio": 0.4,
"overlap_ratio": 0.3,
"num_labels": 2,
"max_pixel_len": 256,
"max_pixel_position_embeddings": 196
}
此差异已折叠。
此差异已折叠。
此差异已折叠。
data_name=unimo_base_en
data_tar=${data_name}.tar.gz
bos_url=https://unimo.bj.bcebos.com/model/$data_tar
rm -rf $data_name
wget --no-check-certificate -q $bos_url
if [[ $? -ne 0 ]]; then
echo "url link: $bos_url"
echo "download data failed"
exit 1
fi
tar zxf $data_tar
rm -f $data_tar
exit 0
paddlepaddle-gpu==1.8.4.post107
pyrouge==0.1.3
regex==2020.7.14
output_name="classification"
task=SST-2
## hyper param
use_fp16="False"
do_train="True"
do_val="True"
do_test="False"
do_pred="True"
num_labels=2
weight_decay=0
max_len=512
warmup_ratio=0.06
save_checkpoints="False"
save_steps=2000
validation_steps=2000
skip_steps=10
eval_mertrics=simple_accuracy
EPOCH=("10")
BATCH_SIZE=("16" "32")
LR_RATE=("1e-5" "2e-5" "3e-5")
DD_RAND_SEED=("1" "2" "3" "4" "5")
init_model="./model_files/unimo_base_en"
config_path="./model_files/config/unimo_base_en.json"
vocab_file="./model_files/dict/unimo_en.vocab.txt"
bpe_json="./model_files/dict/unimo_en.encoder.json"
bpe_file="./model_files/dict/unimo_en.vocab.bpe"
#!/usr/bin/env bash
set -eux
R_DIR=`dirname $0`; MYDIR=`cd $R_DIR;pwd`
cd ${MYDIR}/../../../
# config env
source ${MYDIR}/model_conf
source ./env.sh
source ./utils.sh
check_iplist
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
output_dir=./output/${task}
log_dir=${output_dir}/log
save_model_base_dir=$output_dir/save_model
mkdir -p $output_dir $log_dir $save_model_base_dir
if [[ ${do_pred} == "True" ]]; then
pred_save_prefix="${output_dir}/predict"
mkdir -p $pred_save_prefix
fi
for seed in "${DD_RAND_SEED[@]}"; do
echo "seed "$seed
for epoch in "${EPOCH[@]}"; do
echo "epoch "$epoch
for lr in "${LR_RATE[@]}"; do
echo "learning rate "$lr
for bs in "${BATCH_SIZE[@]}"; do
echo "batch_size "$bs
log_prefix=$seed"_"$epoch"_"$lr"_"$bs"."
if [[ ${do_pred} == "True" ]]; then
pred_save="${pred_save_prefix}/test.${seed}.${epoch}.${lr}.${bs}"
fi
if [[ ${save_checkpoints} == "True" ]]; then
save_model_dir="${save_model_base_dir}/params.${seed}.${epoch}.${lr}.${bs}"
mkdir -p $save_model_dir
fi
if [[ ${bs} == "32" ]]; then
validation_steps=1000
fi
python -u ./src/run_classifier.py --use_cuda "True" \
--is_distributed ${is_distributed:-"False"} \
--weight_sharing ${weight_sharing:-"True"} \
--use_fast_executor ${e_executor:-"true"} \
--use_fp16 ${use_fp16:-"false"} \
--nccl_comm_num ${nccl_comm_num:-1} \
--use_hierarchical_allreduce ${use_hierarchical_allreduce:-"False"} \
--in_tokens ${in_tokens:-"false"} \
--use_dynamic_loss_scaling ${use_fp16} \
--init_loss_scaling ${loss_scaling:-12800} \
--beta1 ${beta1:-0.9} \
--beta2 ${beta2:-0.98} \
--epsilon ${epsilon:-1e-06} \
--verbose true \
--do_train ${do_train:-"True"} \
--do_val ${do_val:-"True"} \
--do_test ${do_test:-"True"} \
--do_pred ${do_pred:-"True"} \
--pred_save ${pred_save:-"./output/predict/test"} \
--batch_size ${bs:-16} \
--init_pretraining_params ${init_model:-""} \
--train_set ./data/SST-2/train.tsv \
--dev_set ./data/SST-2/dev.tsv \
--test_set ./data/SST-2/test.tsv \
--checkpoints ${save_model_dir:-""} \
--save_checkpoints ${save_checkpoints:-"True"} \
--save_steps ${save_steps:-1000} \
--weight_decay ${weight_decay:-"0.1"} \
--warmup_proportion ${warmup_ratio:-"0.06"} \
--validation_steps ${validation_steps:-"100"} \
--epoch $epoch \
--max_seq_len ${max_len:-512} \
--learning_rate ${lr:-"5e-5"} \
--lr_scheduler ${lr_scheduler:-"linear_warmup_decay"} \
--skip_steps ${skip_steps:-"10"} \
--num_iteration_per_drop_scope 10 \
--num_labels ${num_labels:-2} \
--unimo_vocab_file ${vocab_file} \
--encoder_json_file ${bpe_json} \
--vocab_bpe_file ${bpe_file} \
--unimo_config_path ${config_path} \
--eval_mertrics ${eval_mertrics:-"simple_accuracy"} \
--random_seed ${seed:-1} >> $log_dir/${log_prefix}lanch.log 2>&1
done
done
done
done
if [[ $? -ne 0 ]]; then
echo "run failed"
exit 1
fi
python ./src/utils/stat_res.py --log_dir=$log_dir
exit 0
output_name="retrieval"
task=Flickr30k
## hyper param
epoch=40
do_train="True"
do_val="True"
do_test="True"
save_checkpoints="False"
save_steps=10000
validation_steps=10000
samples_num=20
bbox="bbox100"
max_img_len=101
seed=1
batch_size=4
test_batch_size=128
lr=5e-6
learning_rate_scale=0.1
learning_rate_decay_epoch1=24
learning_rate_decay_epoch2=32
init_model="./model_files/unimo_base_en"
config_path="./model_files/config/unimo_base_en.json"
vocab_file="./model_files/dict/unimo_en.vocab.txt"
bpe_json="./model_files/dict/unimo_en.encoder.json"
bpe_file="./model_files/dict/unimo_en.vocab.bpe"
#!/usr/bin/env bash
set -eux
R_DIR=`dirname $0`; MYDIR=`cd $R_DIR;pwd`
cd ${MYDIR}/../../../
# config env
source ${MYDIR}/model_conf
source ./env.sh
source ./utils.sh
check_iplist
export FLAGS_fuse_parameter_memory_size=64
set -eu
output_dir=./output/${task}
log_dir=${output_dir}/log
save_model_base_dir=$output_dir/save_model
mkdir -p $output_dir $log_dir $save_model_base_dir
log_prefix=$seed"_"$epoch"_"$lr"_"$batch_size"."
eval_dir="${output_dir}/tmp/params.${seed}.${epoch}.${lr}.${batch_size}"
mkdir -p $eval_dir
if [[ ${save_checkpoints} == "True" ]]; then
save_model_dir="${save_model_base_dir}/params.${seed}.${epoch}.${lr}.${batch_size}"
mkdir -p $save_model_dir
fi
distributed_args="--node_ips ${PADDLE_TRAINERS} \
--node_id ${PADDLE_TRAINER_ID} \
--current_node_ip ${POD_IP} \
--selected_gpus 0,1,2,3,4,5,6,7 \
--split_log_path $log_dir \
--log_prefix $log_prefix \
--nproc_per_node 8"
lanch_start=" -u ./src/launch.py ${distributed_args} "
python $lanch_start ./src/run_retrieval.py \
--use_cuda "True" \
--is_distributed ${is_distributed:-"True"} \
--weight_sharing ${weight_sharing:-"True"} \
--use_fuse ${use_fuse:-"True"} \
--use_fast_executor ${e_executor:-"true"} \
--use_fp16 ${use_fp16:-"false"} \
--nccl_comm_num ${nccl_comm_num:-2} \
--use_hierarchical_allreduce ${use_hierarchical_allreduce:-"False"} \
--use_dynamic_loss_scaling ${use_fp16:-"False"} \
--use_sigmoid ${use_sigmoid:-"False"} \
--init_loss_scaling ${loss_scaling:-12800} \
--beta1 ${beta1:-0.9} \
--beta2 ${beta2:-0.98} \
--epsilon ${epsilon:-1e-06} \
--scale_circle ${scale_circle:-1.0} \
--margin ${margin:-0.2} \
--verbose true \
--samples_num ${samples_num:-20} \
--run_random ${run_random:-"False"} \
--do_train ${do_train:-"True"} \
--do_val ${do_val:-"True"} \
--do_test ${do_test:-"True"} \
--batch_size ${batch_size:-16} \
--test_batch_size ${test_batch_size:-96} \
--init_pretraining_params ${init_model:-""} \
--train_image_caption ./data/Flickr30k/flickr30k-textids/train.ids \
--train_image_feature_dir ./data/Flickr30k/flickr30k-features/$bbox/train \
--dev_image_caption ./data/Flickr30k/flickr30k-textids/val.all.ids \
--dev_image_feature_dir ./data/Flickr30k/flickr30k-features/$bbox/dev \
--test_image_caption ./data/Flickr30k/flickr30k-textids/test.all.ids \
--test_image_feature_dir ./data/Flickr30k/flickr30k-features/$bbox/test \
--img_id_path ./data/Flickr30k/flickr30k-textids/dataset_flickr30k_name_id.txt \
--checkpoints ${save_model_dir:-""} \
--save_checkpoints ${save_checkpoints:-"True"} \
--save_steps ${save_steps:-1000} \
--weight_decay ${weight_decay:-"0.1"} \
--warmup_step ${warmup_step:-"1"} \
--validation_steps ${validation_steps:-"100"} \
--epoch $epoch \
--max_seq_len ${max_len:-512} \
--max_img_len ${max_img_len:-37} \
--learning_rate ${lr:-"5e-6"} \
--learning_rate_scale ${learning_rate_scale:-0.1} \
--learning_rate_decay_epoch1 ${learning_rate_decay_epoch1:-24} \
--learning_rate_decay_epoch2 ${learning_rate_decay_epoch2:-32} \
--lr_scheduler ${lr_scheduler:-"scale_by_epoch_decay"} \
--skip_steps ${skip_steps:-"50"} \
--num_iteration_per_drop_scope 10 \
--unimo_vocab_file ${vocab_file} \
--encoder_json_file ${bpe_json} \
--vocab_bpe_file ${bpe_file} \
--unimo_config_path ${config_path} \
--eval_mertrics ${eval_mertrics:-"recall@k"} \
--eval_dir $eval_dir \
--random_seed ${seed:-1} \
>> $log_dir/${log_prefix}lanch.log 2>&1
if [[ $? -ne 0 ]]; then
echo "run failed"
exit 1
fi
exit 0
output_name="seq2seq"
init_model="./model_files/unimo_base_en"
data_path='./data/coqa'
# hyper param
lr_scheduler="linear_warmup_decay"
use_fp16="False"
# Merge the ALLReduce times of a layer
use_fuse="True"
use_hierarchical_allreduce="True"
loss_scaling=12800
skip_steps=100
save_steps=10000
validation_steps=10000
label_smooth=0.1
weight_decay=0.01
max_seq_len=512
random_seed=666
#for multi-turn dialog/qa
task_type="dialog"
role_type_size=3
turn_type_size=16
#decoding params
do_decode="true"
max_src_len=480
max_tgt_len=32
max_out_len=30
min_out_len=0
beam_size=3
length_penalty=0.0
block_trigram="False"
use_multi_gpu_test="True"
#adam optimizer
beta1=0.9
beta2=0.98
epsilon=1e-06
#data
tokenized_input="True"
continuous_position="False"
#dataset
train_set="train.tsv"
dev_set="dev.tsv"
test_set="dev.tsv"
do_train="true"
do_val="true"
do_test="false"
do_pred="false"
#evaluate
eval_script="bash ./src/eval/tasks/coqa/eval.sh"
eval_mertrics="f1"
## turning params
in_tokens="False"
pred_batch_size=4
epoch=20
BATCH_SIZE=("8")
LR_RATE=("1e-5")
DD_RAND_SEED=("1")
WARMUP_PROP=("0.06")
config_path="./model_files/config/unimo_base_en.json"
vocab_file="./model_files/dict/unimo_en.vocab.txt"
bpe_json="./model_files/dict/unimo_en.encoder.json"
bpe_file="./model_files/dict/unimo_en.vocab.bpe"
\ No newline at end of file
#!/usr/bin/env bash
set -eux
R_DIR=`dirname $0`; MYDIR=`cd $R_DIR;pwd`
cd ${MYDIR}/../../../
# config env
source ${MYDIR}/model_conf
source ./env.sh
source ./utils.sh
# check
check_iplist
set -eu
output_dir=../output-coqa
log_dir=../log-coqa
mkdir -p $output_dir $log_dir
e_executor=$(echo ${use_experimental_executor-'True'} | tr '[A-Z]' '[a-z]')
use_fuse=$(echo ${use_fuse-'False'} | tr '[A-Z]' '[a-z]')
if [[ ${use_fuse} == "true" ]]; then
#MB
export FLAGS_fuse_parameter_memory_size=64
fi
export DEV_PREFIX=`echo ${dev_set:-"dev.tsv"} | sed 's/\.tsv$//'`
export TEST_PREFIX=`echo ${test_set:-"test.tsv"} | sed 's/\.tsv$//'`
export PRED_PREFIX=`echo ${pred_set:-"pred.tsv"} | sed 's/\.tsv$//'`
export EVAL_SCRIPT_LOG=${MYDIR}/../../../${output_dir}/eval.log
export TASK_DATA_PATH=${data_path}
distributed_args="--node_ips ${PADDLE_TRAINERS} \
--node_id ${PADDLE_TRAINER_ID} \
--current_node_ip ${POD_IP} \
--selected_gpus 4,5,6,7 \
--split_log_path $log_dir \
--nproc_per_node 4"
for random_seed in "${DD_RAND_SEED[@]}"; do
echo "random_seed "${random_seed}
for batch_size in "${BATCH_SIZE[@]}"; do
echo "batch_size "${batch_size}
for warmup_proportion in "${WARMUP_PROP[@]}"; do
echo "warmup_proportion "${warmup_proportion}
for learning_rate in "${LR_RATE[@]}"; do
echo "learning rate "${learning_rate}
python -u ./src/launch.py ${distributed_args} \
./src/run_seq2seq.py --use_cuda "True" \
--is_distributed "True" \
--use_multi_gpu_test ${use_multi_gpu_test:-"True"} \
--use_fp16 ${use_fp16:-"False"} \
--use_dynamic_loss_scaling ${use_fp16} \
--init_loss_scaling ${loss_scaling:-128} \
--use_fast_executor ${e_executor:-"True"} \
--use_fuse ${use_fuse:-"False"} \
--nccl_comm_num ${nccl_comm_num:-1} \
--use_hierarchical_allreduce ${use_hierarchical_allreduce:-"False"} \
--do_train ${do_train:-"true"} \
--do_val ${do_val:-"false"} \
--do_test ${do_test:-"true"} \
--do_pred ${do_pred:-"false"} \
--do_decode ${do_decode:-"True"} \
--train_set ${data_path}/${train_set:-""} \
--dev_set ${data_path}/${dev_set:-""} \
--test_set ${data_path}/${test_set:-""} \
--pred_set ${data_path}/${pred_set:-""} \
--epoch ${epoch} \
--tokenized_input ${tokenized_input:-"True"} \
--task_type ${task_type:-"dialog"} \
--role_type_size ${role_type_size:-3} \
--turn_type_size ${turn_type_size:-16} \
--max_seq_len ${max_seq_len} \
--max_src_len ${max_src_len} \
--max_tgt_len ${max_tgt_len} \
--max_out_len ${max_out_len} \
--min_out_len ${min_out_len} \
--block_trigram ${block_trigram:-"True"} \
--beam_size ${beam_size:-5} \
--length_penalty ${length_penalty:-0.6} \
--hidden_dropout_prob ${hidden_dropout_prob:-0.1} \
--attention_probs_dropout_prob ${attention_probs_dropout_prob:-0.1} \
--beta1 ${beta1:-0.9} \
--beta2 ${beta2:-0.98} \
--epsilon ${epsilon:-1e-06} \
--continuous_position ${continuous_position:-"false"} \
--tgt_type_id ${tgt_type_id:-1}\
--batch_size ${batch_size} \
--pred_batch_size ${pred_batch_size} \
--in_tokens ${in_tokens:-"True"} \
--learning_rate ${learning_rate} \
--lr_scheduler ${lr_scheduler:-"linear_warmup_decay"} \
--warmup_proportion ${warmup_proportion:-0.02} \
--weight_decay ${weight_decay:-0.01} \
--weight_sharing ${weight_sharing:-"True"} \
--label_smooth ${label_smooth:-0.1} \
--init_pretraining_params ${init_model:-""} \
--unimo_vocab_file ${vocab_file} \
--encoder_json_file ${bpe_json} \
--vocab_bpe_file ${bpe_file} \
--unimo_config_path ${config_path} \
--checkpoints $output_dir \
--save_steps ${save_steps:-10000} \
--validation_steps ${validation_steps:-10000} \
--skip_steps ${skip_steps:-10} \
--save_and_valid_by_epoch ${save_and_valid_by_epoch:-"False"} \
--eval_script ${eval_script:-""} \
--eval_mertrics ${eval_mertrics:-"bleu_1"} \
--random_seed ${random_seed:-"666"} >> $log_dir/lanch.log 2>&1
done
done
done
done
python ./src/utils/extract_eval_res.py --log_dir=$log_dir
exit 0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""args for classification task"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from utils.args import ArgumentGroup
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("init_pretraining_params", str, None,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints.")
model_g.add_arg("save_checkpoints", bool, True, "Whether to save checkpoints")
model_g.add_arg("weight_sharing", bool, True, "If set, share weights between word embedding and masked lm.")
model_g.add_arg("unimo_vocab_file", str, './model_files/dict/unimo_en.vocab.txt', "unimo vocab")
model_g.add_arg("encoder_json_file", str, './model_files/dict/unimo_en.encoder.json', 'bpt map')
model_g.add_arg("vocab_bpe_file", str, './model_files/dict/unimo_en.vocab.bpe', "vocab bpe")
model_g.add_arg("unimo_config_path", str, "./model_files/config/unimo_base_en.json",
"The file to save unimo configuration.")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5, "Learning rate used to train with warmup.")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for.")
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("nccl_comm_num", int, 1, "NCCL comm num.")
train_g.add_arg("hierarchical_allreduce_inter_nranks", int, 8, "Hierarchical allreduce inter ranks.")
train_g.add_arg("use_hierarchical_allreduce", bool, False, "Use hierarchical allreduce or not.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("use_dynamic_loss_scaling", bool, False, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
train_g.add_arg("beta1", float, 0.9, "beta1 for adam")
train_g.add_arg("beta2", float, 0.98, "beta2 for adam.")
train_g.add_arg("epsilon", float, 1e-06, "epsilon for adam.")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("train_set", str, None, "Path to training data.")
data_g.add_arg("test_set", str, None, "Path to test data.")
data_g.add_arg("test_hard_set", str, None, "Path to test_hard data.")
data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("dev_hard_set", str, None, "Path to validation_hard data.")
data_g.add_arg("diagnostic_set", str, None, "Path to diagnostic data.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
data_g.add_arg("num_labels", int, 2, "label number")
data_g.add_arg("max_query_length", int, 64, "Max query length.")
data_g.add_arg("max_answer_length", int, 100, "Max answer length.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, False, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, False, "Whether to perform evaluation on dev data set.")
run_type_g.add_arg("do_val_hard", bool, False, "Whether to perform evaluation on dev hard data set.")
run_type_g.add_arg("do_test", bool, False, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("do_test_hard", bool, False, "Whether to perform evaluation on test hard data set.")
run_type_g.add_arg("do_pred", bool, False, "Whether to predict on test data set.")
run_type_g.add_arg("do_pred_hard", bool, False, "Whether to predict on test hard data set.")
run_type_g.add_arg("do_diagnostic", bool, False, "Whether to predict on diagnostic data set.")
run_type_g.add_arg("pred_save", str, "./output/predict/test", "Whether to predict on test data set.")
run_type_g.add_arg("use_multi_gpu_test", bool, False, "Whether to perform evaluation using multiple gpu cards")
run_type_g.add_arg("eval_mertrics", str, "simple_accuracy", "eval_mertrics")
# yapf: enable
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""args for image-to-text generation"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import argparse
from utils.args import ArgumentGroup
class CustomAction(argparse.Action):
"""custom action"""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, " ".join(values))
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("init_pretraining_params", str, None,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints.")
model_g.add_arg("weight_sharing", bool, True, "If set, share weights between word embedding and masked lm.")
model_g.add_arg("unimo_vocab_file", str, './model_files/dict/unimo_en.vocab.txt', "unimo vocab")
model_g.add_arg("encoder_json_file", str, './model_files/dict/unimo_en.encoder.json', 'bpt map')
model_g.add_arg("vocab_bpe_file", str, './model_files/dict/unimo_en.vocab.bpe', "vocab bpe")
model_g.add_arg("unimo_config_path", str, "./model_files/config/unimo_base_en.json",
"The file to save unimo configuration.")
model_g.add_arg("object_file", str, "./data/coco_object_0.35_tot.ids", "The object file for image bounding boxes.")
model_g.add_arg("adv_type", str, "villa", "The adversial learning type: freelb_image, freelb_text, villa")
model_g.add_arg("adv_step", int, 4, "adv_step")
model_g.add_arg("adv_lr", float, 0.05, "adv_lr")
model_g.add_arg("norm_type", str, 'l2', "norm_type")
model_g.add_arg("adv_max_norm", float, 0.4, "adv_max_norm")
model_g.add_arg("adv_init_mag", float, 0.4, "adv_init_mag")
model_g.add_arg("adv_kl_weight", float, 1.5, "adv_kl_weight")
model_g.add_arg("with_pure_model", bool, True, "whether include the pure model during adv learning")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 50, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 4e-5, "Learning rate used to train with warmup.")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float, 0.02,
"Proportion of training steps to perform linear learning rate warmup for.")
train_g.add_arg("save_steps", int, 100000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 100000, "The steps interval to evaluate model performance.")
train_g.add_arg("use_fuse", bool, False, "Whether to use fuse_allreduce_ops.")
train_g.add_arg("nccl_comm_num", int, 1, "NCCL comm num.")
train_g.add_arg("hierarchical_allreduce_inter_nranks", int, 8, "Hierarchical allreduce inter ranks.")
train_g.add_arg("use_hierarchical_allreduce", bool, False, "Use hierarchical allreduce or not.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("use_dynamic_loss_scaling", bool, False, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 128.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
train_g.add_arg("beta1", float, 0.9, "beta1 for adam")
train_g.add_arg("beta2", float, 0.98, "beta2 for adam.")
train_g.add_arg("epsilon", float, 1e-06, "epsilon for adam.")
train_g.add_arg("tgt_type_id", int, 1, "for seq2seq task.")
train_g.add_arg("do_decode", bool, False, "for seq2seq task.")
train_g.add_arg("label_smooth", float, 0.1, "label smooth")
train_g.add_arg("hidden_dropout_prob", float, 0.1, "hidden_dropout_prob")
train_g.add_arg("attention_probs_dropout_prob", float, 0.1, "attention_probs_dropout_prob")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 100, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, True, "Whether to output verbose log.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("task_type", str, "normal", "is task type")
data_g.add_arg("train_filelist", str, None, "Path to training data.")
data_g.add_arg("test_filelist", str, None, "Path to test data.")
data_g.add_arg("valid_filelist", str, None, "Path to validation data.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("max_tgt_len", int, 512, "for seq2seq task.")
data_g.add_arg("max_out_len", int, 512, "for seq2seq task.")
data_g.add_arg("min_out_len", int, 20, "for seq2seq task.")
data_g.add_arg("block_trigram", bool, True, "utilize trigram blocking during beam search")
data_g.add_arg("beam_size", int, 5, "for seq2seq task.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training.")
data_g.add_arg("pred_batch_size", int, 0, "Total examples' number in batch for training.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("length_penalty", float, 0.6, "length_penalty")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("visualdl_log", bool, False, "If set, use visualdl_log on paddlecloud.")
run_type_g.add_arg("is_distributed", bool, True, "If set, then start distributed training.")
run_type_g.add_arg("use_fast_executor", bool, True, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 1, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation on dev data set.")
run_type_g.add_arg("do_test", bool, True, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("do_pred", bool, True, "Whether to perform evaluation on pred data set.")
run_type_g.add_arg("use_multi_gpu_test", bool, True, "Whether to perform evaluation using multiple gpu cards")
run_type_g.add_arg("save_and_valid_by_epoch", bool, False, "save_and_valid_by_epoch")
run_type_g.add_arg("eval_script", action=CustomAction, type=str, nargs='+', help="eval_script", default=None)
run_type_g.add_arg("eval_mertrics", str, "", "eval_mertrics")
run_type_g.add_arg("random_seed", int, 0, "Random seed.")
image_g = ArgumentGroup(parser, "image", "image configuration options")
image_g.add_arg("image_embedding_size", int, 2048, "Image feature size==2048.")
image_g.add_arg("max_img_len", int, 37, "Image feature size==2048.")
image_g.add_arg("max_obj_len", int, 50, "max num of object size.")
# yapf: enable
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""args for regression task"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from utils.args import ArgumentGroup
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("init_pretraining_params", str, None,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints.")
model_g.add_arg("save_checkpoints", bool, True, "Whether to save checkpoints")
model_g.add_arg("weight_sharing", bool, True, "If set, share weights between word embedding and masked lm.")
model_g.add_arg("unimo_vocab_file", str, './model_files/dict/unimo_en.vocab.txt', "unimo vocab")
model_g.add_arg("encoder_json_file", str, './model_files/dict/unimo_en.encoder.json', 'bpt map')
model_g.add_arg("vocab_bpe_file", str, './model_files/dict/unimo_en.vocab.bpe', "vocab bpe")
model_g.add_arg("unimo_config_path", str, "./model_files/config/unimo_base_en.json",
"The file to save unimo configuration.")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5, "Learning rate used to train with warmup.")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float, 0.1,
"Proportion of training steps to perform linear learning rate warmup for.")
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("nccl_comm_num", int, 1, "NCCL comm num.")
train_g.add_arg("hierarchical_allreduce_inter_nranks", int, 8, "Hierarchical allreduce inter ranks.")
train_g.add_arg("use_hierarchical_allreduce", bool, False, "Use hierarchical allreduce or not.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("use_dynamic_loss_scaling", bool, False, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
train_g.add_arg("beta1", float, 0.9, "beta1 for adam")
train_g.add_arg("beta2", float, 0.98, "beta2 for adam.")
train_g.add_arg("epsilon", float, 1e-06, "epsilon for adam.")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("train_set", str, None, "Path to training data.")
data_g.add_arg("test_set", str, None, "Path to test data.")
data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
data_g.add_arg("num_labels", int, 2, "label number")
data_g.add_arg("max_query_length", int, 64, "Max query length.")
data_g.add_arg("max_answer_length", int, 100, "Max answer length.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("visualdl_log", bool, False, "If set, use visualdl_log on paddlecloud.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, False, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, False, "Whether to perform evaluation on dev data set.")
run_type_g.add_arg("do_test", bool, False, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("do_pred", bool, False, "Whether to predict on test data set.")
run_type_g.add_arg("pred_save", str, "./output/predict/test", "Whether to predict on test data set.")
run_type_g.add_arg("use_multi_gpu_test", bool, False, "Whether to perform evaluation using multiple gpu cards")
run_type_g.add_arg("eval_mertrics", str, "simple_accuracy", "eval_mertrics")
# yapf: enable
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""args for retrieval task"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from utils.args import ArgumentGroup
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("run_random", bool, False, "run model with random params")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("init_pretraining_params", str, None,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints.")
model_g.add_arg("save_checkpoints", bool, True, "Whether to save checkpoints")
model_g.add_arg("weight_sharing", bool, True, "If set, share weights between word embedding and masked lm.")
model_g.add_arg("unimo_vocab_file", str, './model_files/dict/unimo_en.vocab.txt', "unimo vocab")
model_g.add_arg("encoder_json_file", str, './model_files/dict/unimo_en.encoder.json', 'bpt map')
model_g.add_arg("vocab_bpe_file", str, './model_files/dict/unimo_en.vocab.bpe', "vocab bpe")
model_g.add_arg("unimo_config_path", str, "./model_files/config/unimo_base_en.json",
"The file to save unimo configuration.")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 5e-5, "Learning rate used to train.")
train_g.add_arg("learning_rate_scale", float, 0.1, "Learning rate decay scale.")
train_g.add_arg("lr_scheduler", str, "scale_by_epoch_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay', 'scale_by_epoch_decay'])
train_g.add_arg("learning_rate_decay_epoch1", int, 24, "Learning rate decay epoch1.")
train_g.add_arg("learning_rate_decay_epoch2", int, 32, "Learning rate decay epoch2.")
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_step", int, 1, "warmup_step, 1 for scale_by_epoch_decay, 0 for others")
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.")
train_g.add_arg("nccl_comm_num", int, 1, "NCCL comm num.")
train_g.add_arg("hierarchical_allreduce_inter_nranks", int, 8, "Hierarchical allreduce inter ranks.")
train_g.add_arg("use_hierarchical_allreduce", bool, False, "Use hierarchical allreduce or not.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("use_dynamic_loss_scaling", bool, False, "Whether to use dynamic loss scaling.")
train_g.add_arg("use_sigmoid", bool, True, "Whether to use sigmoid before loss")
train_g.add_arg("init_loss_scaling", float, 1.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
train_g.add_arg("beta1", float, 0.9, "beta1 for adam")
train_g.add_arg("beta2", float, 0.98, "beta2 for adam.")
train_g.add_arg("epsilon", float, 1e-06, "epsilon for adam.")
train_g.add_arg("use_fuse", bool, False, "Whether to use fuse_allreduce_ops.")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log.")
log_g.add_arg("eval_dir", str, "", "eval_dir to save tmp data")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("samples_num", int, 20, "neg sample num.")
data_g.add_arg("train_image_caption", str, None, "Path to training data.")
data_g.add_arg("train_image_feature_dir", str, None, "data dir to training data.")
data_g.add_arg("test_image_caption", str, None, "Path to test data.")
data_g.add_arg("test_image_feature_dir", str, None, "data dir to test data.")
data_g.add_arg("dev_image_caption", str, None, "Path to validation data.")
data_g.add_arg("dev_image_feature_dir", str, None, "data dir to validation data.")
data_g.add_arg("img_id_path", str, None, "img_id_path.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("test_batch_size", int, 24, "Total examples' number in batch for testing.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("random_seed", int, 0, "Random seed.")
data_g.add_arg("max_img_len", int, 37, "Image feature size==2048.")
data_g.add_arg("scale_circle", float, "1.0", "The scale factor in circle loss function, only use in circle loss mode")
data_g.add_arg("margin", float, "0.2", "The margin value in loss function")
data_g.add_arg("max_neg_cap_num", int, 0, "max_neg_cap_num")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 10, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation on dev data set.")
run_type_g.add_arg("do_test", bool, True, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("use_multi_gpu_test", bool, False, "Whether to perform evaluation using multiple gpu cards")
run_type_g.add_arg("eval_mertrics", str, "recall@k", "eval_mertrics")
# yapf: enable
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""args for seq2seq generation"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import argparse
from utils.args import ArgumentGroup
class CustomAction(argparse.Action):
"""custom action"""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, " ".join(values))
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("init_pretraining_params", str, None,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid.")
model_g.add_arg("checkpoints", str, "checkpoints", "Path to save checkpoints.")
model_g.add_arg("weight_sharing", bool, True, "If set, share weights between word embedding and masked lm.")
model_g.add_arg("unimo_vocab_file", str, './model_files/dict/unimo_en.vocab.txt', "unimo vocab")
model_g.add_arg("encoder_json_file", str, './model_files/dict/unimo_en.encoder.json', 'bpt map')
model_g.add_arg("vocab_bpe_file", str, './model_files/dict/unimo_en.vocab.bpe', "vocab bpe")
model_g.add_arg("unimo_config_path", str, "./model_files/config/unimo_base_en.json",
"The file to save unimo configuration.")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 50, "Number of epoches for fine-tuning.")
train_g.add_arg("learning_rate", float, 4e-5, "Learning rate used to train with warmup.")
train_g.add_arg("lr_scheduler", str, "linear_warmup_decay",
"scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
train_g.add_arg("weight_decay", float, 0.01, "Weight decay rate for L2 regularizer.")
train_g.add_arg("warmup_proportion", float, 0.02,
"Proportion of training steps to perform linear learning rate warmup for.")
train_g.add_arg("save_steps", int, 100000, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 100000, "The steps interval to evaluate model performance.")
train_g.add_arg("use_fuse", bool, False, "Whether to use fuse_allreduce_ops.")
train_g.add_arg("nccl_comm_num", int, 1, "NCCL comm num.")
train_g.add_arg("hierarchical_allreduce_inter_nranks", int, 8, "Hierarchical allreduce inter ranks.")
train_g.add_arg("use_hierarchical_allreduce", bool, False, "Use hierarchical allreduce or not.")
train_g.add_arg("use_fp16", bool, False, "Whether to use fp16 mixed precision training.")
train_g.add_arg("use_dynamic_loss_scaling", bool, False, "Whether to use dynamic loss scaling.")
train_g.add_arg("init_loss_scaling", float, 128.0,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
train_g.add_arg("incr_every_n_steps", int, 100, "Increases loss scaling every n consecutive.")
train_g.add_arg("decr_every_n_nan_or_inf", int, 2,
"Decreases loss scaling every n accumulated steps with nan or inf gradients.")
train_g.add_arg("incr_ratio", float, 2.0,
"The multiplier to use when increasing the loss scaling.")
train_g.add_arg("decr_ratio", float, 0.8,
"The less-than-one-multiplier to use when decreasing.")
train_g.add_arg("beta1", float, 0.9, "beta1 for adam")
train_g.add_arg("beta2", float, 0.98, "beta2 for adam.")
train_g.add_arg("epsilon", float, 1e-06, "epsilon for adam.")
train_g.add_arg("tgt_type_id", int, 1, "for seq2seq task.")
train_g.add_arg("do_decode", bool, False, "for seq2seq task.")
train_g.add_arg("label_smooth", float, 0.1, "label smooth")
train_g.add_arg("hidden_dropout_prob", float, 0.1, "hidden_dropout_prob")
train_g.add_arg("attention_probs_dropout_prob", float, 0.1, "attention_probs_dropout_prob")
log_g = ArgumentGroup(parser, "logging", "logging related.")
log_g.add_arg("skip_steps", int, 100, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, True, "Whether to output verbose log.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("task_type", str, "normal", "is task type")
data_g.add_arg("train_set", str, None, "Path to training data.")
data_g.add_arg("test_set", str, None, "Path to test data.")
data_g.add_arg("dev_set", str, None, "Path to validation data.")
data_g.add_arg("pred_set", str, None, "Path to pred data.")
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
data_g.add_arg("max_tgt_len", int, 512, "for seq2seq task.")
data_g.add_arg("max_src_len", int, 512, "for seq2seq task.")
data_g.add_arg("max_out_len", int, 512, "for seq2seq task.")
data_g.add_arg("min_out_len", int, 20, "for seq2seq task.")
data_g.add_arg("block_trigram", bool, True, "utilize trigram blocking during beam search")
data_g.add_arg("beam_size", int, 5, "for seq2seq task.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("pred_batch_size", int, 0, "Total examples' number in batch for training. see also --in_tokens.")
data_g.add_arg("in_tokens", bool, False,
"If set, the batch size will be the maximum number of tokens in one batch. "
"Otherwise, it will be the maximum number of examples in one batch.")
data_g.add_arg("do_lower_case", bool, True,
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
data_g.add_arg("tokenized_input", bool, True, "input is tokenized")
data_g.add_arg("length_penalty", float, 0.6, "length_penalty")
data_g.add_arg("continuous_position", bool, False, "position is continuous")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("visualdl_log", bool, False, "If set, use visualdl_log on paddlecloud.")
run_type_g.add_arg("is_distributed", bool, True, "If set, then start distributed training.")
run_type_g.add_arg("use_fast_executor", bool, True, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 1, "Iteration intervals to drop scope.")
run_type_g.add_arg("do_train", bool, True, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation on dev data set.")
run_type_g.add_arg("do_test", bool, True, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("do_pred", bool, True, "Whether to perform evaluation on pred data set.")
run_type_g.add_arg("use_multi_gpu_test", bool, True, "Whether to perform evaluation using multiple gpu cards")
run_type_g.add_arg("save_and_valid_by_epoch", bool, False, "save_and_valid_by_epoch")
run_type_g.add_arg("eval_script", action=CustomAction, type=str, nargs='+', help="eval_script", default=None)
run_type_g.add_arg("eval_mertrics", str, "", "eval_mertrics")
run_type_g.add_arg("random_seed", int, 0, "Random seed.")
dialo_g = ArgumentGroup(parser, "dialogue", "for dialogue task.")
dialo_g.add_arg("role_type_size", int, 2, "role type size")
dialo_g.add_arg("turn_type_size", int, 16, "turn type size")
# yapf: enable
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils help and eval functions for text generation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import math
import subprocess
from model.tokenization import GptBpeTokenizer, BasicTokenizer
class GenerationEval(object):
"""GenerationEval"""
def __init__(self, args):
self.eval_script = args.eval_script.split(" ")
self.eval_mertrics = args.eval_mertrics.split(",") if args.eval_mertrics else []
self.basic_tokenizer = BasicTokenizer(do_lower_case=True)
self.roberta_tokenizer = GptBpeTokenizer(vocab_file=args.unimo_vocab_file,
encoder_json_file=args.encoder_json_file,
vocab_bpe_file=args.vocab_bpe_file,
do_lower_case=True)
def eval(self, output_file, phase="", features=None):
"""run eval"""
eval_res = {}
if self.eval_script:
eval_res = subprocess.check_output(self.eval_script + [output_file, phase])
eval_res = json.loads(eval_res)
else:
preds = []
for line in open(output_file):
preds.append(self.basic_tokenizer.tokenize(line.strip()))
refs = []
for id in sorted(features.keys()):
ref_str = self.roberta_tokenizer.gptbpe_tokenizer.decode(features[id].tgt.split(" "))
refs.append([self.basic_tokenizer.tokenize(ref_str)])
for mertric in self.eval_mertrics:
eval_func = getattr(self, mertric, None)
if eval_func:
eval_res[mertric] = eval_func(refs, preds)
ret = []
for mertric in self.eval_mertrics:
mertric_res = eval_res.get(mertric, None)
if mertric_res is None:
raise Exception("Eval mertric: %s is not supported" % mertric)
ret.append("%s: %f" % (mertric, mertric_res))
return ", ".join(ret)
def bleu(self, refs, preds):
"""bleu mertric"""
return _compute_bleu(refs, preds, max_order=4)[0]
def _get_ngrams(segment, max_order):
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i: i + order])
ngram_counts[ngram] += 1
return ngram_counts
def _compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False):
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0
for (references, translation) in zip(reference_corpus, translation_corpus):
reference_length += min(len(r) for r in references)
translation_length += len(translation)
merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
translation_ngram_counts = _get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for order in range(1, max_order + 1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order - 1] += possible_matches
precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = ((matches_by_order[i] + 1.) /
(possible_matches_by_order[i] + 1.))
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (float(matches_by_order[i]) /
possible_matches_by_order[i])
else:
precisions[i] = 0.0
if min(precisions) > 0:
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
else:
bp = math.exp(1 - 1. / (ratio + 1e-4))
bleu = geo_mean * bp
ret = [bleu, precisions, bp, ratio, translation_length, reference_length]
return ret
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ultis help and eval functions for glue ."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from scipy.stats import pearsonr, spearmanr
from six.moves import xrange
import paddle.fluid as fluid
from functools import partial
from collections import OrderedDict
def matthews_corrcoef(preds, labels):
"""matthews_corrcoef"""
preds = np.array(preds)
labels = np.array(labels)
tp = np.sum((labels == 1) & (preds == 1))
tn = np.sum((labels == 0) & (preds == 0))
fp = np.sum((labels == 0) & (preds == 1))
fn = np.sum((labels == 1) & (preds == 0))
mcc = ((tp * tn) - (fp * fn)) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
ret = OrderedDict()
ret['mat_cor'] = mcc
ret['key_eval'] = "mat_cor"
return ret
def f1_score(preds, labels):
"""f1_score"""
preds = np.array(preds)
labels = np.array(labels)
tp = np.sum((labels == 1) & (preds == 1))
tn = np.sum((labels == 0) & (preds == 0))
fp = np.sum((labels == 0) & (preds == 1))
fn = np.sum((labels == 1) & (preds == 0))
p = tp / (tp + fp)
r = tp / (tp + fn)
f1 = (2 * p * r) / (p + r + 1e-8)
ret = OrderedDict()
ret['f1'] = f1
ret['key_eval'] = "f1"
return ret
def pearson_and_spearman(preds, labels):
"""pearson_and_spearman"""
preds = np.array(preds)
labels = np.array(labels)
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
ret = OrderedDict()
ret['pearson'] = pearson_corr
ret['spearmanr'] = spearman_corr
ret['p_and_sp'] = (pearson_corr + spearman_corr) / 2
ret['key_eval'] = "p_and_sp"
return ret
def acc_and_f1(preds, labels):
"""acc_and_f1"""
preds = np.array(preds)
labels = np.array(labels)
acc = simple_accuracy(preds, labels)['acc']
f1 = f1_score(preds, labels)['f1']
ret = OrderedDict()
ret['acc'] = acc
ret['f1'] = f1
ret['acc_and_f1'] = (acc + f1) / 2
ret['key_eval'] = "acc_and_f1"
return ret
def simple_accuracy(preds, labels):
"""simple_accuracy"""
preds = np.array(preds)
labels = np.array(labels)
acc = (preds == labels).mean()
ret = OrderedDict()
ret['acc'] = acc
ret['key_eval'] = "acc"
return ret
def evaluate_mrr(preds):
"""evaluate_mrr"""
last_qid = None
total_mrr = 0.0
qnum = 0.0
rank = 0.0
correct = False
for qid, score, label in preds:
if qid != last_qid:
rank = 0.0
qnum += 1
correct = False
last_qid = qid
rank += 1
if not correct and label != 0:
total_mrr += 1.0 / rank
correct = True
return total_mrr / qnum
def evaluate_map(preds):
"""evaluate_map"""
def singe_map(st, en):
"""singe_map"""
total_p = 0.0
correct_num = 0.0
for index in xrange(st, en):
if int(preds[index][2]) != 0:
correct_num += 1
total_p += correct_num / (index - st + 1)
if int(correct_num) == 0:
return 0.0
return total_p / correct_num
last_qid = None
total_map = 0.0
qnum = 0.0
st = 0
for i in xrange(len(preds)):
qid = preds[i][0]
if qid != last_qid:
qnum += 1
if last_qid is not None:
total_map += singe_map(st, i)
st = i
last_qid = qid
total_map += singe_map(st, len(preds))
return total_map / qnum
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ultis help and eval functions for image/text retrieval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from collections import OrderedDict
def recall_at_k(score_matrix, text2img, img2texts):
"""recall@k"""
assert score_matrix.shape[0] == len(text2img) * len(img2texts)
cur_img, cur_cap = score_matrix[:, 1], score_matrix[:, 2]
img_len, cap_len = len(np.unique(cur_img)), len(np.unique(cur_cap))
cur_img_sort = np.reshape(np.argsort(cur_img), [-1, cap_len])
cur_cap_sort = np.reshape(np.argsort(cur_cap), [-1, img_len])
i2c = np.take(score_matrix, cur_img_sort, axis=0) # img_len x cap_len x 3
c2i = np.take(score_matrix, cur_cap_sort, axis=0) # cap_len x img_len x 3
def get_recall_k(scores, idx, label_dict):
"""
scores: sample x len x 5
idx: 1 means text retrieval(i2c), 2 means image retrieval(c2i)
"""
cand_idx_dict = {1: 2, 2: 1}
cand_idx = cand_idx_dict[idx]
tot = scores.shape[0]
r1, r5, r10, rank_tot = 0, 0, 0, 0
for i in range(tot):
score_mat = scores[i]
cur_ids = score_mat[0][idx]
ans_ids = label_dict[cur_ids] # when idx is 1, type is list. idx is 2, type is int
score = score_mat[:, 0]
score_sort = np.argsort(score)[::-1]
cand_ans = np.take(score_mat[:, cand_idx], score_sort, axis=0)
cand_ans = cand_ans.astype(np.int64)
if isinstance(ans_ids, list):
rank = min([np.where(cand_ans == ans)[0] for ans in ans_ids])
elif isinstance(ans_ids, int):
rank = np.where(cand_ans == ans_ids)[0]
else:
raise ValueError('type error')
if rank < 1:
r1 += 1.0
if rank < 5:
r5 += 1.0
if rank < 10:
r10 += 1.0
rank_tot += (rank + 1)
ret = {
'recall@1': float(r1)/tot,
'recall@5': float(r5)/tot,
'recall@10': float(r10)/tot,
'avg_rank': float(rank_tot)/tot
}
return ret
cap_retrieval_recall = get_recall_k(i2c, 1, img2texts)
img_retrieval_recall = get_recall_k(c2i, 2, text2img)
ret = OrderedDict()
ret['img_avg_rank'] = img_retrieval_recall['avg_rank']
ret['cap_avg_rank'] = cap_retrieval_recall['avg_rank']
ret['img_recall@1'] = img_retrieval_recall['recall@1']
ret['img_recall@5'] = img_retrieval_recall['recall@5']
ret['img_recall@10'] = img_retrieval_recall['recall@10']
ret['cap_recall@1'] = cap_retrieval_recall['recall@1']
ret['cap_recall@5'] = cap_retrieval_recall['recall@5']
ret['cap_recall@10'] = cap_retrieval_recall['recall@10']
ret['avg_img_recall'] = (img_retrieval_recall['recall@1'] + \
img_retrieval_recall['recall@5'] + img_retrieval_recall['recall@10']) /3
ret['avg_cap_recall'] = (cap_retrieval_recall['recall@1'] + \
cap_retrieval_recall['recall@5'] + cap_retrieval_recall['recall@10']) /3
ret['avg_recall@1'] = (img_retrieval_recall['recall@1'] + cap_retrieval_recall['recall@1']) /2
ret['avg_recall@5'] = (img_retrieval_recall['recall@5'] + cap_retrieval_recall['recall@5']) /2
ret['avg_recall@10'] = (img_retrieval_recall['recall@10'] + cap_retrieval_recall['recall@10']) /2
ret['key_eval'] = "avg_recall@1"
return ret
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model for classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from six.moves import xrange
import paddle.fluid as fluid
from model.unimo_finetune import UNIMOModel
from eval import glue_eval
from collections import OrderedDict
from utils.utils import print_eval_log
def create_model(args, pyreader_name, config):
"""create_model"""
stype = 'int64'
pyreader = fluid.layers.py_reader(
capacity=50,
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, args.max_seq_len], [-1, 1],
[-1, 1]],
dtypes=[stype, stype, stype, 'float32', stype, stype],
lod_levels=[0, 0, 0, 0, 0, 0],
name=pyreader_name,
use_double_buffer=True)
(src_ids, sent_ids, pos_ids, input_mask, labels,
qids) = fluid.layers.read_file(pyreader)
emb_ids = {"word_embedding": src_ids, "sent_embedding": sent_ids, "pos_embedding": pos_ids}
model = UNIMOModel(
emb_ids=emb_ids,
input_mask=input_mask,
config=config)
cls_feats = model.get_pooled_text_output()
cls_feats = fluid.layers.dropout(
x=cls_feats,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
cls_params_name = ["cls_out_%d_w" % args.num_labels, "cls_out_%d_b" % args.num_labels]
logits = fluid.layers.fc(
input=cls_feats,
size=args.num_labels,
param_attr=fluid.ParamAttr(
name=cls_params_name[0],
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name=cls_params_name[1], initializer=fluid.initializer.Constant(0.)))
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=labels, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss)
num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(input=probs, label=labels, total=num_seqs)
graph_vars = {
"loss": loss,
"probs": probs,
"accuracy": accuracy,
"labels": labels,
"num_seqs": num_seqs,
"qids": qids
}
return pyreader, graph_vars
def predict(exe, test_program, test_pyreader, graph_vars, dev_count=1):
"""predict"""
qids, scores, probs, preds = [], [], [], []
fetch_list = [graph_vars["probs"].name, graph_vars["qids"].name]
test_pyreader.start()
while True:
try:
if dev_count == 1:
np_probs, np_qids = exe.run(program=test_program, fetch_list=fetch_list)
else:
np_probs, np_qids = exe.run(fetch_list=fetch_list)
qids.extend(np_qids.reshape(-1).tolist())
np_preds = np.argmax(np_probs, axis=1).astype(np.float32)
preds.extend(np_preds)
probs.append(np_probs)
except fluid.core.EOFException:
test_pyreader.reset()
break
probs = np.concatenate(probs, axis=0).reshape([len(qids), -1])
return qids, preds, probs
def evaluate(args, exe, test_program, test_pyreader, graph_vars, eval_phase):
"""evaluate"""
total_cost, total_num_seqs = 0.0, 0.0
qids, labels, scores, preds = [], [], [], []
time_begin = time.time()
fetch_list = [
graph_vars["loss"].name,
graph_vars["probs"].name, graph_vars["labels"].name,
graph_vars["num_seqs"].name, graph_vars["qids"].name
]
test_pyreader.start()
while True:
try:
np_loss, np_probs, np_labels, np_num_seqs, np_qids = exe.run(
program=test_program, fetch_list=fetch_list) \
if not args.use_multi_gpu_test else exe.run(fetch_list=fetch_list)
total_cost += np.sum(np_loss * np_num_seqs)
total_num_seqs += np.sum(np_num_seqs)
labels.extend(np_labels.reshape((-1)).tolist())
if np_qids is not None:
qids.extend(np_qids.reshape(-1).tolist())
scores.extend(np_probs[:, 1].reshape(-1).tolist())
np_preds = list(np.argmax(np_probs, axis=1).astype(np.float32))
preds.extend([float(val) for val in np_preds])
except fluid.core.EOFException:
test_pyreader.reset()
break
time_end = time.time()
ret = OrderedDict()
ret['phase'] = eval_phase
ret['loss'] = round(total_cost / total_num_seqs, 4)
ret['data_num'] = total_num_seqs
ret['used_time'] = round(time_end - time_begin, 4)
metrics = OrderedDict()
metrics["acc_and_f1"] = glue_eval.acc_and_f1
metrics["simple_accuracy"] = glue_eval.simple_accuracy
metrics["matthews_corrcoef"] = glue_eval.matthews_corrcoef
if args.eval_mertrics in metrics:
ret_metric = metrics[args.eval_mertrics](preds, labels)
ret.update(ret_metric)
print_eval_log(ret)
else:
raise ValueError('unsupported metric {}'.format(args.eval_mertrics))
return ret
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model for classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from six.moves import xrange
import paddle.fluid as fluid
from model.unimo_finetune import UNIMOModel
from eval import glue_eval
from collections import OrderedDict
from utils.utils import print_eval_log
def create_model(args, pyreader_name, config):
"""create_model"""
stype = 'int64'
pyreader = fluid.layers.py_reader(
capacity=50,
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
[-1, args.max_seq_len, 1], [-1, args.max_seq_len, args.max_seq_len], [-1, 1],
[-1, 1]],
dtypes=[stype, stype, stype, 'float32', 'float32', stype],
lod_levels=[0, 0, 0, 0, 0, 0],
name=pyreader_name,
use_double_buffer=True)
(src_ids, sent_ids, pos_ids, input_mask, labels,
qids) = fluid.layers.read_file(pyreader)
emb_ids = {"word_embedding": src_ids, "sent_embedding": sent_ids, "pos_embedding": pos_ids}
model = UNIMOModel(
emb_ids=emb_ids,
input_mask=input_mask,
config=config)
cls_feats = model.get_pooled_text_output()
cls_feats = fluid.layers.dropout(
x=cls_feats,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
cls_params_name = ["cls_out_%d_w" % args.num_labels, "cls_out_%d_b" % args.num_labels]
logits = fluid.layers.fc(
input=cls_feats,
size=args.num_labels,
param_attr=fluid.ParamAttr(
name=cls_params_name[0],
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name=cls_params_name[1], initializer=fluid.initializer.Constant(0.)))
cost = fluid.layers.square_error_cost(input=logits, label=labels)
loss = fluid.layers.mean(x=cost)
num_seqs = fluid.layers.create_tensor(dtype='int64')
graph_vars = {
"loss": loss,
"probs": logits,
"labels": labels,
"num_seqs": num_seqs,
"qids": qids
}
return pyreader, graph_vars
def predict(exe, test_program, test_pyreader, graph_vars, dev_count=1):
"""predict"""
qids, scores, probs, preds = [], [], [], []
fetch_list = [graph_vars["probs"].name, graph_vars["qids"].name]
test_pyreader.start()
while True:
try:
if dev_count == 1:
np_probs, np_qids = exe.run(program=test_program, fetch_list=fetch_list)
else:
np_probs, np_qids = exe.run(fetch_list=fetch_list)
qids.extend(np_qids.reshape(-1).tolist())
np_preds = np.argmax(np_probs, axis=1).astype(np.float32)
preds.extend(np_preds)
probs.append(np_probs)
except fluid.core.EOFException:
test_pyreader.reset()
break
probs = np.concatenate(probs, axis=0).reshape([len(qids), -1])
return qids, preds, probs
def evaluate(args, exe, test_program, test_pyreader, graph_vars, eval_phase):
"""evaluate"""
qids, labels, scores = [], [], []
time_begin = time.time()
fetch_list = [
graph_vars["loss"].name, graph_vars["probs"].name,
graph_vars["labels"].name, graph_vars["qids"].name
]
test_pyreader.start()
while True:
try:
np_loss, np_probs, np_labels, np_qids = exe.run(
program=test_program, fetch_list=fetch_list) \
if not args.use_multi_gpu_test else exe.run(fetch_list=fetch_list)
labels.extend(np_labels.reshape((-1)).tolist())
if np_qids is not None:
qids.extend(np_qids.reshape(-1).tolist())
scores.extend(np_probs.reshape((-1)).tolist())
except fluid.core.EOFException:
test_pyreader.reset()
break
time_end = time.time()
ret = OrderedDict()
ret['phase'] = eval_phase
ret['loss'] = -1 # placeholder
ret['data_num'] = -1 # placeholder
ret['used_time'] = round(time_end - time_begin, 4)
metrics = OrderedDict()
metrics["pearson_and_spearman"] = glue_eval.pearson_and_spearman
if args.eval_mertrics in metrics:
ret_metric = metrics[args.eval_mertrics](scores, labels)
ret.update(ret_metric)
print_eval_log(ret)
else:
raise ValueError('unsupported metric {}'.format(args.eval_mertrics))
return ret
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model for retrieval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import time
import codecs
import numpy as np
import paddle.fluid as fluid
from eval import img_eval
from collections import OrderedDict
from utils.utils import print_eval_log
from model.unimo_finetune import UNIMOModel
def circle_loss(sp, sn, m, scale):
"""
sp: score list of positive samples, shape [B * L]
sn: score list of negative samples, shape [B * K]
m: relaxation factor in circle loss function
scale: scale factor in circle loss function
return: circle loss value, shape [1]
"""
op = 1. + m
on = 0. - m
delta_p = 1 - m
delta_n = m
ap = fluid.layers.relu(op - sp)
ap.stop_gradient = True
an = fluid.layers.relu(sn - on)
an.stop_gradient = True
logit_p = ap * (sp - delta_p)
logit_p = -1. * scale * logit_p
logit_p = fluid.layers.cast(x=logit_p, dtype=np.float64)
loss_p = fluid.layers.reduce_sum(fluid.layers.exp(logit_p), dim=1, keep_dim=False)
logit_n = an * (sn - delta_n)
logit_n = scale * logit_n
logit_n = fluid.layers.cast(x=logit_n, dtype=np.float64)
loss_n = fluid.layers.reduce_sum(fluid.layers.exp(logit_n), dim=1, keep_dim=False)
circle_loss = fluid.layers.log(1 + loss_n * loss_p)
circle_loss = fluid.layers.cast(x=circle_loss, dtype=np.float32)
return fluid.layers.mean(circle_loss)
def create_model(args, phase, config, samples_num):
""""create_model"""
input_mask_shape = [-1, args.max_img_len + args.max_seq_len, args.max_img_len + args.max_seq_len]
src_ids = fluid.layers.data(name='src_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
pos_ids = fluid.layers.data(name='pos_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
sent_ids = fluid.layers.data(name='sent_ids', shape=[-1, args.max_seq_len, 1], dtype='int64')
input_mask = fluid.layers.data(name='input_mask', shape=input_mask_shape, dtype='float32')
image_embedding = fluid.layers.data(
name='image_embedding',
shape=[-1, args.max_img_len, config["image_embedding_size"]],
dtype='float32')
image_loc = fluid.layers.data(name='image_loc', shape=[-1, args.max_img_len, 5], dtype='float32')
labels = fluid.layers.data(name='labels', shape=[-1, 1], dtype='int64')
ids = fluid.layers.data(name='ids', shape=[-1, 2], dtype='int64')
drop_last = True if phase == 'train' else False
feed_list = [src_ids, pos_ids, sent_ids, input_mask, image_embedding, image_loc, labels, ids]
pyreader = fluid.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=70,
use_double_buffer=True,
iterable=False,
drop_last=drop_last)
emb_ids = {"word_embedding": src_ids, "sent_embedding": sent_ids, "pos_embedding": pos_ids}
image_input = {"image_embedding": image_embedding, "loc_embedding": image_loc}
model = UNIMOModel(
emb_ids=emb_ids,
input_mask=input_mask,
config=config,
image_input=image_input,
weight_sharing=args.weight_sharing
)
text, image = model.get_pooled_output()
score = model.get_match_output(text, image, mode="mul")
score = fluid.layers.fc(
input=score,
size=1,
act=None,
param_attr=fluid.ParamAttr(
name='match_fc.w_0',
initializer=fluid.initializer.Xavier()),
bias_attr=fluid.ParamAttr(name='match_fc.b_0',
initializer=fluid.initializer.UniformInitializer()))
score = fluid.layers.reshape(score, [-1, samples_num])
if phase == 'train':
if args.use_sigmoid:
score = fluid.layers.sigmoid(score)
positive_score = score[:, 0]
image_neg_score = score[:, 1:int((samples_num + 1) / 2)]
caption_neg_score = score[:, int((samples_num + 1) / 2):]
acc = fluid.layers.accuracy(score, labels, k=1)
positive_score = fluid.layers.reshape(x=positive_score, shape=[-1, 1])
loss_c = circle_loss(positive_score, caption_neg_score, args.margin, args.scale_circle)
loss_i = circle_loss(positive_score, image_neg_score, args.margin, args.scale_circle)
total_loss = (loss_c + loss_i) / 2
else:
assert samples_num == 1
total_loss = fluid.layers.cross_entropy(input=score, label=labels)
total_loss = fluid.layers.mean(x=total_loss)
acc = fluid.layers.zeros_like(total_loss)
graph_vars = {"loss": total_loss, "acc": acc, "score": score, "label": labels, "ids": ids}
return pyreader, graph_vars
def evaluate(args, exe, test_pyreader, graph_vars, eval_phase, dev_count=1, gpu_id=0, data_reader=None):
"""evaluate"""
test_pyreader.start()
time_begin = time.time()
all_mat = None
fetch_list = [graph_vars["score"].name, graph_vars["ids"].name]
while True:
try:
score, ids = exe.run(fetch_list=fetch_list)
mat = np.concatenate([score, ids], axis=1)
if all_mat is None:
all_mat = mat
else:
all_mat = np.concatenate([all_mat, mat], axis=0)
except fluid.core.EOFException:
test_pyreader.reset()
break
time_end = time.time()
save_file = "%s/%s.trainers_%d.part_%d.npy" % (args.eval_dir, eval_phase, dev_count, gpu_id)
np.save(save_file, all_mat)
tmp_file = "%s/%s.trainers_%d.part_%d.finish" % (args.eval_dir, eval_phase, dev_count, gpu_id)
tmp_writer = codecs.open(tmp_file, "w", 'utf-8')
tmp_writer.close()
if gpu_id == 0:
while True:
ret = os.popen('find %s -maxdepth 1 -name "%s.trainers_%d.part_*.finish"' %
(args.eval_dir, eval_phase, dev_count)).readlines()
if len(ret) != dev_count:
time.sleep(1)
continue
else:
break
all_mat = None
save_files = glob.glob("%s/%s.trainers_%d.part_*.npy" % (args.eval_dir, eval_phase, dev_count))
for cur_save_file in save_files:
mat = np.load(cur_save_file)
if all_mat is None:
all_mat = mat
else:
all_mat = np.concatenate([all_mat, mat], axis=0)
cur_time = str(int(time.time()))
os.system("mkdir %s/%s" % (args.eval_dir, cur_time))
os.system("mv %s/%s.trainers_%d.* %s/%s" % (args.eval_dir, eval_phase, dev_count, args.eval_dir, cur_time))
assert data_reader is not None
text2img = {text_id: item[-1] for text_id, item in data_reader._caption_ids_dict.items()}
img2texts = data_reader._image_sent_map
ret = OrderedDict()
ret['phase'] = eval_phase
ret['loss'] = -1
ret['data_num'] = all_mat.shape[0]
ret['used_time'] = round(time_end - time_begin, 4)
metrics = OrderedDict()
metrics["recall@k"] = img_eval.recall_at_k
if args.eval_mertrics in metrics:
ret_metric = metrics[args.eval_mertrics](all_mat, text2img, img2texts)
ret.update(ret_metric)
print_eval_log(ret)
else:
raise ValueError('unsupported metric {}'.format(args.eval_mertrics))
return ret
else:
return None
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""trigram_blocking for sequence generation"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
class TrigramBlocking(object):
"""trigram blocking check data holder
"""
def __init__(self, init_token, roberta_tokenizer, beam_size, use_fp16=False):
"""use tokenizer to generate the real-tokens from sub-token ids.
but we can't pass the tokenizer to network, so we need make a trick
"""
# => [N, T==0, 1]
self._alive_seq = fluid.layers.fill_constant_batch_size_like(
input=init_token,
shape=[-1, 0, 1],
dtype=init_token.dtype,
value=0)
self._cand_seq = fluid.layers.fill_constant_batch_size_like(
input=init_token,
shape=[-1, 0, beam_size],
dtype=init_token.dtype,
value=0)
self.beam_size = beam_size
self._dtype = "float32" if not use_fp16 else "float16"
_SHAPE_PLACEHOLDER = [10, beam_size]
self._delta_score_out = fluid.layers.create_parameter(shape=_SHAPE_PLACEHOLDER, dtype=self._dtype,
name="duplicated_trigram_blocking_delta_score_out")
self.tokenizer = roberta_tokenizer
id2is_full_token = self._build_id2is_full_token(self.tokenizer, self._dtype)
self._id2is_full_token = fluid.layers.create_parameter(
shape=id2is_full_token.shape,
dtype=self._dtype,
name="duplicated_trigram_blocking_id2is_full_token",
default_initializer=fluid.initializer.NumpyArrayInitializer(id2is_full_token))
def update_seq(self, new_step_id, gather_idx):
"""update alive sequence. need pre-gather the inner seq then concat the new step id"""
# new_step_id = fluid.layers.unsqueeze(new_step_id, axes=[1])
alive_seq = fluid.layers.gather(self._alive_seq, gather_idx)
# => [N, T==1, 1]
alive_seq = fluid.layers.concat([alive_seq, new_step_id], axis=1)
fluid.layers.assign(alive_seq, self._alive_seq)
return self._alive_seq
def expand_cand_seq(self, new_topk_indx):
"""expand the alive seq by concatenating the topk candidates"""
new_topk_indx = fluid.layers.unsqueeze(new_topk_indx, axes=[1]) # (batch_size, 1, beam_size)
cand_seq = fluid.layers.expand(self._alive_seq, expand_times=[1, 1, self.beam_size])
# => [N, T+1, beam_size]
expand_cand_seq = fluid.layers.concat([cand_seq, new_topk_indx], axis=1)
fluid.layers.assign(expand_cand_seq, self._cand_seq)
return self._cand_seq
@property
def alive_seq(self):
"""alive seq"""
return self._alive_seq
@property
def cand_seq(self):
"""candidate seq"""
return self._cand_seq
@property
def delta_score_out(self):
"""delta score out"""
return self._delta_score_out
@property
def id2is_full_token(self):
"""id->isfulltoken"""
return self._id2is_full_token
@staticmethod
def blocking_forward(cand_seq, id2is_full_token):
"""py_func can't be member function
run the trigram-blocking logic. return `delta-score` for every sequence.
for seq which has duplicated trigram, set delta-score = -inf,
else set delta-score = 0
in the outer, should do the `seq-score + delta-score` logic
alive_seq: shape = [N, T, 1]
Returns
---------
np.array, shape = [N, 1]
"""
_BLOCKING_DELTA = -65000.0 # -65500.0 is the min value of float16
_KEEP_DELTA = 0.0
cand_seq = np.array(cand_seq) # (batch_size, dec_len, beam_size)
cand_seq = np.transpose(cand_seq, axes=(0, 2, 1)) # (batch_size, beam_size, dec_len)
id2is_full_token = np.array(id2is_full_token)
def _sub_token_id2full_tokens(sub_token_ids):
full_tokens = []
for sub_token_id in sub_token_ids:
is_full_token = bool(id2is_full_token[sub_token_id])
if is_full_token or not full_tokens:
full_tokens.append([sub_token_id])
else:
pre_full_token = full_tokens[-1]
pre_full_token.append(sub_token_id)
full_tokens = ["-".join(map(str, full_token)) for full_token in full_tokens]
return full_tokens
_make_trigram_str = lambda trigram_tokens: "_".join(trigram_tokens)
delta_list = []
for beam_cand_ids in cand_seq:
delta_score = []
for one_seq_ids in beam_cand_ids:
sub_token_ids = one_seq_ids.reshape(-1)
tokens = _sub_token_id2full_tokens(sub_token_ids)
if len(tokens) <= 3:
delta_score.append(_KEEP_DELTA)
continue
# don't include the last trigram(checking self)!
trigrams = [_make_trigram_str(tokens[end - 3: end]) for end in range(3, len(tokens))]
trigrams_set = set(trigrams)
last_trigram = _make_trigram_str(tokens[-3:])
if last_trigram in trigrams_set:
# duplicated
delta_score.append(_BLOCKING_DELTA)
else:
delta_score.append(_KEEP_DELTA)
delta_list.append(delta_score)
return np.array(delta_list, dtype=id2is_full_token.dtype).reshape(cand_seq.shape[0], cand_seq.shape[1])
@staticmethod
def blocking_backward(*args):
"""blocking backward"""
raise ValueError("Impossible call backward.")
def _build_id2is_full_token(self, tokenizer, dtype):
vocab_sz = tokenizer.vocab_size()
is_full_token = [0.0] * vocab_sz
for token_id in range(vocab_sz):
token = tokenizer.convert_id_to_token(token_id)
token_str = tokenizer.gptbpe_tokenizer.decode_token(token)
if token_str.startswith(' '):
is_full_token[token_id] = 1.0
return np.array(is_full_token, dtype=dtype)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
launch for multi process training
"""
import sys
import subprocess
import os
import copy
import argparse
from utils.args import ArgumentGroup, print_arguments
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
multip_g = ArgumentGroup(parser, "multiprocessing",
"start paddle training using multi-processing mode.")
multip_g.add_arg("node_ips", str, None,
"paddle trainer ips")
multip_g.add_arg("node_id", int, None,
"the trainer id of the node for multi-node distributed training.")
multip_g.add_arg("print_config", bool, True,
"print the config of multi-processing mode.")
multip_g.add_arg("current_node_ip", str, None,
"the ip of current node.")
multip_g.add_arg("split_log_path", str, "log",
"log path for each trainer.")
multip_g.add_arg("log_prefix", str, "",
"the prefix name of job log.")
multip_g.add_arg("nproc_per_node", int, 8,
"the number of process to use on each node.")
multip_g.add_arg("selected_gpus", str, "0,1,2,3,4,5,6,7",
"the gpus selected to use.")
multip_g.add_arg("training_script", str, None, "the program/script to be lauched "
"in parallel followed by all the arguments", positional_arg=True)
multip_g.add_arg("training_script_args", str, None,
"training script args", positional_arg=True, nargs=argparse.REMAINDER)
# yapf: enable
def start_procs(args):
""" start_procs """
default_env = os.environ.copy()
node_id = args.node_id
print(args.node_ips)
node_ips = [x.strip() for x in args.node_ips.split(',')]
current_ip = args.current_node_ip
num_nodes = len(node_ips)
selected_gpus = [x.strip() for x in args.selected_gpus.split(',')]
selected_gpu_num = len(selected_gpus)
start_port = int(default_env['PADDLE_PORT'])
all_trainer_endpoints = ""
for ip in node_ips:
cur_port = start_port + 1
for i in range(args.nproc_per_node):
cur_port += 1
if all_trainer_endpoints != "":
all_trainer_endpoints += ","
all_trainer_endpoints += "%s:%d" % (ip, cur_port)
nranks = num_nodes * args.nproc_per_node
gpus_per_proc = args.nproc_per_node % selected_gpu_num
if gpus_per_proc == 0:
gpus_per_proc = selected_gpu_num // args.nproc_per_node
else:
gpus_per_proc = selected_gpu_num // args.nproc_per_node + 1
selected_gpus_per_proc = [selected_gpus[i:i + gpus_per_proc]
for i in range(0, len(selected_gpus), gpus_per_proc)]
if args.print_config:
print("all_trainer_endpoints: ", all_trainer_endpoints,
", node_id: ", node_id,
", current_ip: ", current_ip,
", num_nodes: ", num_nodes,
", node_ips: ", node_ips,
", gpus_per_proc: ", gpus_per_proc,
", selected_gpus_per_proc: ", selected_gpus_per_proc,
", nranks: ", nranks)
current_env = copy.copy(default_env)
procs = []
cmds = []
log_fns = []
cur_port = start_port + 1
for i in range(0, args.nproc_per_node):
trainer_id = node_id * args.nproc_per_node + i
cur_port += 1
current_env.update({
"FLAGS_selected_gpus": "%s" % ",".join([str(s) for s in selected_gpus_per_proc[i]]),
"PADDLE_TRAINER_ID": "%d" % trainer_id,
"PADDLE_CURRENT_ENDPOINT": "%s:%d" % (current_ip, cur_port),
"PADDLE_TRAINERS_NUM": "%d" % nranks,
"PADDLE_TRAINER_ENDPOINTS": all_trainer_endpoints,
"PADDLE_NODES_NUM": "%d" % num_nodes
})
cmd = [sys.executable, "-u",
args.training_script] + args.training_script_args
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
process = subprocess.Popen(cmd, env=current_env)
procs.append(process)
for i in range(len(procs)):
proc = procs[i]
proc.wait()
if len(log_fns) > 0:
log_fns[i].close()
if proc.returncode != 0:
raise subprocess.CalledProcessError(returncode=procs[i].returncode,
cmd=cmds[i])
else:
print("proc %d finsh" % i)
print("run success")
def main(args):
""" main_func """
if args.print_config:
print_arguments(args)
start_procs(args)
if __name__ == "__main__":
lanch_args = parser.parse_args()
main(lanch_args)
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def pad_batch_data(insts,
pretraining_task='seq2seq',
pad_idx=1,
sent_b_starts=None,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False,
return_seq_lens=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype('int64').reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype('int64').reshape([-1, max_len, 1])]
if return_input_mask:
if pretraining_task is 'seq2seq':
assert sent_b_starts is not None, \
"[FATAL] For seq2seq lanugae model loss," \
" sent_b_starts should not be None"
# This is used to avoid attention on paddings and subsequent words.
input_mask_data = np.zeros((inst_data.shape[0], max_len, max_len))
for index, mask_data in enumerate(input_mask_data):
start = sent_b_starts[index]
end = len(insts[index])
mask_data[:end, :start] = 1.0
# Generate the lower triangular matrix using the slice of matrix
b = np.tril(np.ones([end - start, end - start]), 0)
mask_data[start:end, start:end] = b
input_mask_data = np.array(input_mask_data).reshape([-1, max_len, max_len])
else:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
# input_mask_data = np.matmul(input_mask_data, np.transpose(input_mask_data, (0, 2, 1)))
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
if return_seq_lens:
seq_lens = np.array([len(inst) for inst in insts])
return_list += [seq_lens.astype('int64').reshape([-1, 1])]
return return_list if len(return_list) > 1 else return_list[0]
def pad_feature_data(data, pad_value=0.0, dtype="float32", return_mask=False, batch_image_size=None):
"""for image feature sequence padding"""
# num box + 1 ,1 for global feature
max_lenth = max([len(item) for item in data])
data_width = len(data[0][0])
out_data = np.ones((len(data), max_lenth, data_width), dtype=dtype) * pad_value
out_mask = np.zeros((len(data), max_lenth, 1), dtype=dtype)
for i in range(len(data)):
out_data[i, 0:len(data[i]), :] = data[i]
if return_mask and batch_image_size[i] > 1:
out_mask[i, 0:len(data[i]), :] = 1.0
if return_mask:
return out_data, out_mask
else:
return out_data
def gen_seq2seq_mask(insts, sent_b_starts=None):
"""
generate input mask for seq2seq
"""
max_len = max(len(inst) for inst in insts)
input_mask_data = np.zeros((len(insts), max_len, max_len))
for index, mask_data in enumerate(input_mask_data):
start = sent_b_starts[index]
end = len(insts[index])
mask_data[:end, :start] = 1.0
# Generate the lower triangular matrix using the slice of matrix
b = np.tril(np.ones([end - start, end - start]), 0)
mask_data[start:end, start:end] = b
input_mask_data = np.array(input_mask_data, dtype='float32').reshape([-1, max_len, max_len])
return input_mask_data
if __name__ == "__main__":
pass
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册