提交 d55e6b5a 编写于 作者: H Haoxin Ma

revise from_pretrained function

上级 3652b87f
......@@ -29,6 +29,9 @@ from deepspeech.utils.socket_server import warm_up_test
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
def init_predictor(args):
if args.model_dir is not None:
......@@ -83,7 +86,12 @@ def start_server(config, args):
config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
model.eval()
......
......@@ -28,6 +28,9 @@ from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
def start_server(config, args):
"""Start the ASR server"""
config.defrost()
......@@ -36,7 +39,12 @@ def start_server(config, args):
config.data.keep_transcription_text = True
dataset = ManifestDataset.from_config(config)
model = DeepSpeech2Model.from_pretrained(dataset, config,
config.collator.batch_size=1
config.collator.num_workers=0
collate_fn = SpeechCollator.from_config(config)
test_loader = DataLoader(dataset_dataset, collate_fn=collate_fn, num_workers=0)
model = DeepSpeech2Model.from_pretrained(test_loader, config,
args.checkpoint_path)
model.eval()
......
......@@ -47,7 +47,7 @@ def tune(config, args):
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
model = DeepSpeech2Model.from_pretrained(dev_dataset, config,
model = DeepSpeech2Model.from_pretrained(valid_loader, config,
args.checkpoint_path)
model.eval()
......
......@@ -318,7 +318,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def export(self):
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, self.config, self.args.checkpoint_path)
self.test_loader, self.config, self.args.checkpoint_path)
infer_model.eval()
feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static(
......
......@@ -506,7 +506,7 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec.
"""
from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset,
infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(),
self.args.checkpoint_path)
feat_dim = self.test_loader.collate_fn.feature_size
......
......@@ -198,11 +198,11 @@ class DeepSpeech2Model(nn.Layer):
cutoff_top_n, num_processes)
@classmethod
def from_pretrained(cls, dataset, config, checkpoint_path):
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Parameters
----------
dataset: paddle.io.Dataset
dataloader: paddle.io.DataLoader
config: yacs.config.CfgNode
model configs
......@@ -215,8 +215,8 @@ class DeepSpeech2Model(nn.Layer):
DeepSpeech2Model
The model built from pretrained result.
"""
model = cls(feat_size=dataset.feature_size,
dict_size=dataset.vocab_size,
model = cls(feat_size=dataloader.collate_fn.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
......
......@@ -876,11 +876,11 @@ class U2Model(U2BaseModel):
return model
@classmethod
def from_pretrained(cls, dataset, config, checkpoint_path):
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataset (paddle.io.Dataset): not used.
dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
......@@ -888,8 +888,8 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result.
"""
config.defrost()
config.input_dim = dataset.feature_size
config.output_dim = dataset.vocab_size
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
model = cls.from_config(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册