finetuning_task=$1
init_model_path=$2

finetuning_data_path="./data/"$finetuning_task
CONFIG_PATH=${init_model_path}"/ernie_config.json"
vocab_path=${init_model_path}"/vocab.txt"
init_model=${init_model_path}"/params"
train_set=${finetuning_data_path}/train.tsv
dev_set=${finetuning_data_path}/dev.tsv
test_set=${finetuning_data_path}/test.tsv


# task specific config

if [[ $finetuning_task == "MNLI" ]];
then
    epoch="3"
    lr="8e-5,1e-4"
    batch_size="16"
    warmup=0.1
    weight_decay=0.1
    num_labels=3
    max_seq_len=256
    train_set=${finetuning_data_path}/train.tsv
    dev_set=${finetuning_data_path}/m/dev.tsv,${finetuning_data_path}/mm/dev.tsv
    test_set=${finetuning_data_path}/m/test.tsv,${finetuning_data_path}/mm/test.tsv

    gpu_card=4

elif [[ $finetuning_task == "QNLI" ]];then
    epoch="12"
    lr="6e-5,8e-5,1e-4"
    batch_size="16"
    warmup=0.1
    weight_decay=0.01
    gpu_card=4

elif [[ $finetuning_task == "QQP" ]];then
    epoch="10"
    lr="1e-4,1.25e-4,1.5e-4"
    batch_size="16"
    warmup=0.1
    weight_decay=0.00
    gpu_card=4

elif [[ $finetuning_task == "SST-2" ]];
then
    epoch="12"
    lr="6e-5,8e-5,1e-4"
    batch_size="32"
    warmup=0.1
    weight_decay=0.0
    gpu_card=2

elif [[ $finetuning_task == "CoLA" ]];
then
    epoch="10,12,15"
    lr="3e-5,5e-5,8e-5"
    batch_size="16,32"
    warmup=0.1
    weight_decay=0.01
    num_labels=2
    metric="matthews_corrcoef"
    gpu_card=1

elif [[ $finetuning_task == "RTE" ]];
then
    epoch="10,15"
    lr="1e-4,1.25e-4,1.5e-4"
    batch_size="16,32"
    warmup=0.1
    weight_decay=0.1
    gpu_card=1

elif [[ $finetuning_task == "MRPC" ]];then
    epoch="10,12,15"
    lr="1e-4,1.25e-4,1.5e-4"
    batch_size="16,32"
    warmup=0.1
    weight_decay=0.01
    has_fc="false"
    metric="acc_and_f1"
    gpu_card=1

elif [[ $finetuning_task == "STS-B" ]];then
    epoch="10,12,15"
    lr="1e-4,1.25e-4,1.5e-4"
    batch_size="16,32"
    warmup=0.1
    weight_decay=0.1
    num_labels=1
    metric="pearson_and_spearman"
    is_regression="true"
    is_classify="false"
    gpu_card=1

elif [[ $finetuning_task == "RACE" ]];
then
    epoch="5" # {all:4, middle:6, high:5}
    lr="8e-5,1e-4" # {all:8e-5,1e-4, middle:1e-4,1.25e-4,1.5e-4, high:8e-5,1e-4}
    batch_size="4" # {all:4, middle:8, high:4}
    level="high" # {all, middle, high}
    warmup=0.1
    weight_decay=0.01 # {all:0.01,middle:0.1,high:0.01}
    num_labels=4
    for_race="true"
    do_test="true"
    max_seq_len=512
    train_set=${finetuning_data_path}/train-${level}.tsv
    dev_set=${finetuning_data_path}/dev-${level}.tsv
    test_set=${finetuning_data_path}/test-${level}.tsv
    gpu_card=4

elif [[ $finetuning_task == "IMDB" ]];then
    epoch="3"
    lr="8e-5,1e-4,1.25e-4"
    batch_size="8"
    warmup=0.1
    weight_decay=0.1
    max_seq_len=512
    num_labels=2
    eval_span="true"
    train_set=${finetuning_data_path}/train.csv
    dev_set=${finetuning_data_path}/test.csv
    test_set=${finetuning_data_path}/test.csv

    gpu_card=4

elif [[ $finetuning_task == "AG" ]];then
    epoch="3"
    lr="8e-5,1e-4,1.25e-4,1.5e-4"
    batch_size="8"
    warmup=0.1
    weight_decay=0.0
    max_seq_len=512
    num_labels=4
    eval_span="true"
    train_set=${finetuning_data_path}/train.csv
    dev_set=${finetuning_data_path}/test.csv
    test_set=${finetuning_data_path}/test.csv

    gpu_card=4

elif [[ $finetuning_task == "SQuADv1" ]];
then
    epoch="2"
    lr="2.25e-4,2.5e-4,2.75e-4"
    batch_size="12"
    warmup=0.1
    weight_decay=0.0
    max_seq_len=384
    scripts="run_mrc.py"
    train_set=${finetuning_data_path}/train.json
    dev_set=${finetuning_data_path}/dev.json
    test_set=${finetuning_data_path}/dev.json

    gpu_card=4

elif [[ $finetuning_task == "SQuADv2" ]];
then
    epoch="4"
    lr="1.25e-4,1.5e-4"
    batch_size="12"
    warmup=0.1
    weight_decay=0.0
    max_seq_len=384
    scripts="run_mrc.py"
    version_2="true"
    train_set=${finetuning_data_path}/train-v2.0.json
    dev_set=${finetuning_data_path}/dev-v2.0.json
    test_set=${finetuning_data_path}/dev-v2.0.json

    gpu_card=4

fi


