未验证 提交 b175b5e4 编写于 作者: 0 0YuanZhang0 提交者: GitHub

upgrade_dgu_api (#4413) (#4430)

上级 7ef02b05
...@@ -13,10 +13,6 @@ ...@@ -13,10 +13,6 @@
[submodule "PaddleSpeech/DeepSpeech"] [submodule "PaddleSpeech/DeepSpeech"]
path = PaddleSpeech/DeepSpeech path = PaddleSpeech/DeepSpeech
url = https://github.com/PaddlePaddle/DeepSpeech.git url = https://github.com/PaddlePaddle/DeepSpeech.git
[submodule "PaddleCV/PaddleDetection"]
path = PaddleCV/PaddleDetection
url = https://github.com/PaddlePaddle/PaddleDetection.git
branch = release/0.2
[submodule "PaddleSpeech/Parakeet"] [submodule "PaddleSpeech/Parakeet"]
path = PaddleSpeech/Parakeet path = PaddleSpeech/Parakeet
url = https://github.com/PaddlePaddle/Parakeet url = https://github.com/PaddlePaddle/Parakeet
Subproject commit f24275a46f225e6111e8650d70baece90a37f324
...@@ -145,7 +145,7 @@ batch_size: 一个batch内输入的样本个数 ...@@ -145,7 +145,7 @@ batch_size: 一个batch内输入的样本个数
do_lower_case: 是否进行大小写转换 do_lower_case: 是否进行大小写转换
random_seed: 随机种子设置 random_seed: 随机种子设置
use_cuda: 是否使用cuda, 如果是gpu训练时,设置成true use_cuda: 是否使用cuda, 如果是gpu训练时,设置成true
in_tokens: 是否采用in_tokens模式来计算batch_siz数量, 如果in_tokens为false, 则batch_size等于真实设置的batch_size大小, 如果in_tokens为true, 则batch_size=batch_size*max_seq_len,即按照token计数 in_tokens: false
do_save_inference_model: 是否保存inference model do_save_inference_model: 是否保存inference model
encable_ce: 是否开启ce encable_ce: 是否开启ce
``` ```
...@@ -213,9 +213,8 @@ python -u main.py \ ...@@ -213,9 +213,8 @@ python -u main.py \
--task_name=${TASK_NAME} \ --task_name=${TASK_NAME} \
--use_cuda=${use_cuda} \ --use_cuda=${use_cuda} \
--do_train=true \ --do_train=true \
--in_tokens=true \
--epoch=20 \ --epoch=20 \
--batch_size=4096 \ --batch_size=32 \
--do_lower_case=true \ --do_lower_case=true \
--data_dir="./data/input/data/atis/${TASK_NAME}" \ --data_dir="./data/input/data/atis/${TASK_NAME}" \
--bert_config_path="${BERT_BASE_PATH}/bert_config.json" \ --bert_config_path="${BERT_BASE_PATH}/bert_config.json" \
...@@ -236,7 +235,7 @@ python -u main.py \ ...@@ -236,7 +235,7 @@ python -u main.py \
#### windows环境下 #### windows环境下
``` ```
python -u main.py --task_name=atis_intent --use_cuda=false --do_train=true --in_tokens=true --epoch=20 --batch_size=4096 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --init_from_pretrain_model=data\pretrain_model\uncased_L-12_H-768_A-12\params --save_model_path=data\saved_models\atis_intent --save_param=params --save_steps=100 --learning_rate=2e-5 --weight_decay=0.01 --max_seq_len=128 --print_steps=10 python -u main.py --task_name=atis_intent --use_cuda=false --do_train=true --epoch=20 --batch_size=32 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --init_from_pretrain_model=data\pretrain_model\uncased_L-12_H-768_A-12\params --save_model_path=data\saved_models\atis_intent --save_param=params --save_steps=100 --learning_rate=2e-5 --weight_decay=0.01 --max_seq_len=128 --print_steps=10
``` ```
### 模型预测 ### 模型预测
...@@ -292,8 +291,7 @@ python -u main.py \ ...@@ -292,8 +291,7 @@ python -u main.py \
--task_name=${TASK_NAME} \ --task_name=${TASK_NAME} \
--use_cuda=${use_cuda} \ --use_cuda=${use_cuda} \
--do_predict=true \ --do_predict=true \
--in_tokens=true \ --batch_size=32 \
--batch_size=4096 \
--do_lower_case=true \ --do_lower_case=true \
--data_dir="./data/input/data/atis/${TASK_NAME}" \ --data_dir="./data/input/data/atis/${TASK_NAME}" \
--init_from_params="./data/saved_models/trained_models/${TASK_NAME}/params" \ --init_from_params="./data/saved_models/trained_models/${TASK_NAME}/params" \
...@@ -307,7 +305,7 @@ python -u main.py \ ...@@ -307,7 +305,7 @@ python -u main.py \
#### windows环境下 #### windows环境下
``` ```
python -u main.py --task_name=atis_intent --use_cuda=false --do_predict=true --in_tokens=true --batch_size=4096 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --init_from_params=data\saved_models\trained_models\atis_intent\params --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --output_prediction_file=data\output\pred_atis_intent --max_seq_len=128 python -u main.py --task_name=atis_intent --use_cuda=false --do_predict=true --batch_size=32 --do_lower_case=true --data_dir=data\input\data\atis\atis_intent --init_from_params=data\saved_models\trained_models\atis_intent\params --bert_config_path=data\pretrain_model\uncased_L-12_H-768_A-12\bert_config.json --vocab_path=data\pretrain_model\uncased_L-12_H-768_A-12\vocab.txt --output_prediction_file=data\output\pred_atis_intent --max_seq_len=128
``` ```
### 模型评估 ### 模型评估
......
...@@ -71,7 +71,7 @@ def do_predict(args): ...@@ -71,7 +71,7 @@ def do_predict(args):
name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64') name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
input_mask = fluid.data( input_mask = fluid.data(
name='input_mask', name='input_mask',
shape=[-1, args.max_seq_len], shape=[-1, args.max_seq_len, 1],
dtype='float32') dtype='float32')
if args.task_name == 'atis_slot': if args.task_name == 'atis_slot':
labels = fluid.data( labels = fluid.data(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
export FLAGS_sync_nccl_allreduce=0 export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1 export FLAGS_eager_delete_tensor_gb=1
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=1
if [ ! "$CUDA_VISIBLE_DEVICES" ] if [ ! "$CUDA_VISIBLE_DEVICES" ]
then then
export CPU_NUM=1 export CPU_NUM=1
...@@ -21,7 +21,7 @@ SAVE_MODEL_PATH="./data/saved_models/${TASK_NAME}" ...@@ -21,7 +21,7 @@ SAVE_MODEL_PATH="./data/saved_models/${TASK_NAME}"
TRAIN_MODEL_PATH="./data/saved_models/trained_models" TRAIN_MODEL_PATH="./data/saved_models/trained_models"
OUTPUT_PATH="./data/output" OUTPUT_PATH="./data/output"
INFERENCE_MODEL="data/inference_models" INFERENCE_MODEL="data/inference_models"
PYTHON_PATH="python" PYTHON_PATH="python3"
if [ -f ${SAVE_MODEL_PATH} ]; then if [ -f ${SAVE_MODEL_PATH} ]; then
rm ${SAVE_MODEL_PATH} rm ${SAVE_MODEL_PATH}
...@@ -37,8 +37,7 @@ then ...@@ -37,8 +37,7 @@ then
save_steps=1000 save_steps=1000
max_seq_len=210 max_seq_len=210
print_steps=1000 print_steps=1000
batch_size=6720 batch_size=32
in_tokens=true
epoch=2 epoch=2
learning_rate=2e-5 learning_rate=2e-5
elif [ "${TASK_NAME}" = "swda" ] elif [ "${TASK_NAME}" = "swda" ]
...@@ -46,8 +45,7 @@ then ...@@ -46,8 +45,7 @@ then
save_steps=500 save_steps=500
max_seq_len=128 max_seq_len=128
print_steps=200 print_steps=200
batch_size=6720 batch_size=32
in_tokens=true
epoch=3 epoch=3
learning_rate=2e-5 learning_rate=2e-5
elif [ "${TASK_NAME}" = "mrda" ] elif [ "${TASK_NAME}" = "mrda" ]
...@@ -55,8 +53,7 @@ then ...@@ -55,8 +53,7 @@ then
save_steps=500 save_steps=500
max_seq_len=128 max_seq_len=128
print_steps=200 print_steps=200
batch_size=4096 batch_size=32
in_tokens=true
epoch=7 epoch=7
learning_rate=2e-5 learning_rate=2e-5
elif [ "${TASK_NAME}" = "atis_intent" ] elif [ "${TASK_NAME}" = "atis_intent" ]
...@@ -64,8 +61,7 @@ then ...@@ -64,8 +61,7 @@ then
save_steps=100 save_steps=100
max_seq_len=128 max_seq_len=128
print_steps=10 print_steps=10
batch_size=4096 batch_size=32
in_tokens=true
epoch=20 epoch=20
learning_rate=2e-5 learning_rate=2e-5
INPUT_PATH="./data/input/data/atis/${TASK_NAME}" INPUT_PATH="./data/input/data/atis/${TASK_NAME}"
...@@ -75,7 +71,6 @@ then ...@@ -75,7 +71,6 @@ then
max_seq_len=128 max_seq_len=128
print_steps=10 print_steps=10
batch_size=32 batch_size=32
in_tokens=False
epoch=50 epoch=50
learning_rate=2e-5 learning_rate=2e-5
INPUT_PATH="./data/input/data/atis/${TASK_NAME}" INPUT_PATH="./data/input/data/atis/${TASK_NAME}"
...@@ -83,22 +78,23 @@ elif [ "${TASK_NAME}" = "dstc2" ] ...@@ -83,22 +78,23 @@ elif [ "${TASK_NAME}" = "dstc2" ]
then then
save_steps=400 save_steps=400
print_steps=20 print_steps=20
batch_size=8192
in_tokens=true
epoch=40 epoch=40
learning_rate=5e-5 learning_rate=5e-5
INPUT_PATH="./data/input/data/dstc2/${TASK_NAME}" INPUT_PATH="./data/input/data/dstc2/${TASK_NAME}"
if [ "${TASK_TYPE}" = "train" ] if [ "${TASK_TYPE}" = "train" ]
then then
max_seq_len=256 max_seq_len=256
batch_size=32
else else
max_seq_len=512 max_seq_len=512
batch_size=16
fi fi
else else
echo "not support ${TASK_NAME} dataset.." echo "not support ${TASK_NAME} dataset.."
exit 255 exit 255
fi fi
#training #training
function train() function train()
{ {
...@@ -106,7 +102,6 @@ function train() ...@@ -106,7 +102,6 @@ function train()
--task_name=${TASK_NAME} \ --task_name=${TASK_NAME} \
--use_cuda=$1 \ --use_cuda=$1 \
--do_train=true \ --do_train=true \
--in_tokens=${in_tokens} \
--epoch=${epoch} \ --epoch=${epoch} \
--batch_size=${batch_size} \ --batch_size=${batch_size} \
--do_lower_case=true \ --do_lower_case=true \
...@@ -130,7 +125,6 @@ function predict() ...@@ -130,7 +125,6 @@ function predict()
--task_name=${TASK_NAME} \ --task_name=${TASK_NAME} \
--use_cuda=$1 \ --use_cuda=$1 \
--do_predict=true \ --do_predict=true \
--in_tokens=${in_tokens} \
--batch_size=${batch_size} \ --batch_size=${batch_size} \
--data_dir=${INPUT_PATH} \ --data_dir=${INPUT_PATH} \
--do_lower_case=true \ --do_lower_case=true \
......
...@@ -67,7 +67,7 @@ def do_train(args): ...@@ -67,7 +67,7 @@ def do_train(args):
name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64') name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
input_mask = fluid.data( input_mask = fluid.data(
name='input_mask', name='input_mask',
shape=[-1, args.max_seq_len], shape=[-1, args.max_seq_len, 1],
dtype='float32') dtype='float32')
if args.task_name == 'atis_slot': if args.task_name == 'atis_slot':
labels = fluid.data( labels = fluid.data(
...@@ -80,8 +80,9 @@ def do_train(args): ...@@ -80,8 +80,9 @@ def do_train(args):
input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels] input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
input_field = InputField(input_inst) input_field = InputField(input_inst)
data_reader = fluid.io.PyReader(
feed_list=input_inst, capacity=4, iterable=False) data_reader = fluid.io.DataLoader.from_generator(feed_list=input_inst, capacity=4, iterable=False)
processor = processors[task_name](data_dir=args.data_dir, processor = processors[task_name](data_dir=args.data_dir,
vocab_path=args.vocab_path, vocab_path=args.vocab_path,
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
...@@ -108,10 +109,8 @@ def do_train(args): ...@@ -108,10 +109,8 @@ def do_train(args):
accuracy.persistable = True accuracy.persistable = True
num_seqs.persistable = True num_seqs.persistable = True
if args.use_cuda: places = fluid.cuda_places() if args.use_cuda else fluid.cpu_places()
dev_count = fluid.core.get_cuda_device_count() dev_count = len(places)
else:
dev_count = int(os.environ.get('CPU_NUM', 1))
batch_generator = processor.data_generator( batch_generator = processor.data_generator(
batch_size=args.batch_size, phase='train', shuffle=True) batch_size=args.batch_size, phase='train', shuffle=True)
...@@ -140,7 +139,7 @@ def do_train(args): ...@@ -140,7 +139,7 @@ def do_train(args):
use_fp16=False, use_fp16=False,
loss_scaling=args.loss_scaling) loss_scaling=args.loss_scaling)
data_reader.decorate_batch_generator(batch_generator) data_reader.set_batch_generator(batch_generator, places=places)
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册