diff --git a/deep_speech_2/model.py b/deep_speech_2/model.py index c8766deb1d33a12b9e96f00696847fd000c2b78d..c2e440b3ae4a8455d41de7b293e95123b7038931 100644 --- a/deep_speech_2/model.py +++ b/deep_speech_2/model.py @@ -120,6 +120,16 @@ class DeepSpeech2Model(object): feeding=feeding_dict) def infer_loss_batch(self, infer_data): + """Model inference. Infer the ctc loss for a batch of speech + utterances. + + :param infer_data: List of utterances to infer, with each utterance a + tuple of audio features and transcription text (empty + string). + :type infer_data: list + :return: List of ctc loss. + :rtype: List of float + """ # define inferer if self._loss_inferer == None: self._loss_inferer = paddle.inference.Inference(