提交 d55e6b5a 编写于 作者: H Haoxin Ma

revise from_pretrained function

上级 3652b87f
...@@ -29,6 +29,9 @@ from deepspeech.utils.socket_server import warm_up_test ...@@ -29,6 +29,9 @@ from deepspeech.utils.socket_server import warm_up_test
from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
def init_predictor(args): def init_predictor(args):
if args.model_dir is not None: if args.model_dir is not None:
...@@ -83,7 +86,12 @@ def start_server(config, args): ...@@ -83,7 +86,12 @@ def start_server(config, args):
config.data.keep_transcription_text = True config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config) dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config, config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path) args.checkpoint_path)
model.eval() model.eval()
......
...@@ -28,6 +28,9 @@ from deepspeech.utils.utility import add_arguments ...@@ -28,6 +28,9 @@ from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
def start_server(config, args): def start_server(config, args):
"""Start the ASR server""" """Start the ASR server"""
config.defrost() config.defrost()
...@@ -36,7 +39,12 @@ def start_server(config, args): ...@@ -36,7 +39,12 @@ def start_server(config, args):
config.data.keep_transcription_text = True config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config) dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config, config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path) args.checkpoint_path)
model.eval() model.eval()
......
...@@ -47,7 +47,7 @@ def tune(config, args): ...@@ -47,7 +47,7 @@ def tune(config, args):
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True)) collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(dev_dataset, config, model = DeepSpeech2Model.from_pretrained(valid_loader, config,
args.checkpoint_path) args.checkpoint_path)
model.eval() model.eval()
......
...@@ -318,7 +318,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): ...@@ -318,7 +318,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def export(self): def export(self):
infer_model = DeepSpeech2InferModel.from_pretrained( infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, self.config, self.args.checkpoint_path) self.test_loader, self.config, self.args.checkpoint_path)
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(
......
...@@ -506,7 +506,7 @@ class U2Tester(U2Trainer): ...@@ -506,7 +506,7 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec. List[paddle.static.InputSpec]: input spec.
""" """
from deepspeech.models.u2 import U2InferModel from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(), self.config.model.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
feat_dim = self.test_loader.collate_fn.feature_size feat_dim = self.test_loader.collate_fn.feature_size
......
...@@ -198,11 +198,11 @@ class DeepSpeech2Model(nn.Layer): ...@@ -198,11 +198,11 @@ class DeepSpeech2Model(nn.Layer):
cutoff_top_n, num_processes) cutoff_top_n, num_processes)
@classmethod @classmethod
def from_pretrained(cls, dataset, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model. """Build a DeepSpeech2Model model from a pretrained model.
Parameters Parameters
---------- ----------
dataset: paddle.io.Dataset dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode config: yacs.config.CfgNode
model configs model configs
...@@ -215,8 +215,8 @@ class DeepSpeech2Model(nn.Layer): ...@@ -215,8 +215,8 @@ class DeepSpeech2Model(nn.Layer):
DeepSpeech2Model DeepSpeech2Model
The model built from pretrained result. The model built from pretrained result.
""" """
model = cls(feat_size=dataset.feature_size, model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataset.vocab_size, dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers, num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
......
...@@ -876,11 +876,11 @@ class U2Model(U2BaseModel): ...@@ -876,11 +876,11 @@ class U2Model(U2BaseModel):
return model return model
@classmethod @classmethod
def from_pretrained(cls, dataset, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model. """Build a DeepSpeech2Model model from a pretrained model.
Args: Args:
dataset (paddle.io.Dataset): not used. dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
...@@ -888,8 +888,8 @@ class U2Model(U2BaseModel): ...@@ -888,8 +888,8 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result. DeepSpeech2Model: The model built from pretrained result.
""" """
config.defrost() config.defrost()
config.input_dim = dataset.feature_size config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataset.vocab_size config.output_dim = dataloader.collate_fn.vocab_size
config.freeze() config.freeze()
model = cls.from_config(config) model = cls.from_config(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册