提交 f6d820ed 编写于 作者: X Xinghai Sun

Refactor data utils into a class and add feature normalization.

上级 f33f7420
""" """
Audio data preprocessing tools and reader creators. Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
import logging import logging
...@@ -9,143 +10,201 @@ import soundfile ...@@ -9,143 +10,201 @@ import soundfile
import numpy as np import numpy as np
import os import os
# TODO: add z-score normalization. RANDOM_SEED = 0
ENGLISH_CHAR_VOCAB_FILEPATH = "eng_vocab.txt"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def spectrogram_from_file(filename, class DataGenerator(object):
stride_ms=10,
window_ms=20,
max_freq=None,
eps=1e-14):
"""
Calculate the log of linear spectrogram from FFT energy
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
audio, sample_rate = soundfile.read(filename)
if audio.ndim >= 2:
audio = np.mean(audio, 1)
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
spectrogram, freqs = extract_spectrogram(
audio,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(spectrogram[:ind, :] + eps)
def extract_spectrogram(samples, window_size, stride_size, sample_rate):
""" """
Compute the spectrogram for a real discrete signal. DataGenerator provides basic audio data preprocessing pipeline, and offer
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech both instance-level and batch-level data reader interfaces.
""" Normalized FFT are used as audio features here.
# extract strided windows
truncate_size = (len(samples) - window_size) % stride_size :param vocab_filepath: Vocabulary file path for indexing tokenized
samples = samples[:len(samples) - truncate_size] transcriptions.
nshape = (window_size, (len(samples) - window_size) // stride_size + 1) :type vocab_filepath: basestring
nstrides = (samples.strides[0], samples.strides[0] * stride_size) :param normalizer_manifest_path: Manifest filepath for collecting feature
windows = np.lib.stride_tricks.as_strided( normalization statistics, e.g. mean, std.
samples, shape=nshape, strides=nstrides) :type normalizer_manifest_path: basestring
assert np.all( :param normalizer_num_samples: Number of instances sampled for collecting
windows[:, 1] == samples[stride_size:(stride_size + window_size)]) feature normalization statistics.
# window weighting, compute squared Fast Fourier Transform (fft), scaling Default is 100.
weighting = np.hanning(window_size)[:, None] :type normalizer_num_samples: int
fft = np.fft.rfft(windows * weighting, axis=0) :param max_duration: Audio clips with duration (in seconds) greater than
fft = np.absolute(fft)**2 this will be discarded. Default is 20.0.
scale = np.sum(weighting**2) * sample_rate :type max_duration: float
fft[1:-1, :] *= (2.0 / scale) :param min_duration: Audio clips with duration (in seconds) smaller than
fft[(0, -1), :] /= scale this will be discarded. Default is 0.0.
# prepare fft frequency list :type min_duration: float
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) :param stride_ms: Striding size (in milliseconds) for generating frames.
return fft, freqs Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for frames. Default is 20.0.
def vocabulary_from_file(vocabulary_path): :type window_ms: float
""" :param max_frequency: Maximun frequency for FFT features. FFT features of
Load vocabulary from file. frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
""" """
if os.path.exists(vocabulary_path):
vocab_lines = []
with open(vocabulary_path, 'r') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list
else:
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
def __init__(self,
vocab_filepath,
normalizer_manifest_path,
normalizer_num_samples=100,
max_duration=20.0,
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_frequency=None):
self.__max_duration__ = max_duration
self.__min_duration__ = min_duration
self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency
self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \
self.__load_vocabulary_from_file__(vocab_filepath)
# collect normalizer statistics
self.__mean__, self.__std__ = self.__collect_normalizer_statistics__(
manifest_path=normalizer_manifest_path,
num_samples=normalizer_num_samples)
def get_vocabulary_size(): def __audio_featurize__(self, audio_filename):
""" """
Get vocabulary size. Preprocess audio data, including feature extraction, normalization etc..
""" """
vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH) features = self.__audio_basic_featurize__(audio_filename)
return len(vocab_dict) return self.__normalize__(features)
def __text_featurize__(self, text):
"""
Preprocess text data, including tokenizing and token indexing etc..
"""
return self.__convert_text_to_char_index__(
text=text, vocabulary=self.__vocab_dict__)
def get_vocabulary(): def __audio_basic_featurize__(self, audio_filename):
""" """
Get vocabulary. Compute basic (without normalization etc.) features for audio data.
""" """
return vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH) return self.__spectrogram_from_file__(
filename=audio_filename,
stride_ms=self.__stride_ms__,
window_ms=self.__window_ms__,
max_freq=self.__max_frequency__)
def __collect_normalizer_statistics__(self, manifest_path, num_samples=100):
"""
Compute feature normalization statistics, i.e. mean and stddev.
"""
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sample for statistics
sampled_manifest = self.__random__.sample(manifest, num_samples)
# extract spectrogram feature
features = []
for instance in sampled_manifest:
spectrogram = self.__audio_basic_featurize__(
instance["audio_filepath"])
features.append(spectrogram)
features = np.hstack(features)
mean = np.mean(features, axis=1).reshape([-1, 1])
std = np.std(features, axis=1).reshape([-1, 1])
return mean, std
def parse_transcript(text, vocabulary): def __normalize__(self, features, eps=1e-14):
""" """
Convert the transcript text string to list of token index integers. Normalize features to be of zero mean and unit stddev.
""" """
return [vocabulary[w] for w in text] return (features - self.__mean__) / (self.__std__ + eps)
def __spectrogram_from_file__(self,
filename,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""
Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
audio, sample_rate = soundfile.read(filename)
if audio.ndim >= 2:
audio = np.mean(audio, 1)
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
spectrogram, freqs = self.__extract_spectrogram__(
audio,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(spectrogram[:ind, :] + eps)
def reader_creator(manifest_path, def __extract_spectrogram__(self, samples, window_size, stride_size,
sort_by_duration=True, sample_rate):
shuffle=False, """
max_duration=10.0, Compute the spectrogram by FFT for a discrete real signal.
min_duration=0.0): Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
""" """
Audio data reader creator. # extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of samples = samples[:len(samples) - truncate_size]
tokenized transcription text. nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
:param manifest_path: Filepath for Manifest of audio clip files. windows = np.lib.stride_tricks.as_strided(
:type manifest_path: basestring samples, shape=nshape, strides=nstrides)
:param sort_by_duration: Sort the audio clips by duration if set True. assert np.all(
For SortaGrad. windows[:, 1] == samples[stride_size:(stride_size + window_size)])
:type sort_by_duration: bool # window weighting, squared Fast Fourier Transform (fft), scaling
:param shuffle: Shuffle the audio clips if set True. weighting = np.hanning(window_size)[:, None]
:type shuffle: bool fft = np.fft.rfft(windows * weighting, axis=0)
:param max_duration: Audio clips with duration (in seconds) greater than fft = np.absolute(fft)**2
this will be discarded. scale = np.sum(weighting**2) * sample_rate
:type max_duration: float fft[1:-1, :] *= (2.0 / scale)
:param min_duration: Audio clips with duration (in seconds) smaller than fft[(0, -1), :] /= scale
this will be discarded. # prepare fft frequency list
:type min_duration: float freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
:return: Data reader function. return fft, freqs
:rtype: callable
"""
if sort_by_duration and shuffle:
sort_by_duration = False
logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.")
vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH)
def reader(): def __load_vocabulary_from_file__(self, vocabulary_path):
# read manifest """
manifest_data = [] Load vocabulary from file.
"""
if not os.path.exists(vocabulary_path):
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
vocab_lines = []
with open(vocabulary_path, 'r') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list
def __convert_text_to_char_index__(self, text, vocabulary):
"""
Convert text string to a list of character index integers.
"""
return [vocabulary[w] for w in text]
def __read_manifest__(self, manifest_path, max_duration, min_duration):
"""
Load and parse manifest file.
"""
manifest = []
for json_line in open(manifest_path): for json_line in open(manifest_path):
try: try:
json_data = json.loads(json_line) json_data = json.loads(json_line)
...@@ -153,63 +212,172 @@ def reader_creator(manifest_path, ...@@ -153,63 +212,172 @@ def reader_creator(manifest_path,
raise ValueError("Error reading manifest: %s" % str(e)) raise ValueError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration): json_data["duration"] >= min_duration):
manifest_data.append(json_data) manifest.append(json_data)
# sort (by duration) or shuffle manifest return manifest
if sort_by_duration:
manifest_data.sort(key=lambda x: x["duration"])
if shuffle:
random.shuffle(manifest_data)
# extract spectrogram feature
for instance in manifest_data:
spectrogram = spectrogram_from_file(instance["audio_filepath"])
text = parse_transcript(instance["text"], vocab_dict)
yield (spectrogram, text)
return reader def __padding_batch__(self, batch, padding_to=-1, flatten=False):
"""
Padding audio part of features (only in the time axis -- column axis)
with zeros, to make each instance in the batch share the same
audio feature shape.
If `padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1.
def padding_batch_reader(batch_reader, padding=[-1, -1], flatten=True): If `flatten` is set True, audio data will be flatten to be a 1-dim
""" ndarray. Default is False.
Padding for batches. Return a batch reader. """
Each instance in a batch will be padded to be of a same target shape.
The target shape is the largest shape among all the batch instances and
'padding' argument. Therefore, if padding is set [-1, -1], instance will be
padded to have the same shape just within each batch and the shape will
be different across batches; if padding is set
[VERY_LARGE_NUM, VERY_LARGE_NUM], instances in all batches will be padded to
have the same shape of [VERY_LARGE_NUM, VERY_LARGE_NUM].
:param batch_reader: Input batch reader.
:type batch_reader: callable
:param padding: Padding pattern. Details please refer to the above.
:type padding: list
:param flatten: Flatten the tensor to be one dimension.
:type flatten: bool
:return: Batch reader function.
:rtype: callable
"""
def padding_batch(batch):
new_batch = [] new_batch = []
# get target shape within batch # get target shape
nshape_list = [padding] max_length = max([audio.shape[1] for audio, text in batch])
for audio, text in batch: if padding_to != -1:
nshape_list.append(audio.shape) if padding_to < max_length:
target_shape = np.array(nshape_list).max(axis=0) raise ValueError("If padding_to is not -1, it should be greater"
" or equal to the original instance length.")
max_length = padding_to
# padding # padding
for audio, text in batch: for audio, text in batch:
pad_shape = target_shape - audio.shape padded_audio = np.zeros([audio.shape[0], max_length])
assert np.all(pad_shape >= 0) padded_audio[:, :audio.shape[1]] = audio
padded_audio = np.pad(
audio, [(0, pad_shape[0]), (0, pad_shape[1])], mode="constant")
if flatten: if flatten:
padded_audio = padded_audio.flatten() padded_audio = padded_audio.flatten()
new_batch.append((padded_audio, text)) new_batch.append((padded_audio, text))
return new_batch return new_batch
def new_batch_reader(): def instance_reader_creator(self,
for batch in batch_reader(): manifest_path,
yield padding_batch(batch) sort_by_duration=True,
shuffle=False):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Data reader function.
:rtype: callable
"""
if sort_by_duration and shuffle:
sort_by_duration = False
logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.")
def reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if sort_by_duration:
manifest.sort(key=lambda x: x["duration"])
if shuffle:
self.__random__.shuffle(manifest)
# extract spectrogram feature
for instance in manifest:
spectrogram = self.__audio_featurize__(
instance["audio_filepath"])
transcript = self.__text_featurize__(instance["text"])
yield (spectrogram, transcript)
return reader
def batch_reader_creator(self,
manifest_path,
batch_size,
padding_to=-1,
flatten=False,
sort_by_duration=True,
shuffle=False):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Audio features will be padded with zeros to make each instance in the
batch to share the same audio feature shape.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param batch_size: Instance number in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
instance_reader = self.instance_reader_creator(
manifest_path=manifest_path,
sort_by_duration=sort_by_duration,
shuffle=shuffle)
batch = []
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self.__padding_batch__(batch, padding_to, flatten)
batch = []
if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten)
return batch_reader
def vocabulary_size(self):
"""
Get vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self.__vocab_list__)
def vocabulary_dict(self):
"""
Get vocabulary in dict.
:return: Vocabulary in dict.
:rtype: dict
"""
return self.__vocab_dict__
def vocabulary_list(self):
"""
Get vocabulary in list.
:return: Vocabulary in list
:rtype: list
"""
return self.__vocab_list__
def data_name_feeding(self):
"""
Get feeddings (data field name and corresponding field id).
return new_batch_reader :return: Feeding dict.
:rtype: dict
"""
feeding = {
"audio_spectrogram": 0,
"transcript_text": 1,
}
return feeding
...@@ -5,16 +5,18 @@ ...@@ -5,16 +5,18 @@
import paddle.v2 as paddle import paddle.v2 as paddle
import argparse import argparse
import gzip import gzip
import time
import sys import sys
from model import deep_speech2 from model import deep_speech2
import audio_data_utils from audio_data_utils import DataGenerator
import numpy as np
#TODO: add WER metric #TODO: add WER metric
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 trainer.') description='Simplified version of DeepSpeech2 trainer.')
parser.add_argument( parser.add_argument(
"--batch_size", default=512, type=int, help="Minibatch size.") "--batch_size", default=32, type=int, help="Minibatch size.")
parser.add_argument("--trainer", default=1, type=int, help="Trainer number.") parser.add_argument("--trainer", default=1, type=int, help="Trainer number.")
parser.add_argument( parser.add_argument(
"--num_passes", default=20, type=int, help="Training pass number.") "--num_passes", default=20, type=int, help="Training pass number.")
...@@ -23,7 +25,7 @@ parser.add_argument( ...@@ -23,7 +25,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--num_rnn_layers", default=5, type=int, help="RNN layer number.") "--num_rnn_layers", default=5, type=int, help="RNN layer number.")
parser.add_argument( parser.add_argument(
"--rnn_layer_size", default=256, type=int, help="RNN layer cell number.") "--rnn_layer_size", default=512, type=int, help="RNN layer cell number.")
parser.add_argument( parser.add_argument(
"--use_gpu", default=True, type=bool, help="Use gpu or not.") "--use_gpu", default=True, type=bool, help="Use gpu or not.")
parser.add_argument( parser.add_argument(
...@@ -37,13 +39,45 @@ def train(): ...@@ -37,13 +39,45 @@ def train():
""" """
DeepSpeech2 training. DeepSpeech2 training.
""" """
# create data readers
data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt',
normalizer_manifest_path='./libri.manifest.train',
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.dev.small',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.dev.small',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.test',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
# create network config # create network config
dict_size = audio_data_utils.get_vocabulary_size() dict_size = data_generator.vocabulary_size()
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram",
height=161, height=161,
width=1000, width=2000,
type=paddle.data_type.dense_vector(161000)) type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(dict_size))
...@@ -58,47 +92,26 @@ def train(): ...@@ -58,47 +92,26 @@ def train():
# create parameters and optimizer # create parameters and optimizer
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=5e-4, gradient_clipping_threshold=400) learning_rate=5e-5, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer) cost=cost, parameters=parameters, update_equation=optimizer)
# create data readers
feeding = {
"audio_spectrogram": 0,
"transcript_text": 1,
}
train_batch_reader_with_sortagrad = audio_data_utils.padding_batch_reader(
paddle.batch(
audio_data_utils.reader_creator(
manifest_path="./libri.manifest.train", sort_by_duration=True),
batch_size=args.batch_size // args.trainer),
padding=[-1, 1000])
train_batch_reader_without_sortagrad = audio_data_utils.padding_batch_reader(
paddle.batch(
audio_data_utils.reader_creator(
manifest_path="./libri.manifest.train",
sort_by_duration=False,
shuffle=True),
batch_size=args.batch_size // args.trainer),
padding=[-1, 1000])
test_batch_reader = audio_data_utils.padding_batch_reader(
paddle.batch(
audio_data_utils.reader_creator(
manifest_path="./libri.manifest.dev", sort_by_duration=False),
batch_size=args.batch_size // args.trainer),
padding=[-1, 1000])
# create event handler # create event handler
def event_handler(event): def event_handler(event):
global start_time
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0: if event.batch_id % 10 == 0:
print "/nPass: %d, Batch: %d, TrainCost: %f" % ( print "\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, event.cost)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.BeginPass):
start_time = time.time()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_batch_reader, feeding=feeding) result = trainer.test(reader=test_batch_reader, feeding=feeding)
print "Pass: %d, TestCost: %s" % (event.pass_id, result.cost) print "\n------- Time: %d, Pass: %d, TestCost: %s" % (
time.time() - start_time, event.pass_id, result.cost)
with gzip.open("params.tar.gz", 'w') as f: with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
...@@ -106,14 +119,14 @@ def train(): ...@@ -106,14 +119,14 @@ def train():
# first pass with sortagrad # first pass with sortagrad
if args.use_sortagrad: if args.use_sortagrad:
trainer.train( trainer.train(
reader=train_batch_reader_with_sortagrad, reader=train_batch_reader_sortagrad,
event_handler=event_handler, event_handler=event_handler,
num_passes=1, num_passes=1,
feeding=feeding) feeding=feeding)
args.num_passes -= 1 args.num_passes -= 1
# other passes without sortagrad # other passes without sortagrad
trainer.train( trainer.train(
reader=train_batch_reader_without_sortagrad, reader=train_batch_reader_nosortagrad,
event_handler=event_handler, event_handler=event_handler,
num_passes=args.num_passes, num_passes=args.num_passes,
feeding=feeding) feeding=feeding)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册