提交 cfecaa8a 编写于 作者: Y Yibing Liu

Merge branch 'develop' of upstream into ctc_decoder_deploy

...@@ -15,6 +15,8 @@ python ./cloud/split_data.py \ ...@@ -15,6 +15,8 @@ python ./cloud/split_data.py \
--in_manifest_path=${DEV_MANIFEST} \ --in_manifest_path=${DEV_MANIFEST} \
--out_manifest_path='/local.manifest.dev' --out_manifest_path='/local.manifest.dev'
mkdir ./logs
python -u train.py \ python -u train.py \
--batch_size=${BATCH_SIZE} \ --batch_size=${BATCH_SIZE} \
--trainer_count=${NUM_GPU} \ --trainer_count=${NUM_GPU} \
...@@ -35,10 +37,10 @@ python -u train.py \ ...@@ -35,10 +37,10 @@ python -u train.py \
--train_manifest='/local.manifest.train' \ --train_manifest='/local.manifest.train' \
--dev_manifest='/local.manifest.dev' \ --dev_manifest='/local.manifest.dev' \
--mean_std_path='data/librispeech/mean_std.npz' \ --mean_std_path='data/librispeech/mean_std.npz' \
--vocab_path='data/librispeech/eng_vocab.txt' \ --vocab_path='data/librispeech/vocab.txt' \
--output_model_dir='./checkpoints' \ --output_model_dir='./checkpoints' \
--output_model_dir=${MODEL_PATH} \ --output_model_dir=${MODEL_PATH} \
--augment_conf_path='conf/augmentation.config' \ --augment_conf_path='conf/augmentation.config' \
--specgram_type='linear' \ --specgram_type='linear' \
--shuffle_method='batch_shuffle_clipped' \ --shuffle_method='batch_shuffle_clipped' \
2>&1 | tee ./log/train.log 2>&1 | tee ./logs/train.log
...@@ -17,6 +17,7 @@ python -u train.py \ ...@@ -17,6 +17,7 @@ python -u train.py \
--learning_rate=5e-4 \ --learning_rate=5e-4 \
--max_duration=27.0 \ --max_duration=27.0 \
--min_duration=0.0 \ --min_duration=0.0 \
--test_off=False \
--use_sortagrad=True \ --use_sortagrad=True \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
......
...@@ -17,6 +17,7 @@ python -u train.py \ ...@@ -17,6 +17,7 @@ python -u train.py \
--learning_rate=1e-5 \ --learning_rate=1e-5 \
--max_duration=27.0 \ --max_duration=27.0 \
--min_duration=0.0 \ --min_duration=0.0 \
--test_off=False \
--use_sortagrad=True \ --use_sortagrad=True \
--use_gru=False \ --use_gru=False \
--use_gpu=True \ --use_gpu=True \
......
...@@ -60,7 +60,8 @@ class DeepSpeech2Model(object): ...@@ -60,7 +60,8 @@ class DeepSpeech2Model(object):
num_passes, num_passes,
output_model_dir, output_model_dir,
is_local=True, is_local=True,
num_iterations_print=100): num_iterations_print=100,
test_off=False):
"""Train the model. """Train the model.
:param train_batch_reader: Train data reader. :param train_batch_reader: Train data reader.
...@@ -83,6 +84,8 @@ class DeepSpeech2Model(object): ...@@ -83,6 +84,8 @@ class DeepSpeech2Model(object):
:type is_local: bool :type is_local: bool
:param output_model_dir: Directory for saving the model (every pass). :param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring :type output_model_dir: basestring
:param test_off: Turn off testing.
:type test_off: bool
""" """
# prepare model output directory # prepare model output directory
if not os.path.exists(output_model_dir): if not os.path.exists(output_model_dir):
...@@ -120,14 +123,19 @@ class DeepSpeech2Model(object): ...@@ -120,14 +123,19 @@ class DeepSpeech2Model(object):
start_time = time.time() start_time = time.time()
cost_sum, cost_counter = 0.0, 0 cost_sum, cost_counter = 0.0, 0
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test( if test_off:
reader=dev_batch_reader, feeding=feeding_dict) print("\n------- Time: %d sec, Pass: %d" %
(time.time() - start_time, event.pass_id))
else:
result = trainer.test(
reader=dev_batch_reader, feeding=feeding_dict)
print("\n------- Time: %d sec, Pass: %d, "
"ValidationCost: %s" %
(time.time() - start_time, event.pass_id, 0))
output_model_path = os.path.join( output_model_path = os.path.join(
output_model_dir, "params.pass-%d.tar.gz" % event.pass_id) output_model_dir, "params.pass-%d.tar.gz" % event.pass_id)
with gzip.open(output_model_path, 'w') as f: with gzip.open(output_model_path, 'w') as f:
self._parameters.to_tar(f) self._parameters.to_tar(f)
print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" %
(time.time() - start_time, event.pass_id, result.cost))
# run train # run train
trainer.train( trainer.train(
......
#! /usr/bin/env bash
BATCH_SIZE_PER_GPU=64
MIN_DURATION=6.0
MAX_DURATION=7.0
function join_by { local IFS="$1"; shift; echo "$*"; }
for NUM_GPUS in 16 8 4 2 1
do
DEVICES=$(join_by , $(seq 0 $(($NUM_GPUS-1))))
BATCH_SIZE=$(($BATCH_SIZE_PER_GPU * $NUM_GPUS))
CUDA_VISIBLE_DEVICES=$DEVICES \
python train.py \
--batch_size=$BATCH_SIZE \
--num_passes=1 \
--test_off=True \
--trainer_count=$NUM_GPUS \
--min_duration=$MIN_DURATION \
--max_duration=$MAX_DURATION > tmp.log 2>&1
if [ $? -ne 0 ];then
exit 1
fi
cat tmp.log | grep "Time" | awk '{print "GPU Num: " "'"$NUM_GPUS"'" " Time: "$3}'
rm tmp.log
done
...@@ -25,6 +25,7 @@ add_arg('num_iter_print', int, 100, "Every # iterations for printing " ...@@ -25,6 +25,7 @@ add_arg('num_iter_print', int, 100, "Every # iterations for printing "
add_arg('learning_rate', float, 5e-4, "Learning rate.") add_arg('learning_rate', float, 5e-4, "Learning rate.")
add_arg('max_duration', float, 27.0, "Longest audio duration allowed.") add_arg('max_duration', float, 27.0, "Longest audio duration allowed.")
add_arg('min_duration', float, 0.0, "Shortest audio duration allowed.") add_arg('min_duration', float, 0.0, "Shortest audio duration allowed.")
add_arg('test_off', bool, False, "Turn off testing.")
add_arg('use_sortagrad', bool, True, "Use SortaGrad or not.") add_arg('use_sortagrad', bool, True, "Use SortaGrad or not.")
add_arg('use_gpu', bool, True, "Use GPU or not.") add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
...@@ -111,7 +112,8 @@ def train(): ...@@ -111,7 +112,8 @@ def train():
num_passes=args.num_passes, num_passes=args.num_passes,
num_iterations_print=args.num_iter_print, num_iterations_print=args.num_iter_print,
output_model_dir=args.output_model_dir, output_model_dir=args.output_model_dir,
is_local=args.is_local) is_local=args.is_local,
test_off=args.test_off)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册