提交 6d8a2de6 编写于 作者: P Philip

Support split words

上级 c7c95025
import codecs
class Alphabet(object):
def __init__(self, config_file):
self._config_file = config_file
self._label_to_str = []
self._str_to_label = {}
self._size = 0
self.blank_token = 1
with codecs.open(config_file, 'r', 'utf-8') as fin:
for line in fin:
if line[0:2] == '\\#':
line = '#\n'
elif line[0] == '#':
continue
self._label_to_str += line[:-1] # remove the line ending
self._str_to_label[line[:-1]] = self._size
self._size += 1
def string_from_label(self, label):
return self._label_to_str[label]
def label_from_string(self, string):
try:
return self._str_to_label[string]
except KeyError as e:
raise KeyError(
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
).with_traceback(e.__traceback__)
def decode(self, labels):
res = ''
for label in labels:
res += self.string_from_label(label)
return res
def size(self):
return self._size
def config_file(self):
return self._config_file
"""Contains DeepSpeech2 model."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -12,9 +12,10 @@ import copy ...@@ -12,9 +12,10 @@ import copy
import inspect import inspect
from distutils.dir_util import mkpath from distutils.dir_util import mkpath
import paddle.v2 as paddle import paddle.v2 as paddle
from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder from ds_ctcdecoder import Scorer, swigwrapper
from decoders.swig_wrapper import ctc_beam_search_decoder_batch from alphabet import Alphabet
from words import Words
from model_utils.network import deep_speech_v2_network from model_utils.network import deep_speech_v2_network
logging.basicConfig( logging.basicConfig(
...@@ -205,26 +206,8 @@ class DeepSpeech2Model(object): ...@@ -205,26 +206,8 @@ class DeepSpeech2Model(object):
] ]
return probs_split return probs_split
def decode_batch_greedy(self, probs_split, vocab_list): def init_ext_scorer(self, beam_alpha, beam_beta,
"""Decode by best path for a batch of probs matrix input. language_model_path, trie_path, alphabet):
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:return: List of transcription texts.
:rtype: List of basestring
"""
results = []
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription)
return results
def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
vocab_list):
"""Initialize the external scorer. """Initialize the external scorer.
:param beam_alpha: Parameter associated with language model. :param beam_alpha: Parameter associated with language model.
...@@ -236,39 +219,31 @@ class DeepSpeech2Model(object): ...@@ -236,39 +219,31 @@ class DeepSpeech2Model(object):
None, and the decoding method will be pure None, and the decoding method will be pure
beam search without scorer. beam search without scorer.
:type language_model_path: basestring|None :type language_model_path: basestring|None
:param vocab_list: List of tokens in the vocabulary, for decoding. :param trie_path: Filepath for trie
:type vocab_list: list :type trie_path: basestring\None
:param alphabet: Alphabet class for decoding.
:type alphabet: Alphabet
""" """
if language_model_path != '': if language_model_path != '':
self.logger.info("begin to initialize the external scorer " self.logger.info("begin to initialize the external scorer "
"for decoding") "for decoding")
self._ext_scorer = Scorer(beam_alpha, beam_beta, self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list) language_model_path, trie_path,
lm_char_based = self._ext_scorer.is_character_based() alphabet)
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
self.logger.info("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer") self.logger.info("end initializing scorer")
else: else:
self._ext_scorer = None self._ext_scorer = None
self.logger.info("no language model provided, " self.logger.info("no language model provided, "
"decoding by pure beam search without scorer.") "decoding by pure beam search without scorer.")
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, def decode_beam_search(self, probs_split, beam_size,
beam_size, cutoff_prob, cutoff_top_n, cutoff_prob, cutoff_top_n,
vocab_list, num_processes): alphabet):
"""Decode by beam search for a batch of probs matrix input. """Decode by beam search for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists :param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce. of prob vectors for one speech utterancce.
:param probs_split: List of matrix :param probs_split: List of matrix
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search. :param beam_size: Width for Beam search.
:type beam_size: int :type beam_size: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
...@@ -278,28 +253,24 @@ class DeepSpeech2Model(object): ...@@ -278,28 +253,24 @@ class DeepSpeech2Model(object):
characters with highest probs in vocabulary will be characters with highest probs in vocabulary will be
used in beam search, default 40. used in beam search, default 40.
:type cutoff_top_n: int :type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding. :param alphabet: alphabet used for decoding.
:type vocab_list: list :type alphabet: Alphabet
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:return: List of transcription texts. :return: List of transcription texts.
:rtype: List of basestring :rtype: List of basestring
""" """
if self._ext_scorer != None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
# beam search decode # beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch( metadata = swigwrapper.ctc_beam_search_decoder(
probs_split=probs_split, probs_split[0],
vocabulary=vocab_list, alphabet.config_file(),
beam_size=beam_size, beam_size,
num_processes=num_processes, cutoff_prob,
ext_scoring_func=self._ext_scorer, cutoff_top_n,
cutoff_prob=cutoff_prob, self._ext_scorer)
cutoff_top_n=cutoff_top_n)
words = Words(probs_split[0], metadata[0], alphabet)
results = [result[0][1] for result in beam_search_results]
return results return words
def _adapt_feeding_dict(self, feeding_dict): def _adapt_feeding_dict(self, feeding_dict):
"""Adapt feeding dict according to network struct. """Adapt feeding dict according to network struct.
......
import json
class Words(object):
def __init__(self, prob_split, metadata, alphabet, frame_to_sec=.03):
self.raw_output = ''
self.extended_output = []
word = ''
start_step, confidence, num_char = 0, 1.0, 0
metadata_size = metadata.tokens.size()
for i in range(metadata_size):
token = metadata.tokens[i]
letter = alphabet.string_from_label(token)
time_step = metadata.timesteps[i]
# prepare raw output
self.raw_output += letter
# prepare extended output
if token != alphabet.blank_token:
word.append(letter)
confidence *= prob_split[time_step][token]
num_char += 1
if len(word) == 1:
start_step = time_step
if token == alphabet.blank_token or i == metadata_size-1:
duration_step = time_step - start_step
if duration_step < 0:
duration_step = 0
self.extended_output.append({"word": word,
"start_time": frame_to_sec * start_step,
"duration": frame_to_sec * duration_step,
"confidence": confidence**(1.0/num_char)})
# reset
word = ''
start_step, confidence, num_char = 0, 1.0, 0
def to_json(self):
return json.dumps({"raw_output": self.raw_output,
"extended_output": self.extended_output})
def save_json(self, file_path):
with open(file_path, 'w') as outfile:
json.dump({"raw_output": self.raw_output,
"extended_output": self.extended_output},
outfile)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册