LstmLayer.h 7.9 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/BaseMatrix.h"
#include "SequenceToBatch.h"
#include "LstmCompute.h"
namespace paddle {

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
/**
 * @brief LstmLayer takes 1 input layer with size * 4.
 * Input layer is diveded into 4 equal parts:
 *   (input_s, input_ig, input_fg, input_og)
 *
 * For each sequence [start, end] it performs the following computation:
 * @code
 * output_{i} = actState(state_{i}) * actGate(outputGate_{i})
 * state_{i} = actInput(input_s_{i} + bias_s +
 *             output_{i-1} * recurrIW) * actGate(inputGate_{i}) +
 *             actGate(forgetGate_{i}) * state_{i-1}
 * inputGate = input_ig_{i} + bias_ig + output_{i-1} * recurrIGW +
 *             state_{i-1} * inputCheck
 * ouputGate = input_og_{i} + bias_og + output_{i-1} * recurrOGW +
 *             state_{i} * outputCheck
 * forgetGate = input_fg_{i} + bias_fg + output_{i-1} * recurrFGW +
 *              state_{i-1} * forgetCheck
 * @endcode
 *
 * - parameter[0] consists of (recurrIW, recurrIGW, recurrFGW, recurrOGW)
 * - baisParameter consists of
 *   (bias_s, bias_ig, bias_og, bias_fg, inputCheck, forgetCheck, outputCheck)
 *
 * - actInput is defined by config active_type.
 * - actState is defined by config active_state_type.
 * - actGate is defined by config actvie_gate_type.
 *
 * There are two ways to compute, namely one sequence by one sequence or
 * one batch by one batch. By default and no setting pre_batch_state true,
 * it will compute batch by batch.
 *
 * The formula in the paper is as follows:
 * \f[
 * i_t = \sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) \\
 * f_t = \sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) \\
 * \tilde{c_t} = tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) \\
 * o_t = \sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) \\
 * c_t = f_t * c_{t-1} + i_t * \tilde{c_t} \\
 * h_t = o_t tanh(c_t)
 * \f]
 *
 * @note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
 * operations on the input sequence were NOT included in LstmLayer. So
 * users should use fc_layer or mixed_layer before lstm_later.
 *
 * The weight ([size, 4*size]) contains \f$W_{hi}, W_{hf}, W_{hc}, W_{ho}\f$.
 * The bias contains \f$b_i, b_f, b_c, b_o\f$ and \f$W_{ci}, W_{cf}, W_{co}\f$.
 */
Z
zhangjinchao01 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

class LstmLayer : public Layer, public LstmCompute {
public:
  explicit LstmLayer(const LayerConfig &config) : Layer(config) {}

  bool init(const LayerMap &layerMap, const ParameterMap &parameterMap);

  void forward(PassType passType);

  void backward(const UpdateCallback &callback);

  void resetState();

  void setState(LayerStatePtr state);

  LayerStatePtr getState();

protected:
90 91 92 93 94 95 96 97 98 99
  /**
   * @brief Compute lstm forward one sequence by one sequence.
   * @param batchSize The batchSize is not equal to the batch_size in
   * the config file. It is the total words number of all samples
   * in this forward batch.
   * @param numSequences The sample number. It is equal to the batch_size
   * in the config file.
   * @param starts Each start position of each samples.
   * @param inputValue The input values.
   */
100 101 102
  void forwardSequence(int batchSize,
                       size_t numSequences,
                       const int *starts,
L
luotao02 已提交
103
                       MatrixPtr inputValue);
104 105 106
  /**
   * Compute lstm backward one sequence by one sequence.
   */
107 108 109
  void backwardSequence(int batchSize,
                        size_t numSequences,
                        const int *starts,
L
luotao02 已提交
110
                        MatrixPtr inputGrad);
Z
zhangjinchao01 已提交
111

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  /**
   * Compute lstm forward one batch by one batch. The batch value is
   * reorganized by SequenceToBatch class. The batch output value will
   * be convert into sequence value after finishing forward. Here, one
   * batch contains one word of each sample. If the length of each sample
   * is not equality, the batch will not pads zero and contains less words.
   * The total batch numbers are the max length of the sequence. The details
   * can refer to SequenceToBatch class. On GPU mode, it will launch GPU
   * kernel for loop.
   *
   * @code
   * for (int i = 0; i < numBatch(max_sequence_length); ++i) {
   *   compute one batch.
   * }
   * @endcode
   */
128 129 130
  void forwardBatch(int batchSize,
                    size_t numSequences,
                    const int *starts,
L
luotao02 已提交
131
                    MatrixPtr inputValue);
132 133 134
  /**
   * Compute lstm backward one batch by one batch.
   */
135 136 137
  void backwardBatch(int batchSize,
                     size_t numSequences,
                     const int *starts,
L
luotao02 已提交
138
                     MatrixPtr inputGrad);
Z
zhangjinchao01 已提交
139

140 141 142 143 144
  /**
   * This function only supports GPU. It not need to reorganize input into
   * batch value. It will launch one kernel to parallelly compute forward
   * propagation in sequence level.
   */
145 146 147
  void forwardSeqParallel(int batchSize,
                          size_t numSequences,
                          const int *starts,
L
luotao02 已提交
148
                          MatrixPtr inputValue);
149 150 151
  /**
   * Backward propagation corresponding to forwardSeqParallel.
   */
152 153 154 155
  void backwardSeqParallel(int batchSize,
                           size_t numSequences,
                           const int *starts,
                           MatrixPtr inputGrad);
156 157 158 159
  /**
   * This function is used for sequence generation and get output after
   * forwardBatch.
   */
Z
zhangjinchao01 已提交
160
  void getPrevBatchOutput(size_t numSequences);
161 162 163 164
  /**
   * This function is used for sequence generation and get state after
   * forwardBatch.
   */
Z
zhangjinchao01 已提交
165 166 167
  void getPrevBatchState(size_t numSequences);

protected:
168 169
  /// Learned parameters, shape: (size, 4*size).
  /// The weight ([size, 4*size]) contains \f$W_{hi}, W_{hf}, W_{hc}, W_{ho}\f$.
Z
zhangjinchao01 已提交
170
  std::unique_ptr<Weight> weight_;
171
  /// Learned bias parameter, shape: (1, 7 * size).
L
luotao02 已提交
172 173
  /// The bias contains \f$b_i, b_f, b_c, b_o\f$ and \f$W_{ci}, W_{cf},
  /// W_{co}\f$.
Z
zhangjinchao01 已提交
174
  std::unique_ptr<Weight> bias_;
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
  /// The reeal bias, point to \f$b_i, b_f, b_c, b_o\f$.
  MatrixPtr localBias_;
  /// The peephole connection for input gate.
  MatrixPtr checkIg_;
  /// The peephole connection for forget gate.
  MatrixPtr checkFg_;
  /// The peephole connection for output gate.
  MatrixPtr checkOg_;
  /// The gradient of real bias
  MatrixPtr localBiasGrad_;
  /// The gradient of peephole connection for input gates.
  MatrixPtr checkIgGrad_;
  /// The gradient of peephole connection for forget gates.
  MatrixPtr checkFgGrad_;
  /// The gradient of peephole connection for output gates.
  MatrixPtr checkOgGrad_;

  /// Stores the cell state of previous time step, namely \f$c_{t-1}\f$.
Z
zhangjinchao01 已提交
193
  Argument state_;
194
  /// Stores the hidden of previous time step, namely \f$h_{t-1}\f$.
Z
zhangjinchao01 已提交
195
  Argument preOutput_;
196 197
  /// Stores the value and gradient of four gates, namely
  /// \f$i_t, f_t, o_t, c_t\f$.
Z
zhangjinchao01 已提交
198
  Argument gate_;
199
  /// Whether it is reversed lstm.
Z
zhangjinchao01 已提交
200
  bool reversed_;
201
  /// Whether to use batch method to compute.
Z
zhangjinchao01 已提交
202
  bool useBatch_;
203
  /// Whether to use sequence parallell method to compute.
Z
zhangjinchao01 已提交
204
  bool useSeqParallel_;
205 206
  /// batchValue_ is used in method of batch calculation. It stores the
  /// batch value after reorganized input.
Z
zhangjinchao01 已提交
207
  std::unique_ptr<SequenceToBatch> batchValue_;
208
  /// The gradient of batchValue_.
Z
zhangjinchao01 已提交
209 210
  std::unique_ptr<SequenceToBatch> batchGrad_;

211
  /// Used in generation and stores the state of previous time step.
Z
zhangjinchao01 已提交
212
  MatrixPtr prevState_;
213
  /// Used in generation and stores the output of previous time step.
Z
zhangjinchao01 已提交
214 215
  MatrixPtr prevOutput_;
  MatrixPtr prevBatchOutput2_;
216
  /// The total state.
Z
zhangjinchao01 已提交
217 218 219 220
  MatrixPtr totalState_;
};

}  // namespace paddle