.run_ce.sh 989 字节
Newer Older
G
guosheng 已提交
1 2 3
#!/bin/bash

DATA_PATH=$HOME/.cache/paddle/dataset/wmt16
G
guosheng 已提交
4 5
if [ ! -d $DATA_PATH/en_10000.dict ] ; then
    python -c 'import paddle;paddle.dataset.wmt16.train(10000, 10000, "en")().next()'
G
guosheng 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
    tar -zxf $DATA_PATH/wmt16.tar.gz -C $DATA_PATH
fi

train(){
    python -u train.py \
        --src_vocab_fpath $DATA_PATH/en_10000.dict \
        --trg_vocab_fpath $DATA_PATH/de_10000.dict \
        --special_token '<s>' '<e>' '<unk>' \
        --train_file_pattern $DATA_PATH/wmt16/train \
        --val_file_pattern $DATA_PATH/wmt16/val \
        --use_token_batch True \
        --batch_size 2048 \
        --sort_type pool \
        --pool_size 10000 \
        --enable_ce True \
        weight_sharing False \
        pass_num 20 \
        dropout_seed 10
}

26 27 28 29 30 31 32 33
cudaid=${transformer:=0} # use 0-th card as default
export CUDA_VISIBLE_DEVICES=$cudaid

train | python _ce.py

cudaid=${transformer_m:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid

G
guosheng 已提交
34
train | python _ce.py