.run_ce.sh 775 字节
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#!/bin/bash

DATA_PATH=$HOME/.cache/paddle/dataset/wmt16
if [ ! -d $DATA_PATH ] ; then
    python -c 'import paddle;paddle.dataset.wmt16.train(10000, 10000, "en")'\
        '().next()'
    tar -zxf $DATA_PATH/wmt16.tar.gz -C $DATA_PATH
fi

train(){
    python -u train.py \
        --src_vocab_fpath $DATA_PATH/en_10000.dict \
        --trg_vocab_fpath $DATA_PATH/de_10000.dict \
        --special_token '<s>' '<e>' '<unk>' \
        --train_file_pattern $DATA_PATH/wmt16/train \
        --val_file_pattern $DATA_PATH/wmt16/val \
        --use_token_batch True \
        --batch_size 2048 \
        --sort_type pool \
        --pool_size 10000 \
        --enable_ce True \
        weight_sharing False \
        pass_num 20 \
        dropout_seed 10
}

train | python _ce.py