提交 895a086f 编写于 作者: H huangyuxin

rename the config.feat_size and the config.vocab.size to input_size and output_size

上级 d2888897
......@@ -24,4 +24,4 @@
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.725063021977743 | 0.047417 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.725063021977743 | 0.053922 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.725063021977743 | 0.053180 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescoring | 6.725063021977743 | 0.041026 |
\ No newline at end of file
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescoring | 6.725063021977743 | 0.041026 |
......@@ -110,8 +110,8 @@ class DeepSpeech2Tester_hub():
def setup_model(self):
config = self.config.clone()
with UpdateConfig(config):
config.model.feat_size = self.collate_fn_test.feature_size
config.model.dict_size = self.collate_fn_test.vocab_size
config.model.input_dim = self.collate_fn_test.feature_size
config.model.output_dim = self.collate_fn_test.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
......
......@@ -154,11 +154,11 @@ class DeepSpeech2Trainer(Trainer):
config = self.config.clone()
with UpdateConfig(config):
if self.train:
config.model.feat_size = self.train_loader.collate_fn.feature_size
config.model.dict_size = self.train_loader.collate_fn.vocab_size
config.model.input_dim = self.train_loader.collate_fn.feature_size
config.model.output_dim = self.train_loader.collate_fn.vocab_size
else:
config.model.feat_size = self.test_loader.collate_fn.feature_size
config.model.dict_size = self.test_loader.collate_fn.vocab_size
config.model.input_dim = self.test_loader.collate_fn.feature_size
config.model.output_dim = self.test_loader.collate_fn.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
......
......@@ -249,8 +249,8 @@ class DeepSpeech2Model(nn.Layer):
The model built from config.
"""
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
feat_size=config.input_dim,
dict_size=config.output_dim,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
......
......@@ -381,8 +381,8 @@ class DeepSpeech2ModelOnline(nn.Layer):
The model built from config.
"""
model = cls(
feat_size=config.feat_size,
dict_size=config.dict_size,
feat_size=config.input_dim,
dict_size=config.output_dim,
num_conv_layers=config.num_conv_layers,
num_rnn_layers=config.num_rnn_layers,
rnn_size=config.rnn_layer_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册