LstmLayer.h 8.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "Layer.h"
#include "LstmCompute.h"
Y
Yu Yang 已提交
19 20 21
#include "SequenceToBatch.h"
#include "paddle/math/BaseMatrix.h"
#include "paddle/math/Matrix.h"
Z
zhangjinchao01 已提交
22 23
namespace paddle {

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
/**
 * @brief LstmLayer takes 1 input layer with size * 4.
 * Input layer is diveded into 4 equal parts:
 *   (input_s, input_ig, input_fg, input_og)
 *
 * For each sequence [start, end] it performs the following computation:
 * @code
 * output_{i} = actState(state_{i}) * actGate(outputGate_{i})
 * state_{i} = actInput(input_s_{i} + bias_s +
 *             output_{i-1} * recurrIW) * actGate(inputGate_{i}) +
 *             actGate(forgetGate_{i}) * state_{i-1}
 * inputGate = input_ig_{i} + bias_ig + output_{i-1} * recurrIGW +
 *             state_{i-1} * inputCheck
 * ouputGate = input_og_{i} + bias_og + output_{i-1} * recurrOGW +
 *             state_{i} * outputCheck
 * forgetGate = input_fg_{i} + bias_fg + output_{i-1} * recurrFGW +
 *              state_{i-1} * forgetCheck
 * @endcode
 *
 * - parameter[0] consists of (recurrIW, recurrIGW, recurrFGW, recurrOGW)
 * - baisParameter consists of
 *   (bias_s, bias_ig, bias_og, bias_fg, inputCheck, forgetCheck, outputCheck)
 *
 * - actInput is defined by config active_type.
 * - actState is defined by config active_state_type.
 * - actGate is defined by config actvie_gate_type.
 *
 * There are two ways to compute, namely one sequence by one sequence or
 * one batch by one batch. By default and no setting pre_batch_state true,
 * it will compute batch by batch.
 *
 * The formula in the paper is as follows:
 * \f[
 * i_t = \sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) \\
 * f_t = \sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) \\
 * \tilde{c_t} = tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) \\
 * o_t = \sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) \\
 * c_t = f_t * c_{t-1} + i_t * \tilde{c_t} \\
 * h_t = o_t tanh(c_t)
 * \f]
 *
 * @note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
 * operations on the input sequence were NOT included in LstmLayer. So
 * users should use fc_layer or mixed_layer before lstm_later.
 *
 * The weight ([size, 4*size]) contains \f$W_{hi}, W_{hf}, W_{hc}, W_{ho}\f$.
 * The bias contains \f$b_i, b_f, b_c, b_o\f$ and \f$W_{ci}, W_{cf}, W_{co}\f$.
 */
Z
zhangjinchao01 已提交
72 73

class LstmLayer : public Layer, public LstmCompute {
W
Wu Yi 已提交
74
 public:
Z
zhangjinchao01 已提交
75 76
  explicit LstmLayer(const LayerConfig &config) : Layer(config) {}

Y
Yu Yang 已提交
77 78
  bool init(const LayerMap &layerMap,
            const ParameterMap &parameterMap) override;
Z
zhangjinchao01 已提交
79

Y
Yu Yang 已提交
80
  void forward(PassType passType) override;
Z
zhangjinchao01 已提交
81

Y
Yu Yang 已提交
82
  void backward(const UpdateCallback &callback) override;
Z
zhangjinchao01 已提交
83

Y
Yu Yang 已提交
84
  void resetState() override;
Z
zhangjinchao01 已提交
85

Y
Yu Yang 已提交
86
  void setState(LayerStatePtr state) override;
Z
zhangjinchao01 已提交
87

Y
Yu Yang 已提交
88
  LayerStatePtr getState() override;
Z
zhangjinchao01 已提交
89

W
Wu Yi 已提交
90
 protected:
91 92 93 94 95 96 97 98 99 100
  /**
   * @brief Compute lstm forward one sequence by one sequence.
   * @param batchSize The batchSize is not equal to the batch_size in
   * the config file. It is the total words number of all samples
   * in this forward batch.
   * @param numSequences The sample number. It is equal to the batch_size
   * in the config file.
   * @param starts Each start position of each samples.
   * @param inputValue The input values.
   */
101 102 103
  void forwardSequence(int batchSize,
                       size_t numSequences,
                       const int *starts,
L
luotao02 已提交
104
                       MatrixPtr inputValue);
105 106 107
  /**
   * Compute lstm backward one sequence by one sequence.
   */
108 109 110
  void backwardSequence(int batchSize,
                        size_t numSequences,
                        const int *starts,
L
luotao02 已提交
111
                        MatrixPtr inputGrad);
Z
zhangjinchao01 已提交
112

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  /**
   * Compute lstm forward one batch by one batch. The batch value is
   * reorganized by SequenceToBatch class. The batch output value will
   * be convert into sequence value after finishing forward. Here, one
   * batch contains one word of each sample. If the length of each sample
   * is not equality, the batch will not pads zero and contains less words.
   * The total batch numbers are the max length of the sequence. The details
   * can refer to SequenceToBatch class. On GPU mode, it will launch GPU
   * kernel for loop.
   *
   * @code
   * for (int i = 0; i < numBatch(max_sequence_length); ++i) {
   *   compute one batch.
   * }
   * @endcode
   */
129 130 131
  void forwardBatch(int batchSize,
                    size_t numSequences,
                    const int *starts,
L
luotao02 已提交
132
                    MatrixPtr inputValue);
133 134 135
  /**
   * Compute lstm backward one batch by one batch.
   */
136 137 138
  void backwardBatch(int batchSize,
                     size_t numSequences,
                     const int *starts,
L
luotao02 已提交
139
                     MatrixPtr inputGrad);
Z
zhangjinchao01 已提交
140

141 142 143 144 145
  /**
   * This function only supports GPU. It not need to reorganize input into
   * batch value. It will launch one kernel to parallelly compute forward
   * propagation in sequence level.
   */
146 147 148
  void forwardSeqParallel(int batchSize,
                          size_t numSequences,
                          const int *starts,
L
luotao02 已提交
149
                          MatrixPtr inputValue);
150 151 152
  /**
   * Backward propagation corresponding to forwardSeqParallel.
   */
153 154 155 156
  void backwardSeqParallel(int batchSize,
                           size_t numSequences,
                           const int *starts,
                           MatrixPtr inputGrad);
157 158 159 160
  /**
   * This function is used for sequence generation and get output after
   * forwardBatch.
   */
Z
zhangjinchao01 已提交
161
  void getPrevBatchOutput(size_t numSequences);
162 163 164 165
  /**
   * This function is used for sequence generation and get state after
   * forwardBatch.
   */
Z
zhangjinchao01 已提交
166 167
  void getPrevBatchState(size_t numSequences);

W
Wu Yi 已提交
168
 protected:
169 170
  /// Learned parameters, shape: (size, 4*size).
  /// The weight ([size, 4*size]) contains \f$W_{hi}, W_{hf}, W_{hc}, W_{ho}\f$.
Z
zhangjinchao01 已提交
171
  std::unique_ptr<Weight> weight_;
172
  /// Learned bias parameter, shape: (1, 7 * size).
L
luotao02 已提交
173 174
  /// The bias contains \f$b_i, b_f, b_c, b_o\f$ and \f$W_{ci}, W_{cf},
  /// W_{co}\f$.
Z
zhangjinchao01 已提交
175
  std::unique_ptr<Weight> bias_;
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
  /// The reeal bias, point to \f$b_i, b_f, b_c, b_o\f$.
  MatrixPtr localBias_;
  /// The peephole connection for input gate.
  MatrixPtr checkIg_;
  /// The peephole connection for forget gate.
  MatrixPtr checkFg_;
  /// The peephole connection for output gate.
  MatrixPtr checkOg_;
  /// The gradient of real bias
  MatrixPtr localBiasGrad_;
  /// The gradient of peephole connection for input gates.
  MatrixPtr checkIgGrad_;
  /// The gradient of peephole connection for forget gates.
  MatrixPtr checkFgGrad_;
  /// The gradient of peephole connection for output gates.
  MatrixPtr checkOgGrad_;

  /// Stores the cell state of previous time step, namely \f$c_{t-1}\f$.
Z
zhangjinchao01 已提交
194
  Argument state_;
195
  /// Stores the hidden of previous time step, namely \f$h_{t-1}\f$.
Z
zhangjinchao01 已提交
196
  Argument preOutput_;
197 198
  /// Stores the value and gradient of four gates, namely
  /// \f$i_t, f_t, o_t, c_t\f$.
Z
zhangjinchao01 已提交
199
  Argument gate_;
200
  /// Whether it is reversed lstm.
Z
zhangjinchao01 已提交
201
  bool reversed_;
202
  /// Whether to use batch method to compute.
Z
zhangjinchao01 已提交
203
  bool useBatch_;
204
  /// Whether to use sequence parallell method to compute.
Z
zhangjinchao01 已提交
205
  bool useSeqParallel_;
206 207
  /// batchValue_ is used in method of batch calculation. It stores the
  /// batch value after reorganized input.
Z
zhangjinchao01 已提交
208
  std::unique_ptr<SequenceToBatch> batchValue_;
209
  /// The gradient of batchValue_.
Z
zhangjinchao01 已提交
210 211
  std::unique_ptr<SequenceToBatch> batchGrad_;

212
  /// Used in generation and stores the state of previous time step.
Z
zhangjinchao01 已提交
213
  MatrixPtr prevState_;
214
  /// Used in generation and stores the output of previous time step.
Z
zhangjinchao01 已提交
215 216
  MatrixPtr prevOutput_;
  MatrixPtr prevBatchOutput2_;
217
  /// The total state.
Z
zhangjinchao01 已提交
218 219 220 221
  MatrixPtr totalState_;
};

}  // namespace paddle