提交 fa70a024 编写于 作者: H Hui Zhang

model;new updater;benchmark;chians; can run libri/s1

上级 f3338265
...@@ -21,7 +21,7 @@ from paddle.inference import create_predictor ...@@ -21,7 +21,7 @@ from paddle.inference import create_predictor
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.utils.socket_server import AsrTCPServer from deepspeech.utils.socket_server import AsrTCPServer
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.socket_server import AsrRequestHandler from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.utils.socket_server import AsrTCPServer from deepspeech.utils.socket_server import AsrTCPServer
......
...@@ -30,11 +30,17 @@ def main(config, args): ...@@ -30,11 +30,17 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save jit model to
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument(
"--model_type", type=str, default='offline', help="offline/online")
args = parser.parse_args() args = parser.parse_args()
print("model_type:{}".format(args.model_type))
print_arguments(args) print_arguments(args)
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults(args.model_type)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.opts: if args.opts:
......
...@@ -30,11 +30,17 @@ def main(config, args): ...@@ -30,11 +30,17 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults(args.model_type)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.opts: if args.opts:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
def main_sp(config, args):
exp = ExportTester(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
#load jit model from
parser.add_argument(
"--export_path", type=str, help="path of the jit model to save")
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
args = parser.parse_args()
print_arguments(args, globals())
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
import os
import sys
from pathlib import Path
import paddle
import soundfile
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
from deepspeech.utils.utility import print_arguments
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class DeepSpeech2Tester_hub():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
self.collate_fn_test = SpeechCollator.from_config(config)
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
#replace the '<space>' with ' '
result_transcripts = [
self._text_featurizer.detokenize(sentence)
for sentence in result_transcripts
]
return result_transcripts
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
self.model.eval()
cfg = self.config
audio_file = self.audio_file
collate_fn_test = self.collate_fn_test
audio, _ = collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
vocab_list = collate_fn_test.vocab_list
result_transcripts = self.compute_result_transcripts(
audio, audio_len, vocab_list, cfg.decoding)
logger.info("result_transcripts: " + result_transcripts[0])
def run_test(self):
self.resume()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
self.setup_output_dir()
self.setup_checkpointer()
self.setup_model()
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
def setup_model(self):
config = self.config.clone()
with UpdateConfig(config):
config.model.feat_size = self.collate_fn_test.feature_size
config.model.dict_size = self.collate_fn_test.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
self.model = model
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = checkpoint_dir
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
def resume(self):
"""Resume from the checkpoint at checkpoints in the output
directory or load a specified checkpoint.
"""
params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
def check(audio_file):
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
def main_sp(config, args):
exp = DeepSpeech2Tester_hub(config, args)
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
parser.add_argument("--audio_file", type=str, help='audio file path')
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if not os.path.isfile(args.audio_file):
print("Please input the audio file path")
sys.exit(-1)
check(args.audio_file)
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)
...@@ -27,7 +27,7 @@ def main_sp(config, args): ...@@ -27,7 +27,7 @@ def main_sp(config, args):
def main(config, args): def main(config, args):
if args.device == "gpu" and args.nprocs > 1: if args.nprocs > 0:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else: else:
main_sp(config, args) main_sp(config, args)
...@@ -35,11 +35,14 @@ def main(config, args): ...@@ -35,11 +35,14 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
args = parser.parse_args() args = parser.parse_args()
print("model_type:{}".format(args.model_type))
print_arguments(args, globals()) print_arguments(args, globals())
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
config = get_cfg_defaults() config = get_cfg_defaults(args.model_type)
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if args.opts: if args.opts:
......
...@@ -21,7 +21,7 @@ from paddle.io import DataLoader ...@@ -21,7 +21,7 @@ from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import add_arguments
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from yacs.config import CfgNode as CN from yacs.config import CfgNode as CN
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
_C = CN() _C = CN()
_C.data = CN( _C.data = CN(
...@@ -32,7 +32,7 @@ _C.data = CN( ...@@ -32,7 +32,7 @@ _C.data = CN(
window_ms=20.0, # ms window_ms=20.0, # ms
n_fft=None, # fft points n_fft=None, # fft points
max_freq=None, # None for samplerate/2 max_freq=None, # None for samplerate/2
specgram_type='linear', # 'linear', 'mfcc', 'fbank' spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank'
delat_delta=False, # 'mfcc', 'fbank' delat_delta=False, # 'mfcc', 'fbank'
target_sample_rate=16000, # target sample rate target_sample_rate=16000, # target sample rate
...@@ -46,16 +46,7 @@ _C.data = CN( ...@@ -46,16 +46,7 @@ _C.data = CN(
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
)) ))
_C.model = CN( _C.model = DeepSpeech2Model.params()
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
DeepSpeech2Model.params(_C.model)
_C.training = CN( _C.training = CN(
dict( dict(
...@@ -81,7 +72,7 @@ _C.decoding = CN( ...@@ -81,7 +72,7 @@ _C.decoding = CN(
)) ))
def get_cfg_defaults(): def get_cfg_defaults(model_type):
"""Get a yacs CfgNode object with default values for my_project.""" """Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered # Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern # This is for the "local variable" use pattern
......
...@@ -25,14 +25,15 @@ from deepspeech.io.collator import SpeechCollator ...@@ -25,14 +25,15 @@ from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.deepspeech2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -98,15 +99,27 @@ class DeepSpeech2Trainer(Trainer): ...@@ -98,15 +99,27 @@ class DeepSpeech2Trainer(Trainer):
return total_loss, num_seen_utts return total_loss, num_seen_utts
def setup_model(self): def setup_model(self):
config = self.config #config = self.config
model = DeepSpeech2Model( #model = DeepSpeech2Model(
feat_size=self.train_loader.dataset.feature_size, # feat_size=self.train_loader.dataset.feature_size,
dict_size=self.train_loader.dataset.vocab_size, # dict_size=self.train_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers, # num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, # num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, # rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru, # use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights) # share_rnn_weights=config.model.share_rnn_weights)
config = self.config.clone()
with UpdateConfig(config):
config.model.feat_size = self.train_loader.dataset.feature_size
config.model.dict_size = self.train_loader.dataset.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
......
...@@ -30,7 +30,7 @@ def main_sp(config, args): ...@@ -30,7 +30,7 @@ def main_sp(config, args):
def main(config, args): def main(config, args):
if args.device == "gpu" and args.nprocs > 1: if args.nprocs > 0:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else: else:
main_sp(config, args) main_sp(config, args)
......
...@@ -24,15 +24,15 @@ class AudioFeaturizer(object): ...@@ -24,15 +24,15 @@ class AudioFeaturizer(object):
Currently, it supports feature types of linear spectrogram and mfcc. Currently, it supports feature types of linear spectrogram and mfcc.
:param specgram_type: Specgram feature type. Options: 'linear'. :param spectrum_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str :type spectrum_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames. :param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float :type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames. :param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float :type window_ms: float
:param max_freq: When specgram_type is 'linear', only FFT bins :param max_freq: When spectrum_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned; when specgram_type is 'mfcc', max_feq is the returned; when spectrum_type is 'mfcc', max_feq is the
highest band edge of mel filters. highest band edge of mel filters.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or :param target_sample_rate: Audio are resampled (if upsampling or
...@@ -47,7 +47,7 @@ class AudioFeaturizer(object): ...@@ -47,7 +47,7 @@ class AudioFeaturizer(object):
""" """
def __init__(self, def __init__(self,
specgram_type: str='linear', spectrum_type: str='linear',
feat_dim: int=None, feat_dim: int=None,
delta_delta: bool=False, delta_delta: bool=False,
stride_ms=10.0, stride_ms=10.0,
...@@ -58,7 +58,7 @@ class AudioFeaturizer(object): ...@@ -58,7 +58,7 @@ class AudioFeaturizer(object):
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
dither=1.0): dither=1.0):
self._specgram_type = specgram_type self._spectrum_type = spectrum_type
# mfcc and fbank using `feat_dim` # mfcc and fbank using `feat_dim`
self._feat_dim = feat_dim self._feat_dim = feat_dim
# mfcc and fbank using `delta-delta` # mfcc and fbank using `delta-delta`
...@@ -113,27 +113,27 @@ class AudioFeaturizer(object): ...@@ -113,27 +113,27 @@ class AudioFeaturizer(object):
def feature_size(self): def feature_size(self):
"""audio feature size""" """audio feature size"""
feat_dim = 0 feat_dim = 0
if self._specgram_type == 'linear': if self._spectrum_type == 'linear':
fft_point = self._window_ms if self._fft_point is None else self._fft_point fft_point = self._window_ms if self._fft_point is None else self._fft_point
feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 + feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 +
1) 1)
elif self._specgram_type == 'mfcc': elif self._spectrum_type == 'mfcc':
# mfcc, delta, delta-delta # mfcc, delta, delta-delta
feat_dim = int(self._feat_dim * feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim) 3) if self._delta_delta else int(self._feat_dim)
elif self._specgram_type == 'fbank': elif self._spectrum_type == 'fbank':
# fbank, delta, delta-delta # fbank, delta, delta-delta
feat_dim = int(self._feat_dim * feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim) 3) if self._delta_delta else int(self._feat_dim)
else: else:
raise ValueError("Unknown specgram_type %s. " raise ValueError("Unknown spectrum_type %s. "
"Supported values: linear." % self._specgram_type) "Supported values: linear." % self._spectrum_type)
return feat_dim return feat_dim
def _compute_specgram(self, audio_segment): def _compute_specgram(self, audio_segment):
"""Extract various audio features.""" """Extract various audio features."""
sample_rate = audio_segment.sample_rate sample_rate = audio_segment.sample_rate
if self._specgram_type == 'linear': if self._spectrum_type == 'linear':
samples = audio_segment.samples samples = audio_segment.samples
return self._compute_linear_specgram( return self._compute_linear_specgram(
samples, samples,
...@@ -141,7 +141,7 @@ class AudioFeaturizer(object): ...@@ -141,7 +141,7 @@ class AudioFeaturizer(object):
stride_ms=self._stride_ms, stride_ms=self._stride_ms,
window_ms=self._window_ms, window_ms=self._window_ms,
max_freq=self._max_freq) max_freq=self._max_freq)
elif self._specgram_type == 'mfcc': elif self._spectrum_type == 'mfcc':
samples = audio_segment.to('int16') samples = audio_segment.to('int16')
return self._compute_mfcc( return self._compute_mfcc(
samples, samples,
...@@ -152,7 +152,7 @@ class AudioFeaturizer(object): ...@@ -152,7 +152,7 @@ class AudioFeaturizer(object):
max_freq=self._max_freq, max_freq=self._max_freq,
dither=self._dither, dither=self._dither,
delta_delta=self._delta_delta) delta_delta=self._delta_delta)
elif self._specgram_type == 'fbank': elif self._spectrum_type == 'fbank':
samples = audio_segment.to('int16') samples = audio_segment.to('int16')
return self._compute_fbank( return self._compute_fbank(
samples, samples,
...@@ -164,8 +164,8 @@ class AudioFeaturizer(object): ...@@ -164,8 +164,8 @@ class AudioFeaturizer(object):
dither=self._dither, dither=self._dither,
delta_delta=self._delta_delta) delta_delta=self._delta_delta)
else: else:
raise ValueError("Unknown specgram_type %s. " raise ValueError("Unknown spectrum_type %s. "
"Supported values: linear." % self._specgram_type) "Supported values: linear." % self._spectrum_type)
def _compute_linear_specgram(self, def _compute_linear_specgram(self,
samples, samples,
......
...@@ -27,16 +27,16 @@ class SpeechFeaturizer(object): ...@@ -27,16 +27,16 @@ class SpeechFeaturizer(object):
:param vocab_filepath: Filepath to load vocabulary for token indices :param vocab_filepath: Filepath to load vocabulary for token indices
conversion. conversion.
:type specgram_type: str :type spectrum_type: str
:param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'. :param spectrum_type: Specgram feature type. Options: 'linear', 'mfcc'.
:type specgram_type: str :type spectrum_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames. :param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float :type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames. :param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float :type window_ms: float
:param max_freq: When specgram_type is 'linear', only FFT bins :param max_freq: When spectrum_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned; when specgram_type is 'mfcc', max_freq is the returned; when spectrum_type is 'mfcc', max_freq is the
highest band edge of mel filters. highest band edge of mel filters.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or :param target_sample_rate: Speech are resampled (if upsampling or
...@@ -54,7 +54,7 @@ class SpeechFeaturizer(object): ...@@ -54,7 +54,7 @@ class SpeechFeaturizer(object):
unit_type, unit_type,
vocab_filepath, vocab_filepath,
spm_model_prefix=None, spm_model_prefix=None,
specgram_type='linear', spectrum_type='linear',
feat_dim=None, feat_dim=None,
delta_delta=False, delta_delta=False,
stride_ms=10.0, stride_ms=10.0,
...@@ -66,7 +66,7 @@ class SpeechFeaturizer(object): ...@@ -66,7 +66,7 @@ class SpeechFeaturizer(object):
target_dB=-20, target_dB=-20,
dither=1.0): dither=1.0):
self._audio_featurizer = AudioFeaturizer( self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type, spectrum_type=spectrum_type,
feat_dim=feat_dim, feat_dim=feat_dim,
delta_delta=delta_delta, delta_delta=delta_delta,
stride_ms=stride_ms, stride_ms=stride_ms,
......
...@@ -45,7 +45,7 @@ class TextFeaturizer(object): ...@@ -45,7 +45,7 @@ class TextFeaturizer(object):
self.sp = spm.SentencePieceProcessor() self.sp = spm.SentencePieceProcessor()
self.sp.Load(spm_model) self.sp.Load(spm_model)
def tokenize(self, text): def tokenize(self, text, replace_space=True):
if self.unit_type == 'char': if self.unit_type == 'char':
tokens = self.char_tokenize(text) tokens = self.char_tokenize(text)
elif self.unit_type == 'word': elif self.unit_type == 'word':
...@@ -68,7 +68,7 @@ class TextFeaturizer(object): ...@@ -68,7 +68,7 @@ class TextFeaturizer(object):
Args: Args:
text (str): Text to process. text (str): Text to process.
Returns: Returns:
List[int]: List of token indices. List[int]: List of token indices.
""" """
...@@ -81,7 +81,7 @@ class TextFeaturizer(object): ...@@ -81,7 +81,7 @@ class TextFeaturizer(object):
def defeaturize(self, idxs): def defeaturize(self, idxs):
"""Convert a list of token indices to text string, """Convert a list of token indices to text string,
ignore index after eos_id. ignore index after eos_id.
Args: Args:
idxs (List[int]): List of token indices. idxs (List[int]): List of token indices.
......
...@@ -32,6 +32,7 @@ IGNORE_ID = -1 ...@@ -32,6 +32,7 @@ IGNORE_ID = -1
SOS = "<sos/eos>" SOS = "<sos/eos>"
EOS = SOS EOS = SOS
UNK = "<unk>" UNK = "<unk>"
SPACE = " "
BLANK = "<blank>" BLANK = "<blank>"
...@@ -101,7 +102,7 @@ def rms_to_dbfs(rms: float): ...@@ -101,7 +102,7 @@ def rms_to_dbfs(rms: float):
"""Root Mean Square to dBFS. """Root Mean Square to dBFS.
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
dB = dBFS + 3.0103 dB = dBFS + 3.0103
dBFS = db - 3.0103 dBFS = db - 3.0103
e.g. 0 dB = -3.0103 dBFS e.g. 0 dB = -3.0103 dBFS
...@@ -116,26 +117,26 @@ def rms_to_dbfs(rms: float): ...@@ -116,26 +117,26 @@ def rms_to_dbfs(rms: float):
def max_dbfs(sample_data: np.ndarray): def max_dbfs(sample_data: np.ndarray):
"""Peak dBFS based on the maximum energy sample. """Peak dBFS based on the maximum energy sample.
Args: Args:
sample_data ([np.ndarray]): float array, [-1, 1]. sample_data ([np.ndarray]): float array, [-1, 1].
Returns: Returns:
float: dBFS float: dBFS
""" """
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
def mean_dbfs(sample_data): def mean_dbfs(sample_data):
"""Peak dBFS based on the RMS energy. """Peak dBFS based on the RMS energy.
Args: Args:
sample_data ([np.ndarray]): float array, [-1, 1]. sample_data ([np.ndarray]): float array, [-1, 1].
Returns: Returns:
float: dBFS float: dBFS
""" """
return rms_to_dbfs( return rms_to_dbfs(
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
...@@ -155,7 +156,7 @@ def gain_db_to_ratio(gain_db: float): ...@@ -155,7 +156,7 @@ def gain_db_to_ratio(gain_db: float):
def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
"""Nomalize audio to dBFS. """Nomalize audio to dBFS.
Args: Args:
sample_data (np.ndarray): input wave samples, [-1, 1]. sample_data (np.ndarray): input wave samples, [-1, 1].
dbfs (float, optional): target dBFS. Defaults to -3.0103. dbfs (float, optional): target dBFS. Defaults to -3.0103.
......
...@@ -35,7 +35,7 @@ def create_dataloader(manifest_path, ...@@ -35,7 +35,7 @@ def create_dataloader(manifest_path,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None, max_freq=None,
specgram_type='linear', spectrum_type='linear',
feat_dim=None, feat_dim=None,
delta_delta=False, delta_delta=False,
use_dB_normalization=True, use_dB_normalization=True,
...@@ -64,7 +64,7 @@ def create_dataloader(manifest_path, ...@@ -64,7 +64,7 @@ def create_dataloader(manifest_path,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
max_freq=max_freq, max_freq=max_freq,
specgram_type=specgram_type, spectrum_type=spectrum_type,
feat_dim=feat_dim, feat_dim=feat_dim,
delta_delta=delta_delta, delta_delta=delta_delta,
use_dB_normalization=use_dB_normalization, use_dB_normalization=use_dB_normalization,
......
...@@ -63,7 +63,7 @@ class ManifestDataset(Dataset): ...@@ -63,7 +63,7 @@ class ManifestDataset(Dataset):
n_fft=None, # fft points n_fft=None, # fft points
max_freq=None, # None for samplerate/2 max_freq=None, # None for samplerate/2
raw_wav=True, # use raw_wav or kaldi feature raw_wav=True, # use raw_wav or kaldi feature
specgram_type='linear', # 'linear', 'mfcc', 'fbank' spectrum_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank' feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank' delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither dither=1.0, # feature dither
...@@ -124,7 +124,7 @@ class ManifestDataset(Dataset): ...@@ -124,7 +124,7 @@ class ManifestDataset(Dataset):
n_fft=config.data.n_fft, n_fft=config.data.n_fft,
max_freq=config.data.max_freq, max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate, target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type, spectrum_type=config.data.spectrum_type,
feat_dim=config.data.feat_dim, feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta, delta_delta=config.data.delta_delta,
dither=config.data.dither, dither=config.data.dither,
...@@ -152,7 +152,7 @@ class ManifestDataset(Dataset): ...@@ -152,7 +152,7 @@ class ManifestDataset(Dataset):
n_fft=None, n_fft=None,
max_freq=None, max_freq=None,
target_sample_rate=16000, target_sample_rate=16000,
specgram_type='linear', spectrum_type='linear',
feat_dim=None, feat_dim=None,
delta_delta=False, delta_delta=False,
dither=1.0, dither=1.0,
...@@ -180,7 +180,7 @@ class ManifestDataset(Dataset): ...@@ -180,7 +180,7 @@ class ManifestDataset(Dataset):
n_fft (int, optional): fft points for rfft. Defaults to None. n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None. max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. spectrum_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True. use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
...@@ -200,7 +200,7 @@ class ManifestDataset(Dataset): ...@@ -200,7 +200,7 @@ class ManifestDataset(Dataset):
unit_type=unit_type, unit_type=unit_type,
vocab_filepath=vocab_filepath, vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix, spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type, spectrum_type=spectrum_type,
feat_dim=feat_dim, feat_dim=feat_dim,
delta_delta=delta_delta, delta_delta=delta_delta,
stride_ms=stride_ms, stride_ms=stride_ms,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2Model
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
from paddle.nn import functional as F
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['ConvStack', "conv_output_size"]
def conv_output_size(I, F, P, S):
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# Output size after Conv:
# By noting I the length of the input volume size,
# F the length of the filter,
# P the amount of zero padding,
# S the stride,
# then the output size O of the feature map along that dimension is given by:
# O = (I - F + Pstart + Pend) // S + 1
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
# When Pstart == Pend == 0
# O = (I - F - S) // S
# https://iq.opengenus.org/output-size-of-convolution/
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization.
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type kernel_size: int|tuple|list
:param num_channels_in: Number of input channels.
:type num_channels_in: int
:param num_channels_out: Number of output channels.
:type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
:type padding: int|tuple|list
:param act: Activation type, relu|brelu
:type act: string
:return: Batch norm layer after convolution layer.
:rtype: Variable
"""
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
padding, act):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
assert len(padding) == 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.conv = nn.Conv2D(
num_channels_in,
num_channels_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_attr=None,
bias_attr=False,
data_format='NCHW')
self.bn = nn.BatchNorm2D(
num_channels_out,
weight_attr=None,
bias_attr=None,
data_format='NCHW')
self.act = F.relu if act == 'relu' else brelu
def forward(self, x, x_len):
"""
x(Tensor): audio, shape [B, C, D, T]
"""
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
) // self.stride[1] + 1
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# TODO(Hui Zhang): not support bool multiply
masks = masks.type_as(x)
x = x.multiply(masks)
return x, x_len
class ConvStack(nn.Layer):
"""Convolution group with stacked convolution layers.
:param feat_size: audio feature dim.
:type feat_size: int
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
"""
def __init__(self, feat_size, num_stacks):
super().__init__()
self.feat_size = feat_size # D
self.num_stacks = num_stacks
self.conv_in = ConvBn(
num_channels_in=1,
num_channels_out=32,
kernel_size=(41, 11), #[D, T]
stride=(2, 3),
padding=(20, 5),
act='brelu')
out_channel = 32
convs = [
ConvBn(
num_channels_in=32,
num_channels_out=out_channel,
kernel_size=(21, 11),
stride=(2, 1),
padding=(10, 5),
act='brelu') for i in range(num_stacks - 1)
]
self.conv_stack = nn.LayerList(convs)
# conv output feat_dim
output_height = (feat_size - 1) // 2 + 1
for i in range(self.num_stacks - 1):
output_height = (output_height - 1) // 2 + 1
self.output_height = out_channel * output_height
def forward(self, x, x_len):
"""
x: shape [B, C, D, T]
x_len : shape [B]
"""
x, x_len = self.conv_in(x, x_len)
for i, conv in enumerate(self.conv_stack):
x, x_len = conv(x, x_len)
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
from typing import Optional
import paddle
from paddle import nn
from yacs.config import CfgNode
from deepspeech.models.ds2.conv import ConvStack
from deepspeech.models.ds2.rnn import RNNStack
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack
self.rnn = RNNStack(
i_size=i_size,
h_size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
@property
def output_size(self):
return self.rnn_size * 2
def forward(self, audio, audio_len):
"""Compute Encoder outputs
Args:
audio (Tensor): [B, Tmax, D]
text (Tensor): [B, Umax]
audio_len (Tensor): [B]
text_len (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
"""
# [B, T, D] -> [B, D, T]
audio = audio.transpose([0, 2, 1])
# [B, D, T] -> [B, C=1, D, T]
x = audio.unsqueeze(1)
x_lens = audio_len
# convolution group
x, x_lens = self.conv(x, x_lens)
# convert data from convolution feature map to sequence of vectors
#B, C, D, T = paddle.shape(x) # not work under jit
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
x = x.reshape([0, 0, -1]) #[B, T, C*D]
# remove padding part
x, x_lens = self.rnn(x, x_lens) #[B, T, D]
return x, x_lens
class DeepSpeech2Model(nn.Layer):
"""The DeepSpeech2 network structure.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward direction RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
ctc_grad_norm_type='instance', ))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0,
ctc_grad_norm_type='instance'):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
assert (self.encoder.output_size == rnn_size * 2)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=ctc_grad_norm_type)
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len = self.encoder(audio, audio_len)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(
feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2Model from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2Model
The model built from config.
"""
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
return model
class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=True,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights,
blank_id=blank_id)
def forward(self, audio, audio_len):
"""export model function
Args:
audio (Tensor): [B, T, D]
audio_len (Tensor): [B]
Returns:
probs: probs after softmax
"""
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return probs, eouts_len
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None, self.encoder.feat_size],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
return static_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.activation import brelu
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['RNNStack']
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size: int,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCell(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size: int,
hidden_size: int,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = None
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.tanh
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
r = self._gate_activation(x_r + h_r)
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int, share_weights: bool):
super().__init__()
self.share_weights = share_weights
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
h_size, bias_attr=None, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size: int, h_size: int):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.fw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
self.bw_bn = nn.BatchNorm1D(
hidden_size, bias_attr=None, data_format='NLC')
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self,
i_size: int,
h_size: int,
num_stacks: int,
use_gru: bool,
share_rnn_weights: bool):
super().__init__()
rnn_stacks = []
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
else:
rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.ModuleList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# TODO(Hui Zhang): not support bool multiply
masks = masks.type_as(x)
x = x.multiply(masks)
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .deepspeech2 import DeepSpeech2InferModelOnline
from .deepspeech2 import DeepSpeech2ModelOnline
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from deepspeech.modules.subsampling import Conv2dSubsampling4
class Conv2dSubsampling4Online(Conv2dSubsampling4):
def __init__(self, idim: int, odim: int, dropout_rate: float):
super().__init__(idim, odim, dropout_rate, None)
self.output_dim = ((idim - 1) // 2 - 1) // 2 * odim
self.receptive_field_length = 2 * (
3 - 1) + 3 # stride_1 * (kernel_size_2 - 1) + kerel_size_1
def forward(self, x: paddle.Tensor,
x_len: paddle.Tensor) -> [paddle.Tensor, paddle.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
#b, c, t, f = paddle.shape(x) #not work under jit
x = x.transpose([0, 2, 1, 3]).reshape([0, 0, -1])
x_len = ((x_len - 1) // 2 - 1) // 2
return x, x_len
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from yacs.config import CfgNode
from deepspeech.models.ds2_online.conv import Conv2dSubsampling4Online
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.utils import layer_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
class CRNNEncoder(nn.Layer):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False):
super().__init__()
self.rnn_size = rnn_size
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.num_rnn_layers = num_rnn_layers
self.num_fc_layers = num_fc_layers
self.rnn_direction = rnn_direction
self.fc_layers_size_list = fc_layers_size_list
self.use_gru = use_gru
self.conv = Conv2dSubsampling4Online(feat_size, 32, dropout_rate=0.0)
self.output_dim = self.conv.output_dim
i_size = self.conv.output_dim
self.rnn = nn.LayerList()
self.layernorm_list = nn.LayerList()
self.fc_layers_list = nn.LayerList()
if rnn_direction == 'bidirect' or rnn_direction == 'bidirectional':
layernorm_size = 2 * rnn_size
elif rnn_direction == 'forward':
layernorm_size = rnn_size
else:
raise Exception("Wrong rnn direction")
for i in range(0, num_rnn_layers):
if i == 0:
rnn_input_size = i_size
else:
rnn_input_size = layernorm_size
if use_gru is True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
else:
self.rnn.append(
nn.LSTM(
input_size=rnn_input_size,
hidden_size=rnn_size,
num_layers=1,
direction=rnn_direction))
self.layernorm_list.append(nn.LayerNorm(layernorm_size))
self.output_dim = layernorm_size
fc_input_size = layernorm_size
for i in range(self.num_fc_layers):
self.fc_layers_list.append(
nn.Linear(fc_input_size, fc_layers_size_list[i]))
fc_input_size = fc_layers_size_list[i]
self.output_dim = fc_layers_size_list[i]
@property
def output_size(self):
return self.output_dim
def forward(self, x, x_lens, init_state_h_box=None, init_state_c_box=None):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
init_state_h_box(Tensor): init_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
init_state_c_box(Tensor): init_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
Return:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
if init_state_h_box is not None:
init_state_list = None
if self.use_gru is True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list
else:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_c_list = paddle.split(
init_state_c_box, self.num_rnn_layers, axis=0)
init_state_list = [(init_state_h_list[i], init_state_c_list[i])
for i in range(self.num_rnn_layers)]
else:
init_state_list = [None] * self.num_rnn_layers
x, x_lens = self.conv(x, x_lens)
final_chunk_state_list = []
for i in range(0, self.num_rnn_layers):
x, final_state = self.rnn[i](x, init_state_list[i],
x_lens) #[B, T, D]
final_chunk_state_list.append(final_state)
x = self.layernorm_list[i](x)
for i in range(self.num_fc_layers):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru is True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box
else:
final_chunk_state_h_list = [
final_chunk_state_list[i][0] for i in range(self.num_rnn_layers)
]
final_chunk_state_c_list = [
final_chunk_state_list[i][1] for i in range(self.num_rnn_layers)
]
final_chunk_state_h_box = paddle.concat(
final_chunk_state_h_list, axis=0)
final_chunk_state_c_box = paddle.concat(
final_chunk_state_c_list, axis=0)
return x, x_lens, final_chunk_state_h_box, final_chunk_state_c_box
def forward_chunk_by_chunk(self, x, x_lens, decoder_chunk_size=8):
"""Compute Encoder outputs
Args:
x (Tensor): [B, T, D]
x_lens (Tensor): [B]
decoder_chunk_size: The chunk size of decoder
Returns:
eouts_list (List of Tensor): The list of encoder outputs in chunk_size: [B, chunk_size, D] * num_chunks
eouts_lens_list (List of Tensor): The list of encoder length in chunk_size: [B] * num_chunks
final_state_h_box(Tensor): final_states h for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
final_state_c_box(Tensor): final_states c for RNN layers: [num_rnn_layers * num_directions, batch_size, hidden_size]
"""
subsampling_rate = self.conv.subsampling_rate
receptive_field_length = self.conv.receptive_field_length
chunk_size = (decoder_chunk_size - 1
) * subsampling_rate + receptive_field_length
chunk_stride = subsampling_rate * decoder_chunk_size
max_len = x.shape[1]
assert (chunk_size <= max_len)
eouts_chunk_list = []
eouts_chunk_lens_list = []
if (max_len - chunk_size) % chunk_stride != 0:
padding_len = chunk_stride - (max_len - chunk_size) % chunk_stride
else:
padding_len = 0
padding = paddle.zeros((x.shape[0], padding_len, x.shape[2]))
padded_x = paddle.concat([x, padding], axis=1)
num_chunk = (max_len + padding_len - chunk_size) / chunk_stride + 1
num_chunk = int(num_chunk)
chunk_state_h_box = None
chunk_state_c_box = None
final_state_h_box = None
final_state_c_box = None
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
x_chunk = padded_x[:, start:end, :]
x_len_left = paddle.where(x_lens - i * chunk_stride < 0,
paddle.zeros_like(x_lens),
x_lens - i * chunk_stride)
x_chunk_len_tmp = paddle.ones_like(x_lens) * chunk_size
x_chunk_lens = paddle.where(x_len_left < x_chunk_len_tmp,
x_len_left, x_chunk_len_tmp)
eouts_chunk, eouts_chunk_lens, chunk_state_h_box, chunk_state_c_box = self.forward(
x_chunk, x_chunk_lens, chunk_state_h_box, chunk_state_c_box)
eouts_chunk_list.append(eouts_chunk)
eouts_chunk_lens_list.append(eouts_chunk_lens)
final_state_h_box = chunk_state_h_box
final_state_c_box = chunk_state_c_box
return eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box
class DeepSpeech2ModelOnline(nn.Layer):
"""The DeepSpeech2 network structure for online.
:param audio: Audio spectrogram data layer.
:type audio: Variable
:param text: Transcription text data layer.
:type text: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param feat_size: feature size for audio.
:type feat_size: int
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param num_fc_layers: Number of stacking FC layers.
:type num_fc_layers: int
:param fc_layers_size_list: The list of FC layer sizes.
:type fc_layers_size_list: [int,]
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=4, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
blank_id=0, # index of blank in vocob.txt
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
rnn_size=rnn_size,
use_gru=use_gru)
self.decoder = CTCDecoder(
odim=dict_size, # <blank> is in vocab
enc_n_units=self.encoder.output_size,
blank_id=blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type='instance')
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
loss (Tenosr): [1]
"""
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
loss = self.decoder(eouts, eouts_len, text, text_len)
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
DeepSpeech2ModelOnline
The model built from pretrained result.
"""
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
rnn_direction=config.model.rnn_direction,
num_fc_layers=config.model.num_fc_layers,
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
blank_id=config.model.blank_id)
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
@classmethod
def from_config(cls, config):
"""Build a DeepSpeec2ModelOnline from config
Parameters
config: yacs.config.CfgNode
config.model
Returns
-------
DeepSpeech2ModelOnline
The model built from config.
"""
model = cls(feat_size=config.feat_size,
dict_size=config.dict_size,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
rnn_direction=config.rnn_direction,
num_fc_layers=config.num_fc_layers,
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru,
blank_id=config.blank_id)
return model
class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline):
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=4,
rnn_size=1024,
rnn_direction='forward',
num_fc_layers=2,
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0):
super().__init__(
feat_size=feat_size,
dict_size=dict_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,
rnn_size=rnn_size,
rnn_direction=rnn_direction,
num_fc_layers=num_fc_layers,
fc_layers_size_list=fc_layers_size_list,
use_gru=use_gru,
blank_id=blank_id)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box):
eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder(
audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box)
probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box
def export(self):
static_model = paddle.jit.to_static(
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None,
self.encoder.feat_size], #[B, chunk_size, feat_dim]
dtype='float32'),
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32'),
paddle.static.InputSpec(
shape=[None, None, None], dtype='float32')
])
return static_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2 import U2InferModel
from .u2 import U2Model
from .updater import U2Evaluator
from .updater import U2Updater
__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"]
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import paddle
from paddle import distributed as dist
from deepspeech.training.extensions.evaluator import StandardEvaluator
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.training.updaters.standard_updater import StandardUpdater
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class U2Evaluator(StandardEvaluator):
def __init__(self, model, dataloader):
super().__init__(model, dataloader)
self.msg = ""
self.num_seen_utts = 0
self.total_loss = 0.0
def evaluate_core(self, batch):
self.msg = "Valid: Rank: {}, ".format(dist.get_rank())
losses_dict = {}
loss, attention_loss, ctc_loss = self.model(*batch[1:])
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
self.num_seen_utts += num_utts
self.total_loss += float(loss) * num_utts
losses_dict['loss'] = float(loss)
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
for k, v in losses_dict.items():
report("eval/" + k, v)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
logger.info(self.msg)
return self.total_loss, self.num_seen_utts
class U2Updater(StandardUpdater):
def __init__(self,
model,
optimizer,
scheduler,
dataloader,
init_state=None,
accum_grad=1,
**kwargs):
super().__init__(
model, optimizer, scheduler, dataloader, init_state=init_state)
self.accum_grad = accum_grad
self.forward_count = 0
self.msg = ""
def update_core(self, batch):
"""One Step
Args:
batch (List[Object]): utts, xs, xlens, ys, ylens
"""
losses_dict = {}
self.msg = "Rank: {}, ".format(dist.get_rank())
# forward
batch_size = batch[1].shape[0]
loss, attention_loss, ctc_loss = self.model(*batch[1:])
# loss div by `batch_size * accum_grad`
loss /= self.accum_grad
# loss backward
if (self.forward_count + 1) != self.accum_grad:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# loss info
losses_dict['loss'] = float(loss) * self.accum_grad
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
# report loss
for k, v in losses_dict.items():
report("train/" + k, v)
# loss msg
self.msg += "batch size: {}, ".format(batch_size)
self.msg += "accum: {}, ".format(self.accum_grad)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
# Truncate the graph
loss.detach()
# update parameters
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
self.optimizer.step()
self.optimizer.clear_grad()
self.scheduler.step()
def update(self):
# model is default in train mode
# training for a step is implemented here
with Timer("data time cost:{}"):
batch = self.read_batch()
with Timer("step time cost:{}"):
self.update_core(batch)
# #iterations with accum_grad > 1
# Ref.: https://github.com/espnet/espnet/issues/777
if self.forward_count == 0:
self.state.iteration += 1
if self.updates_per_epoch is not None:
if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1
此差异已折叠。
...@@ -16,15 +16,19 @@ from paddle import nn ...@@ -16,15 +16,19 @@ from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from typeguard import check_argument_types from typeguard import check_argument_types
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
from deepspeech.decoders.swig_wrapper import Scorer
from deepspeech.modules.loss import CTCLoss from deepspeech.modules.loss import CTCLoss
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
try:
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
except Exception as e:
logger.info("ctcdecoder not installed!")
__all__ = ['CTCDecoder'] __all__ = ['CTCDecoder']
...@@ -35,7 +39,8 @@ class CTCDecoder(nn.Layer): ...@@ -35,7 +39,8 @@ class CTCDecoder(nn.Layer):
blank_id=0, blank_id=0,
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True, reduction: bool=True,
batch_average: bool=True): batch_average: bool=True,
grad_norm_type: str="instance"):
"""CTC decoder """CTC decoder
Args: Args:
...@@ -44,19 +49,21 @@ class CTCDecoder(nn.Layer): ...@@ -44,19 +49,21 @@ class CTCDecoder(nn.Layer):
dropout_rate (float): dropout rate (0.0 ~ 1.0) dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average. batch_average (bool): do batch dim wise average.
grad_norm_type (str): one of 'instance', 'batch', 'frame', None.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
self.blank_id = blank_id self.blank_id = blank_id
self.odim = odim self.odim = odim
self.dropout_rate = dropout_rate self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = nn.Linear(enc_n_units, self.odim) self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none" reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss( self.criterion = CTCLoss(
blank=self.blank_id, blank=self.blank_id,
reduction=reduction_type, reduction=reduction_type,
batch_average=batch_average) batch_average=batch_average,
grad_norm_type=grad_norm_type)
# CTCDecoder LM Score handle # CTCDecoder LM Score handle
self._ext_scorer = None self._ext_scorer = None
...@@ -72,7 +79,7 @@ class CTCDecoder(nn.Layer): ...@@ -72,7 +79,7 @@ class CTCDecoder(nn.Layer):
Returns: Returns:
loss (Tenosr): ctc loss value, scalar. loss (Tenosr): ctc loss value, scalar.
""" """
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) logits = self.ctc_lo(self.dropout(hs_pad))
loss = self.criterion(logits, ys_pad, hlens, ys_lens) loss = self.criterion(logits, ys_pad, hlens, ys_lens)
return loss return loss
...@@ -132,7 +139,7 @@ class CTCDecoder(nn.Layer): ...@@ -132,7 +139,7 @@ class CTCDecoder(nn.Layer):
results = [] results = []
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder( output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list) probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
results.append(output_transcription) results.append(output_transcription)
return results return results
...@@ -212,13 +219,15 @@ class CTCDecoder(nn.Layer): ...@@ -212,13 +219,15 @@ class CTCDecoder(nn.Layer):
num_processes=num_processes, num_processes=num_processes,
ext_scoring_func=self._ext_scorer, ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob, cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n) cutoff_top_n=cutoff_top_n,
blank_id=self.blank_id)
results = [result[0][1] for result in beam_search_results] results = [result[0][1] for result in beam_search_results]
return results return results
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method): decoding_method):
if decoding_method == "ctc_beam_search": if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list) vocab_list)
...@@ -229,7 +238,7 @@ class CTCDecoder(nn.Layer): ...@@ -229,7 +238,7 @@ class CTCDecoder(nn.Layer):
"""ctc decoding with probs. """ctc decoding with probs.
Args: Args:
probs (Tenosr): activation after softmax probs (Tenosr): activation after softmax
logits_lens (Tenosr): audio output lens logits_lens (Tenosr): audio output lens
vocab_list ([type]): [description] vocab_list ([type]): [description]
decoding_method ([type]): [description] decoding_method ([type]): [description]
......
...@@ -23,7 +23,7 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] ...@@ -23,7 +23,7 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer): class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum', batch_average=False): def __init__(self, blank=0, reduction='sum', batch_average=False, grad_norm_type=None):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
...@@ -89,8 +89,8 @@ class LabelSmoothingLoss(nn.Layer): ...@@ -89,8 +89,8 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE) smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool): normalize_length (bool):
True, normalize loss by sequence length; True, normalize loss by sequence length;
False, normalize loss by batch size. False, normalize loss by batch size.
Defaults to False. Defaults to False.
""" """
...@@ -107,7 +107,7 @@ class LabelSmoothingLoss(nn.Layer): ...@@ -107,7 +107,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the (batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss. padding part which should not be calculated for loss.
Args: Args:
x (paddle.Tensor): prediction (batch, seqlen, class) x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor): target (paddle.Tensor):
......
...@@ -14,25 +14,39 @@ ...@@ -14,25 +14,39 @@
import argparse import argparse
class ExtendAction(argparse.Action):
"""
[Since Python 3.8, the "extend" is available directly in stdlib]
(https://docs.python.org/3.8/library/argparse.html#action).
If you only have to support 3.8+ then defining it yourself is no longer required.
Usage of stdlib "extend" action is exactly the same way as this answer originally described:
"""
def __call__(self, parser, namespace, values, option_string=None):
items = getattr(namespace, self.dest) or []
items.extend(values)
setattr(namespace, self.dest, items)
def default_argument_parser(): def default_argument_parser():
r"""A simple yet genral argument parser for experiments with parakeet. r"""A simple yet genral argument parser for experiments with parakeet.
This is used in examples with parakeet. And it is intended to be used by This is used in examples with parakeet. And it is intended to be used by
other experiments with parakeet. It requires a minimal set of command line other experiments with parakeet. It requires a minimal set of command line
arguments to start a training script. arguments to start a training script.
The ``--config`` and ``--opts`` are used for overwrite the deault The ``--config`` and ``--opts`` are used for overwrite the deault
configuration. configuration.
The ``--data`` and ``--output`` specifies the data path and output path. The ``--data`` and ``--output`` specifies the data path and output path.
Resuming training from existing progress at the output directory is the Resuming training from existing progress at the output directory is the
intended default behavior. intended default behavior.
The ``--checkpoint_path`` specifies the checkpoint to load from. The ``--checkpoint_path`` specifies the checkpoint to load from.
The ``--device`` and ``--nprocs`` specifies how to run the training. The ``--nprocs`` specifies how to run the training.
See Also See Also
-------- --------
parakeet.training.experiment parakeet.training.experiment
...@@ -42,33 +56,53 @@ def default_argument_parser(): ...@@ -42,33 +56,53 @@ def default_argument_parser():
the parser the parser
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.register('action', 'extend', ExtendAction)
# yapf: disable train_group = parser.add_argument_group(
# data and output title='Train Options', description=None)
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") train_group.add_argument(
parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") "--seed",
# parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") type=int,
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") default=None,
help="seed to use for paddle, np and random. None or 0 for random, else set seed."
# load from saved checkpoint )
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") train_group.add_argument(
"--nprocs",
# save jit model to type=int,
parser.add_argument("--export_path", type=str, help="path of the jit model to save") default=1,
help="number of parallel processes. 0 for cpu.")
# save asr result to train_group.add_argument(
parser.add_argument("--result_file", type=str, help="path of save the asr result") "--config", metavar="CONFIG_FILE", help="config file.")
train_group.add_argument(
# running "--output", metavar="CKPT_DIR", help="path to save checkpoint.")
parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], train_group.add_argument(
help="device type to use, cpu and gpu are supported.") "--checkpoint_path", type=str, help="path to load checkpoint")
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") train_group.add_argument(
"--opts",
action='extend',
nargs=2,
metavar=('key', 'val'),
help="overwrite --config field, passing (KEY VALUE) pairs")
train_group.add_argument(
"--dump-config", metavar="FILE", help="dump config to `this` file.")
# overwrite extra config and default config profile_group = parser.add_argument_group(
# parser.add_argument("--opts", nargs=argparse.REMAINDER, title='Benchmark Options', description=None)
# help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") profile_group.add_argument(
parser.add_argument("--opts", type=str, default=[], nargs='+', '--profiler-options',
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") type=str,
# yapd: enable default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
profile_group.add_argument(
'--benchmark-batch-size',
type=int,
default=None,
help='batch size for benchmark.')
profile_group.add_argument(
'--benchmark-max-step',
type=int,
default=None,
help='max iteration for benchmark.')
return parser return parser
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from .extension import Extension
def make_extension(trigger: Callable=None,
default_name: str=None,
priority: int=None,
finalizer: Callable=None,
initializer: Callable=None,
on_error: Callable=None):
"""Make an Extension-like object by injecting required attributes to it.
"""
if trigger is None:
trigger = Extension.trigger
if priority is None:
priority = Extension.priority
def decorator(ext):
ext.trigger = trigger
ext.default_name = default_name or ext.__name__
ext.priority = priority
ext.finalize = finalizer
ext.on_error = on_error
ext.initialize = initializer
return ext
return decorator
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from . import extension
from ..reporter import DictSummary
from ..reporter import ObsScope
from ..reporter import report
from ..timer import Timer
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class StandardEvaluator(extension.Extension):
trigger = (1, 'epoch')
default_name = 'validation'
priority = extension.PRIORITY_WRITER
name = None
def __init__(self, model: Layer, dataloader: DataLoader):
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
self.model = model
# dataloaders
self.dataloader = dataloader
def evaluate_core(self, batch):
# compute
self.model(batch) # you may report here
return
def evaluate_sync(self, data):
# dist sync `evaluate_core` outputs
if data is None:
return
numerator, denominator = data
if dist.get_world_size() > 1:
numerator = paddle.to_tensor(numerator)
denominator = paddle.to_tensor(denominator)
# the default operator in all_reduce function is sum.
dist.all_reduce(numerator)
dist.all_reduce(denominator)
value = numerator / denominator
value = float(value)
else:
value = numerator / denominator
# used for `snapshort` to do kbest save.
report("VALID/LOSS", value)
logger.info(f"Valid: all-reduce loss {value}")
def evaluate(self):
# switch to eval mode
for model in self.models.values():
model.eval()
# to average evaluation metrics
summary = DictSummary()
for batch in self.dataloader:
observation = {}
with ObsScope(observation):
# main evaluation computation here.
with paddle.no_grad():
self.evaluate_sync(self.evaluate_core(batch))
summary.add(observation)
summary = summary.compute_mean()
# switch to train mode
for model in self.models.values():
model.train()
return summary
def __call__(self, trainer=None):
# evaluate and report the averaged metric to current observation
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
with Timer("Eval Time Cost: {}"):
summary = self.evaluate()
for k, v in summary.items():
report(k, v)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200
PRIORITY_READER = 100
class Extension():
"""Extension to customize the behavior of Trainer."""
trigger = (1, 'iteration')
priority = PRIORITY_READER
name = None
@property
def default_name(self):
"""Default name of the extension, class name by default."""
return type(self).__name__
def __call__(self, trainer):
"""Main action of the extention. After each update, it is executed
when the trigger fires."""
raise NotImplementedError(
'Extension implementation must override __call__.')
def initialize(self, trainer):
"""Action that is executed once to get the corect trainer state.
It is called before training normally, but if the trainer restores
states with an Snapshot extension, this method should also be called.
"""
pass
def on_error(self, trainer, exc, tb):
"""Handles the error raised during training before finalization.
"""
pass
def finalize(self, trainer):
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
from . import extension
from ..reporter import get_observations
from ..updaters.trainer import Trainer
from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only
logger = Log(__name__).getlog()
def load_records(records_fp):
"""Load record files (json lines.)"""
with jsonlines.open(records_fp, 'r') as reader:
records = list(reader)
return records
class Snapshot(extension.Extension):
"""An extension to make snapshot of the updater object inside
the trainer. It is done by calling the updater's `save` method.
An Updater save its state_dict by default, which contains the
updater state, (i.e. epoch and iteration) and all the model
parameters and optimizer states. If the updater inside the trainer
subclasses StandardUpdater, everything is good to go.
Parameters
----------
checkpoint_dir : Union[str, Path]
The directory to save checkpoints into.
"""
trigger = (1, 'epoch')
priority = -100
default_name = "snapshot"
def __init__(self,
mode='latest',
max_size: int=5,
indicator=None,
less_better=True,
snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = []
assert mode in ('latest', 'kbest'), mode
if mode == 'kbest':
assert indicator is not None
self.mode = mode
self.indicator = indicator
self.less_is_better = less_better
self.max_size = max_size
self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1)
self.checkpoint_dir = None
def initialize(self, trainer: Trainer):
"""Setting up this extention."""
self.checkpoint_dir = trainer.out / "checkpoints"
# load existing records
record_path: Path = self.checkpoint_dir / "records.jsonl"
if record_path.exists():
self.records = load_records(record_path)
ckpt_path = self.records[-1]['path']
logger.info(f"Loading from an existing checkpoint {ckpt_path}")
trainer.updater.load(ckpt_path)
def on_error(self, trainer, exc, tb):
if self._snapshot_on_error:
self.save_checkpoint_and_update(trainer, 'latest')
def __call__(self, trainer: Trainer):
self.save_checkpoint_and_update(trainer, self.mode)
def full(self):
"""Whether the number of snapshots it keeps track of is greater
than the max_size."""
return (not self._save_all) and len(self.records) > self.max_size
@rank_zero_only
def save_checkpoint_and_update(self, trainer: Trainer, mode: str):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.np"
# add the new one
trainer.updater.save(path)
record = {
"time": str(datetime.now()),
'path': str(path.resolve()), # use absolute path
'iteration': iteration,
'epoch': epoch,
'indicator': get_observations()[self.indicator]
}
self.records.append(record)
# remove the earist
if self.full():
if mode == 'kbest':
self.records = sorted(
self.records,
key=lambda record: record['indicator'],
reverse=not self.less_is_better)
eariest_record = self.records[0]
os.remove(eariest_record["path"])
self.records.pop(0)
# update the record file
record_path = self.checkpoint_dir / "records.jsonl"
with jsonlines.open(record_path, 'w') as writer:
for record in self.records:
# jsonlines.open may return a Writer or a Reader
writer.write(record) # pylint: disable=no-member
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from visualdl import LogWriter
from . import extension
from ..updaters.trainer import Trainer
class VisualDL(extension.Extension):
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
writer uses the iteration from the updater's `iteration` as the global step to
add records.
"""
trigger = (1, 'iteration')
default_name = 'visualdl'
priority = extension.PRIORITY_READER
def __init__(self, output_dir):
self.writer = LogWriter(str(output_dir))
def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items():
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
def finalize(self, trainer):
self.writer.close()
...@@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): ...@@ -27,6 +27,9 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
def __init__(self, clip_norm): def __init__(self, clip_norm):
super().__init__(clip_norm) super().__init__(clip_norm)
def __repr__(self):
return f"{self.__class__.__name__}(global_clip_norm={self.clip_norm})"
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
...@@ -44,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): ...@@ -44,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
sum_square = layers.reduce_sum(square) sum_square = layers.reduce_sum(square)
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# debug log # debug log, not dump all since slow down train process
if i < 10: if i < 10:
logger.debug( logger.debug(
f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }")
...@@ -73,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): ...@@ -73,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
new_grad = layers.elementwise_mul(x=g, y=clip_var) new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
# debug log # debug log, not dump all since slow down train process
if i < 10: if i < 10:
logger.debug( logger.debug(
f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}"
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import Text
import paddle
from paddle.optimizer import Optimizer
from paddle.regularizer import L2Decay
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log
__all__ = ["OptimizerFactory"]
logger = Log(__name__).getlog()
OPTIMIZER_DICT = {
"sgd": "paddle.optimizer:SGD",
"momentum": "paddle.optimizer:Momentum",
"adadelta": "paddle.optimizer:Adadelta",
"adam": "paddle.optimizer:Adam",
"adamw": "paddle.optimizer:AdamW",
}
def register_optimizer(cls):
"""Register optimizer."""
alias = cls.__name__.lower()
OPTIMIZER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
return cls
@register_optimizer
class Noam(paddle.optimizer.Adam):
"""Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """
def __init__(self,
learning_rate=0,
beta1=0.9,
beta2=0.98,
epsilon=1e-9,
parameters=None,
weight_decay=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
super().__init__(
learning_rate=learning_rate,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
parameters=parameters,
weight_decay=weight_decay,
grad_clip=grad_clip,
lazy_mode=lazy_mode,
multi_precision=multi_precision,
name=name)
def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"learning_rate: {self._learning_rate}, "
echo += f"(beta1: {self._beta1} beta2: {self._beta2}), "
echo += f"epsilon: {self._epsilon}"
def dynamic_import_optimizer(module):
"""Import Optimizer class dynamically.
Args:
module (str): module_name:class_name or alias in `OPTIMIZER_DICT`
Returns:
type: Optimizer class
"""
module_class = dynamic_import(module, OPTIMIZER_DICT)
assert issubclass(module_class,
Optimizer), f"{module} does not implement Optimizer"
return module_class
class OptimizerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):
assert "parameters" in args, "parameters not in args."
assert "learning_rate" in args, "learning_rate not in args."
grad_clip = ClipGradByGlobalNormWithLog(
args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None
if weight_decay:
logger.info(f'<WeightDecay - {weight_decay}>')
if grad_clip:
logger.info(f'<GradClip - {grad_clip}>')
module_class = dynamic_import_optimizer(name.lower())
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
opt = instance_class(module_class, args)
if "__repr__" in vars(opt):
logger.info(f"{opt}")
else:
logger.info(
f"<Optimizer {module_class.__module__}.{module_class.__name__}> LR: {args['learning_rate']}"
)
return opt
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
from collections import defaultdict
OBSERVATIONS = None
@contextlib.contextmanager
def ObsScope(observations):
# make `observation` the target to report to.
# it is basically a dictionary that stores temporary observations
global OBSERVATIONS
old = OBSERVATIONS
OBSERVATIONS = observations
try:
yield
finally:
OBSERVATIONS = old
def get_observations():
global OBSERVATIONS
return OBSERVATIONS
def report(name, value):
# a simple function to report named value
# you can use it everywhere, it will get the default target and writ to it
# you can think of it as std.out
observations = get_observations()
if observations is None:
return
else:
observations[name] = value
class Summary():
"""Online summarization of a sequence of scalars.
Summary computes the statistics of given scalars online.
"""
def __init__(self):
self._x = 0.0
self._x2 = 0.0
self._n = 0
def add(self, value, weight=1):
"""Adds a scalar value.
Args:
value: Scalar value to accumulate. It is either a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
weight: An optional weight for the value. It is a NumPy scalar or
a zero-dimensional array (on CPU or GPU).
Default is 1 (integer).
"""
self._x += weight * value
self._x2 += weight * value * value
self._n += weight
def compute_mean(self):
"""Computes the mean."""
x, n = self._x, self._n
return x / n
def make_statistics(self):
"""Computes and returns the mean and standard deviation values.
Returns:
tuple: Mean and standard deviation values.
"""
x, n = self._x, self._n
mean = x / n
var = self._x2 / n - mean * mean
std = math.sqrt(var)
return mean, std
class DictSummary():
"""Online summarization of a sequence of dictionaries.
``DictSummary`` computes the statistics of a given set of scalars online.
It only computes the statistics for scalar values and variables of scalar
values in the dictionaries.
"""
def __init__(self):
self._summaries = defaultdict(Summary)
def add(self, d):
"""Adds a dictionary of scalars.
Args:
d (dict): Dictionary of scalars to accumulate. Only elements of
scalars, zero-dimensional arrays, and variables of
zero-dimensional arrays are accumulated. When the value
is a tuple, the second element is interpreted as a weight.
"""
summaries = self._summaries
for k, v in d.items():
w = 1
if isinstance(v, tuple):
v = v[0]
w = v[1]
summaries[k].add(v, weight=w)
def compute_mean(self):
"""Creates a dictionary of mean values.
It returns a single dictionary that holds a mean value for each entry
added to the summary.
Returns:
dict: Dictionary of mean values.
"""
return {
name: summary.compute_mean()
for name, summary in self._summaries.items()
}
def make_statistics(self):
"""Creates a dictionary of statistics.
It returns a single dictionary that holds mean and standard deviation
values for every entry added to the summary. For an entry of name
``'key'``, these values are added to the dictionary by names ``'key'``
and ``'key.std'``, respectively.
Returns:
dict: Dictionary of statistics of all entries.
"""
stats = {}
for name, summary in self._summaries.items():
mean, std = summary.make_statistics()
stats[name] = mean
stats[name + '.std'] = std
return stats
...@@ -11,18 +11,37 @@ ...@@ -11,18 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any
from typing import Dict
from typing import Text
from typing import Union from typing import Union
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
from typeguard import check_argument_types from typeguard import check_argument_types
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.dynamic_import import instance_class
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["WarmupLR"] __all__ = ["WarmupLR", "LRSchedulerFactory"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
SCHEDULER_DICT = {
"noam": "paddle.optimizer.lr:NoamDecay",
"expdecaylr": "paddle.optimizer.lr:ExponentialDecay",
"piecewisedecay": "paddle.optimizer.lr:PiecewiseDecay",
}
def register_scheduler(cls):
"""Register scheduler."""
alias = cls.__name__.lower()
SCHEDULER_DICT[cls.__name__.lower()] = cls.__module__ + ":" + cls.__name__
return cls
@register_scheduler
class WarmupLR(LRScheduler): class WarmupLR(LRScheduler):
"""The WarmupLR scheduler """The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following This scheduler is almost same as NoamLR Scheduler except for following
...@@ -40,7 +59,8 @@ class WarmupLR(LRScheduler): ...@@ -40,7 +59,8 @@ class WarmupLR(LRScheduler):
warmup_steps: Union[int, float]=25000, warmup_steps: Union[int, float]=25000,
learning_rate=1.0, learning_rate=1.0,
last_epoch=-1, last_epoch=-1,
verbose=False): verbose=False,
**kwargs):
assert check_argument_types() assert check_argument_types()
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
super().__init__(learning_rate, last_epoch, verbose) super().__init__(learning_rate, last_epoch, verbose)
...@@ -64,3 +84,45 @@ class WarmupLR(LRScheduler): ...@@ -64,3 +84,45 @@ class WarmupLR(LRScheduler):
None None
''' '''
self.step(epoch=step) self.step(epoch=step)
@register_scheduler
class ConstantLR(LRScheduler):
"""
Args:
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""
def __init__(self, learning_rate, last_epoch=-1, verbose=False):
super().__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(module_class,
LRScheduler), f"{module} does not implement LRScheduler"
return module_class
class LRSchedulerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):
module_class = dynamic_import_scheduler(name.lower())
return instance_class(module_class, args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import time
from deepspeech.utils.log import Log
__all__ = ["Timer"]
logger = Log(__name__).getlog()
class Timer():
"""To be used like this:
with Timer("Message") as value:
do some thing
"""
def __init__(self, message=None):
self.message = message
def duration(self) -> str:
elapsed_time = time.time() - self.start
time_str = str(datetime.timedelta(seconds=elapsed_time))
return time_str
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, type, value, traceback):
if self.message:
logger.info(self.message.format(self.duration()))
def __call__(self) -> float:
return time.time() - self.start
def __str__(self):
return self.duration()
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger():
"""A Predicate to do something every N cycle."""
def __init__(self, period: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
if period <= 0:
raise ValueError("period should be a positive integer.")
self.period = period
self.unit = unit
self.last_index = None
def __call__(self, trainer):
if self.last_index is None:
last_index = getattr(trainer.updater.state, self.unit)
self.last_index = last_index
last_index = self.last_index
index = getattr(trainer.updater.state, self.unit)
fire = index // self.period != last_index // self.period
self.last_index = index
return fire
此差异已折叠。
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
export MAIN_ROOT=${PWD} export MAIN_ROOT=${PWD}
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:/usr/local/bin:${PATH} export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}:/usr/local/bin
export LC_ALL=C export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
old-pd_env.txt
pd_env.txt
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册