提交 6ee3033c 编写于 作者: H Haoxin Ma

finish aishell/s0

上级 7bae32f3
......@@ -102,8 +102,8 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.train_loader.dataset.feature_size,
dict_size=self.train_loader.dataset.vocab_size,
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
......@@ -199,7 +199,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.dataset.vocab_list
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode(
......@@ -272,7 +272,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, self.config, self.args.checkpoint_path)
infer_model.eval()
feat_dim = self.test_loader.dataset.feature_size
feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
......@@ -308,8 +308,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self):
config = self.config
model = DeepSpeech2Model(
feat_size=self.test_loader.dataset.feature_size,
dict_size=self.test_loader.dataset.vocab_size,
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
......
......@@ -279,8 +279,8 @@ class U2Trainer(Trainer):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.dataset.feature_size
model_conf.output_dim = self.train_loader.dataset.vocab_size
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf)
......@@ -497,7 +497,7 @@ class U2Tester(U2Trainer):
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset,
self.config.model.clone(),
self.args.checkpoint_path)
feat_dim = self.test_loader.dataset.feature_size
feat_dim = self.test_loader.collate_fn.feature_size
input_spec = [
paddle.static.InputSpec(
shape=[None, feat_dim, None],
......
......@@ -104,13 +104,60 @@ class SpeechFeaturizer(object):
speech_segment.transcript)
return spec_feature, text_ids
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._text_featurizer.vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
return self._text_featurizer.vocab_dict
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return self._audio_featurizer.feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
@property
def text_feature(self):
"""Return the text feature object.
Returns:
TextFeaturizer: object.
"""
return self._text_featurizer
# @property
# def text_feature(self):
# """Return the text feature object.
# Returns:
# TextFeaturizer: object.
# """
# return self._text_featurizer
......@@ -283,11 +283,41 @@ class SpeechCollator():
return utts, padded_audios, audio_lens, padded_texts, text_lens
# @property
# def text_feature(self):
# return self._speech_featurizer.text_feature
# @property
# def stride_ms(self):
# return self._speech_featurizer.stride_ms
###########
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
return self._speech_featurizer.stride_ms
\ No newline at end of file
......@@ -55,10 +55,6 @@ class ManifestDataset(Dataset):
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
if config is not None:
......@@ -77,7 +73,7 @@ class ManifestDataset(Dataset):
"""
assert 'manifest' in config.data
assert config.data.manifest
assert 'keep_transcription_text' in config.data
assert 'keep_transcription_text' in config.collator
if isinstance(config.data.augmentation_config, (str, bytes)):
if config.data.augmentation_config:
......@@ -171,51 +167,51 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0])
self._vocab_list = self._read_vocab(vocab_filepath)
# self._vocab_list = self._read_vocab(vocab_filepath)
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
vocab_dict = dict(
[(token, idx) for (idx, token) in enumerate(self._vocab_list)])
return vocab_dict
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return self._manifest[0]["feat_shape"][-1]
# @property
# def manifest(self):
# return self._manifest
# @property
# def vocab_size(self):
# """Return the vocabulary size.
# Returns:
# int: Vocabulary size.
# """
# return len(self._vocab_list)
# @property
# def vocab_list(self):
# """Return the vocabulary in list.
# Returns:
# List[str]:
# """
# return self._vocab_list
# @property
# def vocab_dict(self):
# """Return the vocabulary in dict.
# Returns:
# Dict[str, int]:
# """
# vocab_dict = dict(
# [(token, idx) for (idx, token) in enumerate(self._vocab_list)])
# return vocab_dict
# @property
# def feature_size(self):
# """Return the audio feature size.
# Returns:
# int: audio feature size.
# """
# return self._manifest[0]["feat_shape"][-1]
def __len__(self):
return len(self._manifest)
......
......@@ -5,7 +5,6 @@ data:
test_manifest: data/manifest.test
mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 64 # one gpu
min_input_len: 0.0
max_input_len: 27.0 # second
......@@ -13,21 +12,26 @@ data:
max_output_len: .inf
min_output_input_ratio: 0.00
max_output_input_ratio: .inf
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
collator:
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
random_seed: 0
dither: 1.0
keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0
model:
num_conv_layers: 2
......
......@@ -42,7 +42,7 @@ model:
share_rnn_weights: True
training:
n_epoch: 21
n_epoch: 23
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
......
......@@ -72,7 +72,7 @@ model:
training:
n_epoch: 3
n_epoch: 21
accum_grad: 1
global_grad_clip: 5.0
optim: adam
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册