提交 1635e000 编写于 作者: H Hui Zhang

fix cmvn

上级 2aed2752
...@@ -22,9 +22,12 @@ from paddle.io import Dataset ...@@ -22,9 +22,12 @@ from paddle.io import Dataset
from deepspeech.frontend.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
from deepspeech.frontend.utility import load_cmvn from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = ["FeatureNormalizer"] __all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog()
# https://github.com/PaddlePaddle/Paddle/pull/31481 # https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object): class CollateFunc(object):
...@@ -176,8 +179,8 @@ class FeatureNormalizer(object): ...@@ -176,8 +179,8 @@ class FeatureNormalizer(object):
wav_number += batch_size wav_number += batch_size
if wav_number % 1000 == 0: if wav_number % 1000 == 0:
print('process {} wavs,{} frames'.format(wav_number, logger.info('process {} wavs,{} frames'.format(wav_number,
all_number)) all_number))
self.cmvn_info = { self.cmvn_info = {
'mean_stat': list(all_mean_stat.tolist()), 'mean_stat': list(all_mean_stat.tolist()),
......
...@@ -17,16 +17,22 @@ import os ...@@ -17,16 +17,22 @@ import os
import socket import socket
import sys import sys
FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
DATE_FMT_STR = '%Y/%m/%d %H:%M:%S'
logging.basicConfig(
level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR)
def find_log_dir(log_dir=None): def find_log_dir(log_dir=None):
"""Returns the most suitable directory to put log files into. """Returns the most suitable directory to put log files into.
Args: Args:
log_dir: str|None, if specified, the logfile(s) will be created in that log_dir: str|None, if specified, the logfile(s) will be created in that
directory. Otherwise if the --log_dir command-line flag is provided, directory. Otherwise if the --log_dir command-line flag is provided,
the logfile will be created in that directory. Otherwise the logfile the logfile will be created in that directory. Otherwise the logfile
will be created in a standard location. will be created in a standard location.
Raises: Raises:
FileNotFoundError: raised when it cannot find a log directory. FileNotFoundError: raised when it cannot find a log directory.
""" """
# Get a list of possible log dirs (will try to use them in order). # Get a list of possible log dirs (will try to use them in order).
if log_dir: if log_dir:
...@@ -45,22 +51,22 @@ def find_log_dir(log_dir=None): ...@@ -45,22 +51,22 @@ def find_log_dir(log_dir=None):
def find_log_dir_and_names(program_name=None, log_dir=None): def find_log_dir_and_names(program_name=None, log_dir=None):
"""Computes the directory and filename prefix for log file. """Computes the directory and filename prefix for log file.
Args: Args:
program_name: str|None, the filename part of the path to the program that program_name: str|None, the filename part of the path to the program that
is running without its extension. e.g: if your program is called is running without its extension. e.g: if your program is called
'usr/bin/foobar.py' this method should probably be called with 'usr/bin/foobar.py' this method should probably be called with
program_name='foobar' However, this is just a convention, you can program_name='foobar' However, this is just a convention, you can
pass in any string you want, and it will be used as part of the pass in any string you want, and it will be used as part of the
log filename. If you don't pass in anything, the default behavior log filename. If you don't pass in anything, the default behavior
is as described in the example. In python standard logging mode, is as described in the example. In python standard logging mode,
the program_name will be prepended with py_ if it is the program_name the program_name will be prepended with py_ if it is the program_name
argument is omitted. argument is omitted.
log_dir: str|None, the desired log directory. log_dir: str|None, the desired log directory.
Returns: Returns:
(log_dir, file_prefix, symlink_prefix) (log_dir, file_prefix, symlink_prefix)
Raises: Raises:
FileNotFoundError: raised in Python 3 when it cannot find a log directory. FileNotFoundError: raised in Python 3 when it cannot find a log directory.
OSError: raised in Python 2 when it cannot find a log directory. OSError: raised in Python 2 when it cannot find a log directory.
""" """
if not program_name: if not program_name:
# Strip the extension (foobar.par becomes foobar, and # Strip the extension (foobar.par becomes foobar, and
...@@ -123,12 +129,10 @@ class Log(): ...@@ -123,12 +129,10 @@ class Log():
pass pass
if not self.logger.hasHandlers(): if not self.logger.hasHandlers():
format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR)
formatter = logging.Formatter(
fmt=format, datefmt='%Y/%m/%d %H:%M:%S')
fh = logging.FileHandler(Log.log_name) fh = logging.FileHandler(Log.log_name)
fh.setFormatter(formatter)
fh.setLevel(logging.DEBUG) fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
self.logger.addHandler(fh) self.logger.addHandler(fh)
ch = logging.StreamHandler() ch = logging.StreamHandler()
...@@ -136,9 +140,6 @@ class Log(): ...@@ -136,9 +140,6 @@ class Log():
ch.setFormatter(formatter) ch.setFormatter(formatter)
self.logger.addHandler(ch) self.logger.addHandler(ch)
#fh.close()
#ch.close()
# stop propagate for propagating may print # stop propagate for propagating may print
# log multiple times # log multiple times
self.logger.propagate = False self.logger.propagate = False
......
...@@ -51,6 +51,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -51,6 +51,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--stride_ms=10.0 \ --stride_ms=10.0 \
--window_ms=25.0 \ --window_ms=25.0 \
--sample_rate=16000 \ --sample_rate=16000 \
--use_dB_normalization=False \
--num_samples=-1 \ --num_samples=-1 \
--num_workers=16 \ --num_workers=16 \
--output_path="data/mean_std.json" --output_path="data/mean_std.json"
......
...@@ -73,6 +73,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -73,6 +73,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--sample_rate=16000 \ --sample_rate=16000 \
--stride_ms=10.0 \ --stride_ms=10.0 \
--window_ms=25.0 \ --window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=${num_workers} \ --num_workers=${num_workers} \
--output_path="data/mean_std.json" --output_path="data/mean_std.json"
......
...@@ -57,6 +57,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then ...@@ -57,6 +57,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--sample_rate=16000 \ --sample_rate=16000 \
--stride_ms=10.0 \ --stride_ms=10.0 \
--window_ms=25.0 \ --window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=2 \ --num_workers=2 \
--output_path="data/mean_std.json" --output_path="data/mean_std.json"
......
...@@ -21,6 +21,8 @@ import paddle ...@@ -21,6 +21,8 @@ import paddle
def main(args): def main(args):
paddle.set_device('cpu')
val_scores = [] val_scores = []
beat_val_scores = [] beat_val_scores = []
selected_epochs = [] selected_epochs = []
......
...@@ -25,17 +25,19 @@ parser = argparse.ArgumentParser(description=__doc__) ...@@ -25,17 +25,19 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('num_samples', int, -1, "# of samples to for statistics.") add_arg('num_samples', int, -1, "# of samples to for statistics.")
add_arg('specgram_type', str, add_arg('specgram_type', str,
'linear', 'linear',
"Audio feature type. Options: linear, mfcc, fbank.", "Audio feature type. Options: linear, mfcc, fbank.",
choices=['linear', 'mfcc', 'fbank']) choices=['linear', 'mfcc', 'fbank'])
add_arg('feat_dim', int, 13, "Audio feature dim.") add_arg('feat_dim', int, 13, "Audio feature dim.")
add_arg('delta_delta', bool, add_arg('delta_delta', bool, False, "Audio feature with delta delta.")
False, add_arg('stride_ms', float, 10.0, "stride length in ms.")
"Audio feature with delta delta.") add_arg('window_ms', float, 20.0, "stride length in ms.")
add_arg('stride_ms', float, 10.0, "stride length in ms.") add_arg('sample_rate', int, 16000, "target sample rate.")
add_arg('window_ms', float, 20.0, "stride length in ms.") add_arg('use_dB_normalization', bool, False, "do dB normalization.")
add_arg('sample_rate', int, 16000, "target sample rate.") add_arg('target_dB', int, -20, "target dB.")
add_arg('manifest_path', str, add_arg('manifest_path', str,
'data/librispeech/manifest.train', 'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.") "Filepath of manifest to compute normalizer's mean and stddev.")
...@@ -63,8 +65,8 @@ def main(): ...@@ -63,8 +65,8 @@ def main():
n_fft=None, n_fft=None,
max_freq=None, max_freq=None,
target_sample_rate=args.sample_rate, target_sample_rate=args.sample_rate,
use_dB_normalization=True, use_dB_normalization=args.use_dB_normalization,
target_dB=-20, target_dB=args.target_dB,
dither=0.0) dither=0.0)
def augment_and_featurize(audio_segment): def augment_and_featurize(audio_segment):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册