LstmLayer.h 7.6 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/BaseMatrix.h"
#include "SequenceToBatch.h"
#include "LstmCompute.h"
namespace paddle {

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
/**
 * @brief LstmLayer takes 1 input layer with size * 4.
 * Input layer is diveded into 4 equal parts:
 *   (input_s, input_ig, input_fg, input_og)
 *
 * For each sequence [start, end] it performs the following computation:
 * @code
 * output_{i} = actState(state_{i}) * actGate(outputGate_{i})
 * state_{i} = actInput(input_s_{i} + bias_s +
 *             output_{i-1} * recurrIW) * actGate(inputGate_{i}) +
 *             actGate(forgetGate_{i}) * state_{i-1}
 * inputGate = input_ig_{i} + bias_ig + output_{i-1} * recurrIGW +
 *             state_{i-1} * inputCheck
 * ouputGate = input_og_{i} + bias_og + output_{i-1} * recurrOGW +
 *             state_{i} * outputCheck
 * forgetGate = input_fg_{i} + bias_fg + output_{i-1} * recurrFGW +
 *              state_{i-1} * forgetCheck
 * @endcode
 *
 * - parameter[0] consists of (recurrIW, recurrIGW, recurrFGW, recurrOGW)
 * - baisParameter consists of
 *   (bias_s, bias_ig, bias_og, bias_fg, inputCheck, forgetCheck, outputCheck)
 *
 * - actInput is defined by config active_type.
 * - actState is defined by config active_state_type.
 * - actGate is defined by config actvie_gate_type.
 *
 * There are two ways to compute, namely one sequence by one sequence or
 * one batch by one batch. By default and no setting pre_batch_state true,
 * it will compute batch by batch.
 *
 * The formula in the paper is as follows:
 * \f[
 * i_t = \sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) \\
 * f_t = \sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) \\
 * \tilde{c_t} = tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) \\
 * o_t = \sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) \\
 * c_t = f_t * c_{t-1} + i_t * \tilde{c_t} \\
 * h_t = o_t tanh(c_t)
 * \f]
 *
 * @note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
 * operations on the input sequence were NOT included in LstmLayer. So
 * users should use fc_layer or mixed_layer before lstm_later.
 *
 * The weight ([size, 4*size]) contains \f$W_{hi}, W_{hf}, W_{hc}, W_{ho}\f$.
 * The bias contains \f$b_i, b_f, b_c, b_o\f$ and \f$W_{ci}, W_{cf}, W_{co}\f$.
 */
Z
zhangjinchao01 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

class LstmLayer : public Layer, public LstmCompute {
public:
  explicit LstmLayer(const LayerConfig &config) : Layer(config) {}

  bool init(const LayerMap &layerMap, const ParameterMap &parameterMap);

  void forward(PassType passType);

  void backward(const UpdateCallback &callback);

  void resetState();

  void setState(LayerStatePtr state);

  LayerStatePtr getState();

protected:
90 91 92 93 94 95 96 97 98 99
  /**
   * @brief Compute lstm forward one sequence by one sequence.
   * @param batchSize The batchSize is not equal to the batch_size in
   * the config file. It is the total words number of all samples
   * in this forward batch.
   * @param numSequences The sample number. It is equal to the batch_size
   * in the config file.
   * @param starts Each start position of each samples.
   * @param inputValue The input values.
   */
L
luotao02 已提交
100 101
  void forwardSequence(int batchSize, size_t numSequences, const int *starts,
                       MatrixPtr inputValue);
102 103 104
  /**
   * Compute lstm backward one sequence by one sequence.
   */
L
luotao02 已提交
105 106
  void backwardSequence(int batchSize, size_t numSequences, const int *starts,
                        MatrixPtr inputGrad);
Z
zhangjinchao01 已提交
107

108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
  /**
   * Compute lstm forward one batch by one batch. The batch value is
   * reorganized by SequenceToBatch class. The batch output value will
   * be convert into sequence value after finishing forward. Here, one
   * batch contains one word of each sample. If the length of each sample
   * is not equality, the batch will not pads zero and contains less words.
   * The total batch numbers are the max length of the sequence. The details
   * can refer to SequenceToBatch class. On GPU mode, it will launch GPU
   * kernel for loop.
   *
   * @code
   * for (int i = 0; i < numBatch(max_sequence_length); ++i) {
   *   compute one batch.
   * }
   * @endcode
   */
L
luotao02 已提交
124 125
  void forwardBatch(int batchSize, size_t numSequences, const int *starts,
                    MatrixPtr inputValue);
126 127 128
  /**
   * Compute lstm backward one batch by one batch.
   */
L
luotao02 已提交
129 130
  void backwardBatch(int batchSize, size_t numSequences, const int *starts,
                     MatrixPtr inputGrad);
Z
zhangjinchao01 已提交
131

132 133 134 135 136
  /**
   * This function only supports GPU. It not need to reorganize input into
   * batch value. It will launch one kernel to parallelly compute forward
   * propagation in sequence level.
   */
L
luotao02 已提交
137 138
  void forwardSeqParallel(int batchSize, size_t numSequences, const int *starts,
                          MatrixPtr inputValue);
139 140 141
  /**
   * Backward propagation corresponding to forwardSeqParallel.
   */
Z
zhangjinchao01 已提交
142 143
  void backwardSeqParallel(int batchSize, size_t numSequences,
                           const int *starts, MatrixPtr inputGrad);
144 145 146 147
  /**
   * This function is used for sequence generation and get output after
   * forwardBatch.
   */
Z
zhangjinchao01 已提交
148
  void getPrevBatchOutput(size_t numSequences);
149 150 151 152
  /**
   * This function is used for sequence generation and get state after
   * forwardBatch.
   */
Z
zhangjinchao01 已提交
153 154 155
  void getPrevBatchState(size_t numSequences);

protected:
156 157
  /// Learned parameters, shape: (size, 4*size).
  /// The weight ([size, 4*size]) contains \f$W_{hi}, W_{hf}, W_{hc}, W_{ho}\f$.
Z
zhangjinchao01 已提交
158
  std::unique_ptr<Weight> weight_;
159
  /// Learned bias parameter, shape: (1, 7 * size).
L
luotao02 已提交
160 161
  /// The bias contains \f$b_i, b_f, b_c, b_o\f$ and \f$W_{ci}, W_{cf},
  /// W_{co}\f$.
Z
zhangjinchao01 已提交
162
  std::unique_ptr<Weight> bias_;
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
  /// The reeal bias, point to \f$b_i, b_f, b_c, b_o\f$.
  MatrixPtr localBias_;
  /// The peephole connection for input gate.
  MatrixPtr checkIg_;
  /// The peephole connection for forget gate.
  MatrixPtr checkFg_;
  /// The peephole connection for output gate.
  MatrixPtr checkOg_;
  /// The gradient of real bias
  MatrixPtr localBiasGrad_;
  /// The gradient of peephole connection for input gates.
  MatrixPtr checkIgGrad_;
  /// The gradient of peephole connection for forget gates.
  MatrixPtr checkFgGrad_;
  /// The gradient of peephole connection for output gates.
  MatrixPtr checkOgGrad_;

  /// Stores the cell state of previous time step, namely \f$c_{t-1}\f$.
Z
zhangjinchao01 已提交
181
  Argument state_;
182
  /// Stores the hidden of previous time step, namely \f$h_{t-1}\f$.
Z
zhangjinchao01 已提交
183
  Argument preOutput_;
184 185
  /// Stores the value and gradient of four gates, namely
  /// \f$i_t, f_t, o_t, c_t\f$.
Z
zhangjinchao01 已提交
186
  Argument gate_;
187
  /// Whether it is reversed lstm.
Z
zhangjinchao01 已提交
188
  bool reversed_;
189
  /// Whether to use batch method to compute.
Z
zhangjinchao01 已提交
190
  bool useBatch_;
191
  /// Whether to use sequence parallell method to compute.
Z
zhangjinchao01 已提交
192
  bool useSeqParallel_;
193 194
  /// batchValue_ is used in method of batch calculation. It stores the
  /// batch value after reorganized input.
Z
zhangjinchao01 已提交
195
  std::unique_ptr<SequenceToBatch> batchValue_;
196
  /// The gradient of batchValue_.
Z
zhangjinchao01 已提交
197 198
  std::unique_ptr<SequenceToBatch> batchGrad_;

199
  /// Used in generation and stores the state of previous time step.
Z
zhangjinchao01 已提交
200
  MatrixPtr prevState_;
201
  /// Used in generation and stores the output of previous time step.
Z
zhangjinchao01 已提交
202 203
  MatrixPtr prevOutput_;
  MatrixPtr prevBatchOutput2_;
204
  /// The total state.
Z
zhangjinchao01 已提交
205 206 207 208
  MatrixPtr totalState_;
};

}  // namespace paddle