提交 5de8e435 编写于 作者: X Xinghai Sun

1. Fix incorrect decoder result printing.

2. Fix incorrect batch-norm usage in RNN.
3. Fix overlapping train/dev/test manfests.
4. Update README.md and requirements.txt.
5. Expose more arguments to users in argparser.
6. Update all other details.
上级 f6d820ed
# Deep Speech 2 on PaddlePaddle # Deep Speech 2 on PaddlePaddle
## Quick Start
### Installation
Please replace `$PADDLE_INSTALL_DIR` with your paddle installation directory.
```
pip install -r requirements.txt
export LD_LIBRARY_PATH=$PADDLE_INSTALL_DIR/Paddle/third_party/install/warpctc/lib:$LD_LIBRARY_PATH
```
For some machines, we also need to install libsndfile1. Details to be added.
### Preparing Dataset(s)
``` ```
sh requirements.sh
python librispeech.py python librispeech.py
python train.py
``` ```
Please add warp-ctc library path (usually $PADDLE_INSTALL_DIR/Paddle/third_party/install/warpctc/lib) to LD_LIBRARY_PATH. More help for arguments:
```
python librispeech.py --help
```
### Traininig
For GPU Training:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4
```
For CPU Training:
```
python train.py --trainer_count 8 --use_gpu False
```
More help for arguments:
```
python train.py --help
```
### Inferencing
```
python infer.py
```
More help for arguments:
```
python infer.py --help
```
...@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) ...@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
class DataGenerator(object): class DataGenerator(object):
""" """
DataGenerator provides basic audio data preprocessing pipeline, and offer DataGenerator provides basic audio data preprocessing pipeline, and offers
both instance-level and batch-level data reader interfaces. both instance-level and batch-level data reader interfaces.
Normalized FFT are used as audio features here. Normalized FFT are used as audio features here.
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
import paddle.v2 as paddle import paddle.v2 as paddle
from itertools import groupby from itertools import groupby
import distutils.util
import argparse import argparse
import gzip import gzip
import audio_data_utils from audio_data_utils import DataGenerator
from model import deep_speech2 from model import deep_speech2
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -15,15 +16,42 @@ parser.add_argument( ...@@ -15,15 +16,42 @@ parser.add_argument(
"--num_samples", "--num_samples",
default=10, default=10,
type=int, type=int,
help="Number of samples for inference.") help="Number of samples for inference. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--num_conv_layers", default=2, type=int, help="Convolution layer number.") "--num_conv_layers",
default=2,
type=int,
help="Convolution layer number. (default: %(default)s)")
parser.add_argument(
"--num_rnn_layers",
default=3,
type=int,
help="RNN layer number. (default: %(default)s)")
parser.add_argument(
"--rnn_layer_size",
default=512,
type=int,
help="RNN layer cell number. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--num_rnn_layers", default=3, type=int, help="RNN layer number.") "--normalizer_manifest_path",
default='./manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--rnn_layer_size", default=512, type=int, help="RNN layer cell number.") "--decode_manifest_path",
default='./manifest.libri.test-clean',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--use_gpu", default=True, type=bool, help="Use gpu or not.") "--model_filepath",
default='./params.tar.gz',
type=str,
help="Model filepath. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
...@@ -39,18 +67,27 @@ def remove_duplicate_and_blank(id_list, blank_id): ...@@ -39,18 +67,27 @@ def remove_duplicate_and_blank(id_list, blank_id):
return [id for id in id_list if id != blank_id] return [id for id in id_list if id != blank_id]
def max_infer(): def best_path_decode():
""" """
Max-ctc-decoding for DeepSpeech2. Max-ctc-decoding for DeepSpeech2.
""" """
# initialize data generator
data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt',
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
# create network config # create network config
_, vocab_list = audio_data_utils.get_vocabulary() dict_size = data_generator.vocabulary_size()
dict_size = len(vocab_list) vocab_list = data_generator.vocabulary_list()
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram",
height=161, height=161,
width=1000, width=2000,
type=paddle.data_type.dense_vector(161000)) type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(dict_size))
...@@ -64,19 +101,17 @@ def max_infer(): ...@@ -64,19 +101,17 @@ def max_infer():
# load parameters # load parameters
parameters = paddle.parameters.Parameters.from_tar( parameters = paddle.parameters.Parameters.from_tar(
gzip.open("params.tar.gz")) gzip.open(args.model_filepath))
# prepare infer data # prepare infer data
feeding = { feeding = data_generator.data_name_feeding()
"audio_spectrogram": 0, test_batch_reader = data_generator.batch_reader_creator(
"transcript_text": 1, manifest_path=args.decode_manifest_path,
} batch_size=args.num_samples,
test_batch_reader = audio_data_utils.padding_batch_reader( padding_to=2000,
paddle.batch( flatten=True,
audio_data_utils.reader_creator( sort_by_duration=False,
manifest_path="./libri.manifest.test", sort_by_duration=False), shuffle=False)
batch_size=args.num_samples),
padding=[-1, 1000])
infer_data = test_batch_reader().next() infer_data = test_batch_reader().next()
# run max-ctc-decoding # run max-ctc-decoding
...@@ -89,7 +124,7 @@ def max_infer(): ...@@ -89,7 +124,7 @@ def max_infer():
# postprocess # postprocess
instance_length = len(max_id_results) / args.num_samples instance_length = len(max_id_results) / args.num_samples
instance_list = [ instance_list = [
max_id_results[i:i + instance_length] max_id_results[i * instance_length:(i + 1) * instance_length]
for i in xrange(0, args.num_samples) for i in xrange(0, args.num_samples)
] ]
for i, instance in enumerate(instance_list): for i, instance in enumerate(instance_list):
...@@ -102,7 +137,7 @@ def max_infer(): ...@@ -102,7 +137,7 @@ def max_infer():
def main(): def main():
paddle.init(use_gpu=args.use_gpu, trainer_count=1) paddle.init(use_gpu=args.use_gpu, trainer_count=1)
max_infer() best_path_decode()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
from paddle.v2.dataset.common import md5file
import os import os
import wget import wget
import tarfile import tarfile
...@@ -14,11 +15,22 @@ import argparse ...@@ -14,11 +15,22 @@ import argparse
import soundfile import soundfile
import json import json
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') DATA_HOME = os.path.expanduser('~/.cache2/paddle/dataset/speech')
URL_TEST = "http://www.openslr.org/resources/12/test-clean.tar.gz" URL_ROOT = "http://www.openslr.org/resources/12"
URL_DEV = "http://www.openslr.org/resources/12/dev-clean.tar.gz" URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
URL_TRAIN = "http://www.openslr.org/resources/12/train-clean-100.tar.gz" URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz"
URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz"
URL_DEV_OTHER = URL_ROOT + "/dev-other.tar.gz"
URL_TRAIN_CLEAN_100 = URL_ROOT + "/train-clean-100.tar.gz"
URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz"
URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz"
MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9"
MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1"
MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_CLEAN_500 = "d1a0fd59409feb2c614ce4d30c387708"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Downloads and prepare LibriSpeech dataset.') description='Downloads and prepare LibriSpeech dataset.')
...@@ -26,27 +38,33 @@ parser.add_argument( ...@@ -26,27 +38,33 @@ parser.add_argument(
"--target_dir", "--target_dir",
default=DATA_HOME + "/Libri", default=DATA_HOME + "/Libri",
type=str, type=str,
help="Directory to save the dataset.") help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--manifest", "--manifest_prefix",
default="./libri.manifest", default="manifest.libri",
type=str, type=str,
help="Filepath prefix for output manifests.") help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
def download(url, target_dir): def download(url, md5sum, target_dir):
if not os.path.exists(target_dir): """
os.makedirs(target_dir) Download file from url to target_dir, and check md5sum.
"""
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1]) filepath = os.path.join(target_dir, url.split("/")[-1])
if not os.path.exists(filepath): if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url) print("Downloading %s ..." % url)
wget.download(url, target_dir) wget.download(url, target_dir)
print("") print("\nMD5 Chesksum %s ..." % filepath)
assert md5file(filepath) == md5sum, "MD5 checksum failed."
return filepath return filepath
def unpack(filepath, target_dir): def unpack(filepath, target_dir):
"""
Unpack the file to the target_dir.
"""
print("Unpacking %s ..." % filepath) print("Unpacking %s ..." % filepath)
tar = tarfile.open(filepath) tar = tarfile.open(filepath)
tar.extractall(target_dir) tar.extractall(target_dir)
...@@ -55,6 +73,14 @@ def unpack(filepath, target_dir): ...@@ -55,6 +73,14 @@ def unpack(filepath, target_dir):
def create_manifest(data_dir, manifest_path): def create_manifest(data_dir, manifest_path):
"""
Create a manifest file summarizing the dataset (list of filepath and meta
data).
Each line of the manifest contains one audio clip filepath, its
transcription text string, and its duration. Manifest file servers as a
unified interfance to organize data sets.
"""
print("Creating manifest %s ..." % manifest_path) print("Creating manifest %s ..." % manifest_path)
json_lines = [] json_lines = []
for subfolder, _, filelist in os.walk(data_dir): for subfolder, _, filelist in os.walk(data_dir):
...@@ -81,25 +107,31 @@ def create_manifest(data_dir, manifest_path): ...@@ -81,25 +107,31 @@ def create_manifest(data_dir, manifest_path):
out_file.write(line + '\n') out_file.write(line + '\n')
def prepare_dataset(url, target_dir, manifest_path): def prepare_dataset(url, md5sum, target_dir, manifest_path):
filepath = download(url, target_dir) """
Download, unpack and create summmary manifest file.
"""
filepath = download(url, md5sum, target_dir)
unpacked_dir = unpack(filepath, target_dir) unpacked_dir = unpack(filepath, target_dir)
create_manifest(unpacked_dir, manifest_path) create_manifest(unpacked_dir, manifest_path)
def main(): def main():
prepare_dataset( prepare_dataset(
url=URL_TEST, url=URL_TEST_CLEAN,
target_dir=os.path.join(args.target_dir), md5sum=MD5_TEST_CLEAN,
manifest_path=args.manifest + ".test") target_dir=os.path.join(args.target_dir, "test-clean"),
manifest_path=args.manifest_prefix + ".test-clean")
prepare_dataset( prepare_dataset(
url=URL_DEV, url=URL_DEV_CLEAN,
target_dir=os.path.join(args.target_dir), md5sum=MD5_DEV_CLEAN,
manifest_path=args.manifest + ".dev") target_dir=os.path.join(args.target_dir, "dev-clean"),
manifest_path=args.manifest_prefix + ".dev-clean")
prepare_dataset( prepare_dataset(
url=URL_TRAIN, url=URL_TRAIN_CLEAN_100,
target_dir=os.path.join(args.target_dir), md5sum=MD5_TRAIN_CLEAN_100,
manifest_path=args.manifest + ".train") target_dir=os.path.join(args.target_dir, "train-clean-100"),
manifest_path=args.manifest_prefix + ".train-clean-100")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -24,45 +24,23 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, ...@@ -24,45 +24,23 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
return paddle.layer.batch_norm(input=conv_layer, act=act) return paddle.layer.batch_norm(input=conv_layer, act=act)
def bidirectonal_simple_rnn_bn_layer(name, input, size, act): def bidirectional_simple_rnn_bn_layer(name, input, size, act):
""" """
Bidirectonal simple rnn layer with batch normalization. Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state projection The batch normalization is only performed on input-state weights.
(sequence-wise normalization).
Question: does mean and variance statistics computed over the whole sequence
or just on each individual time steps?
""" """
# input-hidden weights shared across bi-direcitonal rnn.
def __simple_rnn_step__(input): input_proj = paddle.layer.fc(
last_state = paddle.layer.memory(name=name + "_state", size=size) input=input, size=size, act=paddle.activation.Linear(), bias_attr=False)
input_fc = paddle.layer.fc( # batch norm is only performed on input-state projection
input=input, input_proj_bn = paddle.layer.batch_norm(
size=size, input=input_proj, act=paddle.activation.Linear())
act=paddle.activation.Linear(), # forward and backward in time
bias_attr=False) forward_simple_rnn = paddle.layer.recurrent(
# batch norm is only performed on input-state projection input=input_proj_bn, act=act, reverse=False)
input_fc_bn = paddle.layer.batch_norm( backward_simple_rnn = paddle.layer.recurrent(
input=input_fc, act=paddle.activation.Linear()) input=input_proj_bn, act=act, reverse=True)
state_fc = paddle.layer.fc( return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn])
input=last_state,
size=size,
act=paddle.activation.Linear(),
bias_attr=False)
return paddle.layer.addto(
name=name + "_state", input=[input_fc_bn, state_fc], act=act)
forward = paddle.layer.recurrent_group(
step=__simple_rnn_step__, input=input)
return forward
# argument reverse is not exposed in V2 recurrent_group
#backward = paddle.layer.recurrent_group(
#step=__simple_rnn_step__,
#input=input,
#reverse=True)
#return paddle.layer.concat(input=[forward, backward])
def conv_group(input, num_stacks): def conv_group(input, num_stacks):
...@@ -86,7 +64,9 @@ def conv_group(input, num_stacks): ...@@ -86,7 +64,9 @@ def conv_group(input, num_stacks):
stride=(1, 2), stride=(1, 2),
padding=(5, 10), padding=(5, 10),
act=paddle.activation.BRelu()) act=paddle.activation.BRelu())
return conv output_num_channels = 32
output_height = 160 // pow(2, num_stacks) + 1
return conv, output_num_channels, output_height
def rnn_group(input, size, num_stacks): def rnn_group(input, size, num_stacks):
...@@ -95,7 +75,7 @@ def rnn_group(input, size, num_stacks): ...@@ -95,7 +75,7 @@ def rnn_group(input, size, num_stacks):
""" """
output = input output = input
for i in xrange(num_stacks): for i in xrange(num_stacks):
output = bidirectonal_simple_rnn_bn_layer( output = bidirectional_simple_rnn_bn_layer(
name=str(i), input=output, size=size, act=paddle.activation.BRelu()) name=str(i), input=output, size=size, act=paddle.activation.BRelu())
return output return output
...@@ -125,15 +105,16 @@ def deep_speech2(audio_data, ...@@ -125,15 +105,16 @@ def deep_speech2(audio_data,
:rtype: tuple of LayerOutput :rtype: tuple of LayerOutput
""" """
# convolution group # convolution group
conv_group_output = conv_group(input=audio_data, num_stacks=num_conv_layers) conv_group_output, conv_group_num_channels, conv_group_height = conv_group(
input=audio_data, num_stacks=num_conv_layers)
# convert data form convolution feature map to sequence of vectors # convert data form convolution feature map to sequence of vectors
conv2seq = paddle.layer.block_expand( conv2seq = paddle.layer.block_expand(
input=conv_group_output, input=conv_group_output,
num_channels=32, num_channels=conv_group_num_channels,
stride_x=1, stride_x=1,
stride_y=1, stride_y=1,
block_x=1, block_x=1,
block_y=21) block_y=conv_group_height)
# rnn group # rnn group
rnn_group_output = rnn_group( rnn_group_output = rnn_group(
input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers) input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers)
......
pip install wget
pip install soundfile
# For Ubuntu only
apt-get install libsndfile1
SoundFile==0.9.0.post1
wget==3.2
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
import distutils.util
import argparse import argparse
import gzip import gzip
import time import time
...@@ -17,21 +18,61 @@ parser = argparse.ArgumentParser( ...@@ -17,21 +18,61 @@ parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 trainer.') description='Simplified version of DeepSpeech2 trainer.')
parser.add_argument( parser.add_argument(
"--batch_size", default=32, type=int, help="Minibatch size.") "--batch_size", default=32, type=int, help="Minibatch size.")
parser.add_argument("--trainer", default=1, type=int, help="Trainer number.")
parser.add_argument( parser.add_argument(
"--num_passes", default=20, type=int, help="Training pass number.") "--num_passes",
default=20,
type=int,
help="Training pass number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--num_conv_layers", default=3, type=int, help="Convolution layer number.") "--num_conv_layers",
default=2,
type=int,
help="Convolution layer number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--num_rnn_layers", default=5, type=int, help="RNN layer number.") "--num_rnn_layers",
default=3,
type=int,
help="RNN layer number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--rnn_layer_size", default=512, type=int, help="RNN layer cell number.") "--rnn_layer_size",
default=512,
type=int,
help="RNN layer cell number. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--use_gpu", default=True, type=bool, help="Use gpu or not.") "--adam_learning_rate",
default=5e-4,
type=float,
help="Learning rate for ADAM Optimizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--use_sortagrad", default=False, type=bool, help="Use sortagrad or not.") "--use_gpu",
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--trainer_count", default=8, type=int, help="Trainer number.") "--use_sortagrad",
default=False,
type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--trainer_count",
default=4,
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--normalizer_manifest_path",
default='./manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--train_manifest_path",
default='./manifest.libri.train-clean-100',
type=str,
help="Manifest path for training. (default: %(default)s)")
parser.add_argument(
"--dev_manifest_path",
default='./manifest.libri.dev-clean',
type=str,
help="Manifest path for validation. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
...@@ -39,37 +80,15 @@ def train(): ...@@ -39,37 +80,15 @@ def train():
""" """
DeepSpeech2 training. DeepSpeech2 training.
""" """
# create data readers # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt', vocab_filepath='eng_vocab.txt',
normalizer_manifest_path='./libri.manifest.train', normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200, normalizer_num_samples=200,
max_duration=20.0, max_duration=20.0,
min_duration=0.0, min_duration=0.0,
stride_ms=10, stride_ms=10,
window_ms=20) window_ms=20)
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.dev.small',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.dev.small',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.test',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
# create network config # create network config
dict_size = data_generator.vocabulary_size() dict_size = data_generator.vocabulary_size()
...@@ -92,28 +111,58 @@ def train(): ...@@ -92,28 +111,58 @@ def train():
# create parameters and optimizer # create parameters and optimizer
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=5e-5, gradient_clipping_threshold=400) learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer) cost=cost, parameters=parameters, update_equation=optimizer)
# prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size // args.trainer_count,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size // args.trainer_count,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size // args.trainer_count,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
# create event handler # create event handler
def event_handler(event): def event_handler(event):
global start_time global start_time
global cost_sum
global cost_counter
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0: cost_sum += event.cost
cost_counter += 1
if event.batch_id % 50 == 0:
print "\nPass: %d, Batch: %d, TrainCost: %f" % ( print "\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, cost_sum / cost_counter)
cost_sum, cost_counter = 0.0, 0
with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.BeginPass): if isinstance(event, paddle.event.BeginPass):
start_time = time.time() start_time = time.time()
cost_sum, cost_counter = 0.0, 0
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_batch_reader, feeding=feeding) result = trainer.test(reader=test_batch_reader, feeding=feeding)
print "\n------- Time: %d, Pass: %d, TestCost: %s" % ( print "\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % (
time.time() - start_time, event.pass_id, result.cost) time.time() - start_time, event.pass_id, result.cost)
with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f)
# run train # run train
# first pass with sortagrad # first pass with sortagrad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册