提交 5dd9e2f8 编写于 作者: H huangyuxin

先不暴露出online

上级 6079a249
...@@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler ...@@ -29,8 +29,8 @@ from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.ds2 import DeepSpeech2InferModel from deepspeech.models.ds2 import DeepSpeech2InferModel
from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline #from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline #from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
...@@ -122,25 +122,15 @@ class DeepSpeech2Trainer(Trainer): ...@@ -122,25 +122,15 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if (config.model.apply_online == False): model = DeepSpeech2Model(
model = DeepSpeech2Model( feat_size=self.train_loader.collate_fn.feature_size,
feat_size=self.train_loader.collate_fn.feature_size, dict_size=self.train_loader.collate_fn.vocab_size,
dict_size=self.train_loader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers,
num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers,
num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size,
rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru,
use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights)
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -347,7 +337,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -347,7 +337,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
else: else:
infer_model = DeepSpeech2InferModelOnline.from_pretrained( infer_model = DeepSpeech2InferModelOnline.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path) self.test_loader, self.config, self.args.checkpoint_path)
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(
...@@ -384,25 +374,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -384,25 +374,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
if config.model.apply_online == False: model = DeepSpeech2Model(
model = DeepSpeech2Model( feat_size=self.test_loader.collate_fn.feature_size,
feat_size=self.test_loader.collate_fn.feature_size, dict_size=self.test_loader.collate_fn.vocab_size,
dict_size=self.test_loader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers,
num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers,
num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size,
rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru,
use_gru=config.model.use_gru, share_rnn_weights=config.model.share_rnn_weights)
share_rnn_weights=config.model.share_rnn_weights)
else:
model = DeepSpeech2ModelOnline(
feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
self.model = model self.model = model
logger.info("Setup model!") logger.info("Setup model!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册