model.py 12.3 KB
Newer Older
1 2 3 4
"""Contains DeepSpeech2 model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
X
Xinghai Sun 已提交
5

6 7 8
import sys
import os
import time
9
import logging
10
import gzip
11
from distutils.dir_util import mkpath
12
import paddle.v2 as paddle
Y
Yibing Liu 已提交
13 14 15
from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
16
from model_utils.network import deep_speech_v2_network
17

18 19 20
logging.basicConfig(
    format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')

21

22
class DeepSpeech2Model(object):
23 24 25 26 27 28 29 30 31 32 33 34 35
    """DeepSpeech2Model class.

    :param vocab_size: Decoding vocabulary size.
    :type vocab_size: int
    :param num_conv_layers: Number of stacking convolution layers.
    :type num_conv_layers: int
    :param num_rnn_layers: Number of stacking RNN layers.
    :type num_rnn_layers: int
    :param rnn_layer_size: RNN layer size (number of RNN cells).
    :type rnn_layer_size: int
    :param pretrained_model_path: Pretrained model path. If None, will train
                                  from stratch.
    :type pretrained_model_path: basestring|None
36 37 38 39
    :param share_rnn_weights: Whether to share input-hidden weights between
                              forward and backward directional RNNs.Notice that
                              for GRU, weight sharing is not supported.
    :type share_rnn_weights: bool
40 41
    """

42
    def __init__(self, vocab_size, num_conv_layers, num_rnn_layers,
43 44
                 rnn_layer_size, use_gru, pretrained_model_path,
                 share_rnn_weights):
45
        self._create_network(vocab_size, num_conv_layers, num_rnn_layers,
46
                             rnn_layer_size, use_gru, share_rnn_weights)
47 48
        self._create_parameters(pretrained_model_path)
        self._inferer = None
49
        self._loss_inferer = None
50
        self._ext_scorer = None
51 52
        self.logger = logging.getLogger("")
        self.logger.setLevel(level=logging.INFO)
53

54 55 56 57 58 59 60
    def train(self,
              train_batch_reader,
              dev_batch_reader,
              feeding_dict,
              learning_rate,
              gradient_clipping,
              num_passes,
61
              output_model_dir,
62
              is_local=True,
63 64
              num_iterations_print=100,
              test_off=False):
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        """Train the model.

        :param train_batch_reader: Train data reader.
        :type train_batch_reader: callable
        :param dev_batch_reader: Validation data reader.
        :type dev_batch_reader: callable
        :param feeding_dict: Feeding is a map of field name and tuple index
                             of the data that reader returns.
        :type feeding_dict: dict|list
        :param learning_rate: Learning rate for ADAM optimizer.
        :type learning_rate: float
        :param gradient_clipping: Gradient clipping threshold.
        :type gradient_clipping: float
        :param num_passes: Number of training epochs.
        :type num_passes: int
        :param num_iterations_print: Number of training iterations for printing
                                     a training loss.
        :type rnn_iteratons_print: int
83 84
        :param is_local: Set to False if running with pserver with multi-nodes.
        :type is_local: bool
85 86
        :param output_model_dir: Directory for saving the model (every pass).
        :type output_model_dir: basestring
87 88
        :param test_off: Turn off testing.
        :type test_off: bool
89 90 91
        """
        # prepare model output directory
        if not os.path.exists(output_model_dir):
92
            mkpath(output_model_dir)
93

94 95 96 97 98 99 100
        # prepare optimizer and trainer
        optimizer = paddle.optimizer.Adam(
            learning_rate=learning_rate,
            gradient_clipping_threshold=gradient_clipping)
        trainer = paddle.trainer.SGD(
            cost=self._loss,
            parameters=self._parameters,
101 102
            update_equation=optimizer,
            is_local=is_local)
103

104 105 106 107 108 109 110 111 112 113
        # create event handler
        def event_handler(event):
            global start_time, cost_sum, cost_counter
            if isinstance(event, paddle.event.EndIteration):
                cost_sum += event.cost
                cost_counter += 1
                if (event.batch_id + 1) % num_iterations_print == 0:
                    output_model_path = os.path.join(output_model_dir,
                                                     "params.latest.tar.gz")
                    with gzip.open(output_model_path, 'w') as f:
114
                        trainer.save_parameter_to_tar(f)
115 116 117 118 119 120 121 122 123 124 125
                    print("\nPass: %d, Batch: %d, TrainCost: %f" %
                          (event.pass_id, event.batch_id + 1,
                           cost_sum / cost_counter))
                    cost_sum, cost_counter = 0.0, 0
                else:
                    sys.stdout.write('.')
                    sys.stdout.flush()
            if isinstance(event, paddle.event.BeginPass):
                start_time = time.time()
                cost_sum, cost_counter = 0.0, 0
            if isinstance(event, paddle.event.EndPass):
126 127 128 129 130 131
                if test_off:
                    print("\n------- Time: %d sec,  Pass: %d" %
                          (time.time() - start_time, event.pass_id))
                else:
                    result = trainer.test(
                        reader=dev_batch_reader, feeding=feeding_dict)
Y
yangyaming 已提交
132 133 134 135
                    print(
                        "\n------- Time: %d sec,  Pass: %d, "
                        "ValidationCost: %s" %
                        (time.time() - start_time, event.pass_id, result.cost))
136 137 138
                output_model_path = os.path.join(
                    output_model_dir, "params.pass-%d.tar.gz" % event.pass_id)
                with gzip.open(output_model_path, 'w') as f:
139
                    trainer.save_parameter_to_tar(f)
140

141 142 143 144 145 146
        # run train
        trainer.train(
            reader=train_batch_reader,
            event_handler=event_handler,
            num_passes=num_passes,
            feeding=feeding_dict)
147

148
    def infer_loss_batch(self, infer_data):
149 150 151 152 153 154 155 156 157 158
        """Model inference. Infer the ctc loss for a batch of speech
        utterances.

        :param infer_data: List of utterances to infer, with each utterance a
                           tuple of audio features and transcription text (empty
                           string).
        :type infer_data: list
        :return: List of ctc loss.
        :rtype: List of float
        """
159 160 161 162 163 164 165
        # define inferer
        if self._loss_inferer == None:
            self._loss_inferer = paddle.inference.Inference(
                output_layer=self._loss, parameters=self._parameters)
        # run inference
        return self._loss_inferer.infer(input=infer_data)

166
    def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
Y
Yibing Liu 已提交
167 168
                    beam_size, cutoff_prob, cutoff_top_n, vocab_list,
                    language_model_path, num_processes):
169 170 171
        """Model inference. Infer the transcription for a batch of speech
        utterances.

172 173 174
        :param infer_data: List of utterances to infer, with each utterance
                           consisting of a tuple of audio features and
                           transcription text (empty string).
175
        :type infer_data: list
176 177 178
        :param decoding_method: Decoding method name, 'ctc_greedy' or
                                'ctc_beam_search'.
        :param decoding_method: string
179 180 181 182 183 184 185 186 187
        :param beam_alpha: Parameter associated with language model.
        :type beam_alpha: float
        :param beam_beta: Parameter associated with word count.
        :type beam_beta: float
        :param beam_size: Width for Beam search.
        :type beam_size: int
        :param cutoff_prob: Cutoff probability in pruning,
                            default 1.0, no pruning.
        :type cutoff_prob: float
Y
Yibing Liu 已提交
188 189 190 191
        :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
                        characters with highest probs in vocabulary will be
                        used in beam search, default 40.
        :type cutoff_top_n: int
192 193 194 195 196 197 198 199 200
        :param vocab_list: List of tokens in the vocabulary, for decoding.
        :type vocab_list: list
        :param language_model_path: Filepath for language model.
        :type language_model_path: basestring|None
        :param num_processes: Number of processes (CPU) for decoder.
        :type num_processes: int
        :return: List of transcription texts.
        :rtype: List of basestring
        """
201 202 203 204 205 206 207 208 209 210 211 212 213
        # define inferer
        if self._inferer == None:
            self._inferer = paddle.inference.Inference(
                output_layer=self._log_probs, parameters=self._parameters)
        # run inference
        infer_results = self._inferer.infer(input=infer_data)
        num_steps = len(infer_results) // len(infer_data)
        probs_split = [
            infer_results[i * num_steps:(i + 1) * num_steps]
            for i in xrange(0, len(infer_data))
        ]
        # run decoder
        results = []
214
        if decoding_method == "ctc_greedy":
215 216
            # best path decode
            for i, probs in enumerate(probs_split):
217
                output_transcription = ctc_greedy_decoder(
Y
yangyaming 已提交
218
                    probs_seq=probs, vocabulary=vocab_list)
219
                results.append(output_transcription)
220
        elif decoding_method == "ctc_beam_search":
221 222 223
            # initialize external scorer
            if self._ext_scorer == None:
                self._loaded_lm_path = language_model_path
224 225 226 227 228 229 230 231 232 233 234 235 236
                self.logger.info("begin to initialize the external scorer "
                                 "for decoding")
                self._ext_scorer = Scorer(beam_alpha, beam_beta,
                                          language_model_path, vocab_list)

                lm_char_based = self._ext_scorer.is_character_based()
                lm_max_order = self._ext_scorer.get_max_order()
                lm_dict_size = self._ext_scorer.get_dict_size()
                self.logger.info("language model: "
                                 "is_character_based = %d," % lm_char_based +
                                 " max_order = %d," % lm_max_order +
                                 " dict_size = %d" % lm_dict_size)
                self.logger.info("end initializing scorer. Start decoding ...")
237 238 239 240
            else:
                self._ext_scorer.reset_params(beam_alpha, beam_beta)
                assert self._loaded_lm_path == language_model_path
            # beam search decode
241
            num_processes = min(num_processes, len(probs_split))
242 243 244 245 246 247
            beam_search_results = ctc_beam_search_decoder_batch(
                probs_split=probs_split,
                vocabulary=vocab_list,
                beam_size=beam_size,
                num_processes=num_processes,
                ext_scoring_func=self._ext_scorer,
Y
Yibing Liu 已提交
248 249
                cutoff_prob=cutoff_prob,
                cutoff_top_n=cutoff_top_n)
250

251 252
            results = [result[0][1] for result in beam_search_results]
        else:
253 254
            raise ValueError("Decoding method [%s] is not supported." %
                             decoding_method)
255
        return results
256

257
    def _create_parameters(self, model_path=None):
258
        """Load or create model parameters."""
259 260 261 262 263
        if model_path is None:
            self._parameters = paddle.parameters.create(self._loss)
        else:
            self._parameters = paddle.parameters.Parameters.from_tar(
                gzip.open(model_path))
264

265
    def _create_network(self, vocab_size, num_conv_layers, num_rnn_layers,
266
                        rnn_layer_size, use_gru, share_rnn_weights):
267
        """Create data layers and model network."""
268 269 270 271 272 273 274 275 276
        # paddle.data_type.dense_array is used for variable batch input.
        # The size 161 * 161 is only an placeholder value and the real shape
        # of input batch data will be induced during training.
        audio_data = paddle.layer.data(
            name="audio_spectrogram",
            type=paddle.data_type.dense_array(161 * 161))
        text_data = paddle.layer.data(
            name="transcript_text",
            type=paddle.data_type.integer_value_sequence(vocab_size))
277
        self._log_probs, self._loss = deep_speech_v2_network(
278 279 280 281 282
            audio_data=audio_data,
            text_data=text_data,
            dict_size=vocab_size,
            num_conv_layers=num_conv_layers,
            num_rnn_layers=num_rnn_layers,
X
Xinghai Sun 已提交
283
            rnn_size=rnn_layer_size,
284 285
            use_gru=use_gru,
            share_rnn_weights=share_rnn_weights)