lstm_op.cc 13.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dangqingqing 已提交
2

D
dangqingqing 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dangqingqing 已提交
6

D
dangqingqing 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
D
dangqingqing 已提交
8

D
dangqingqing 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dangqingqing 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/lstm_op.h"
S
sneaxiy 已提交
16
#include <memory>
17
#include <string>
D
dangqingqing 已提交
18 19 20 21 22 23 24 25

namespace paddle {
namespace operators {

class LSTMOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

26
  void InferShape(framework::InferShapeContext* ctx) const override {
27 28 29 30 31 32 33 34 35
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM");
    OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM");

    OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
    OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
    OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
    OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
                   "BatchCellPreAct", "LSTM");
D
dangqingqing 已提交
36

D
dangqingqing 已提交
37
    auto in_dims = ctx->GetInputDim("Input");
38 39 40 41
    PADDLE_ENFORCE_EQ(
        in_dims.size(), 2,
        platform::errors::InvalidArgument(
            "Input(X)'s rank must be 2, but received %d.", in_dims.size()));
D
dangqingqing 已提交
42 43

    if (ctx->HasInput("H0")) {
44 45 46 47
      PADDLE_ENFORCE_EQ(
          ctx->HasInput("C0"), true,
          platform::errors::NotFound("Input(Cell) and Input(Hidden) of LSTM "
                                     "should not be null at the same time."));
D
dangqingqing 已提交
48 49
      auto h_dims = ctx->GetInputDim("H0");
      auto c_dims = ctx->GetInputDim("C0");
50 51 52 53 54
      PADDLE_ENFORCE_EQ(h_dims, c_dims,
                        platform::errors::InvalidArgument(
                            "The dimension of Input(H0) and Input(C0) should "
                            "be the same, but received [%s] (H0) vs [%s] (C0).",
                            h_dims, c_dims));
D
dangqingqing 已提交
55 56
    }

D
dangqingqing 已提交
57
    int frame_size = in_dims[1] / 4;
D
dangqingqing 已提交
58
    auto w_dims = ctx->GetInputDim("Weight");
59 60 61 62 63
    PADDLE_ENFORCE_EQ(
        w_dims.size(), 2,
        platform::errors::InvalidArgument(
            "The rank of Input(Weight) should be 2, but received %d.",
            w_dims.size()));
D
dangqingqing 已提交
64
    PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
65 66 67 68
                      platform::errors::InvalidArgument(
                          "The first dimension of Input(Weight) should be %d, "
                          "but received %d.",
                          frame_size, w_dims[0]));
D
dangqingqing 已提交
69
    PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
70 71 72 73
                      platform::errors::InvalidArgument(
                          "The second dimension of Input(Weight) should be 4 * "
                          "%d, but received %d.",
                          frame_size, w_dims[1]));
74

D
dangqingqing 已提交
75
    auto b_dims = ctx->GetInputDim("Bias");
76 77 78 79 80 81 82 83 84 85
    PADDLE_ENFORCE_EQ(
        b_dims.size(), 2,
        platform::errors::InvalidArgument(
            "The rank of Input(Bias) should be 2, but received %d.",
            b_dims.size()));
    PADDLE_ENFORCE_EQ(
        b_dims[0], 1,
        platform::errors::InvalidArgument(
            "The first dimension of Input(Bias) should be 1, but received %d.",
            b_dims[0]));
86 87

    if (ctx->Attrs().Get<bool>("use_peepholes")) {
88 89 90 91 92 93
      PADDLE_ENFORCE_EQ(
          b_dims[1], 7 * frame_size,
          platform::errors::InvalidArgument(
              "The second dimension of Input(Bias) should be 7 * %d if enable "
              "peepholes connection, but received %d.",
              frame_size, b_dims[1]));
D
dangqingqing 已提交
94
    } else {
95 96 97 98 99 100
      PADDLE_ENFORCE_EQ(
          b_dims[1], 4 * frame_size,
          platform::errors::InvalidArgument(
              "The second dimension of Input(Bias) should be 4 * %d if disable "
              "peepholes connection, but received %d.",
              frame_size, b_dims[1]));
D
dangqingqing 已提交
101
    }
102

D
dangqingqing 已提交
103 104 105 106 107
    framework::DDim out_dims({in_dims[0], frame_size});
    ctx->SetOutputDim("Hidden", out_dims);
    ctx->SetOutputDim("Cell", out_dims);
    ctx->SetOutputDim("BatchGate", in_dims);
    ctx->SetOutputDim("BatchCellPreAct", out_dims);
D
dangqingqing 已提交
108 109 110
    ctx->ShareLoD("Input", "Hidden");
    ctx->ShareLoD("Input", "Cell");
  }
111 112

 protected:
113
  framework::OpKernelType GetExpectedKernelType(
114
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
115
    return framework::OpKernelType(
116 117
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        ctx.device_context());
118
  }
D
dangqingqing 已提交
119 120 121 122
};

class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
123
  void Make() override {
D
dangqingqing 已提交
124 125 126
    AddInput("Input",
             "(LoDTensor) the first input is a LodTensor, which support "
             "variable-time length input sequence. The underlying tensor in "
D
dangqingqing 已提交
127
             "this LoDTensor is a matrix with shape (T X 4D), where T is the "
D
dangqingqing 已提交
128 129 130 131
             "total time steps in this mini-batch, D is the hidden size.");
    AddInput("H0",
             "(Tensor, optional) the initial hidden state is an optional "
             "input. This is a tensor with shape (N x D), where N is the "
K
kexinzhao 已提交
132
             "batch size and D is the hidden size.")
133
        .AsDispensable();
D
dangqingqing 已提交
134 135 136
    AddInput("C0",
             "(Tensor, optional) the initial cell state is an optional "
             "input. This is a tensor with shape (N x D), where N is the "
Y
Yibing Liu 已提交
137
             "batch size. `H0` and `C0` can be NULL but only at the same time.")
138
        .AsDispensable();
D
dangqingqing 已提交
139 140
    AddInput("Weight",
             "(Tensor) the learnable hidden-hidden weights."
D
dangqingqing 已提交
141 142
             " - The shape is (D x 4D), where D is the hidden size. "
             " - Weight = {W_ch, W_ih, W_fh, W_oh}");
D
dangqingqing 已提交
143 144 145
    AddInput("Bias",
             "(Tensor) the learnable weights, which contains two parts: "
             "input-hidden bias weight and peephole connections weight if "
146 147
             "setting `use_peepholes` True. "
             "1. `use_peepholes = False` "
D
dangqingqing 已提交
148 149
             " - The shape is (1 x 4D). "
             " - Bias = {b_c, b_i, b_f, b_o}."
150
             "2. `use_peepholes = True` "
D
dangqingqing 已提交
151
             " - The shape is (1 x 7D). "
152
             " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
D
dangqingqing 已提交
153
    AddOutput("Hidden",
D
dangqingqing 已提交
154 155
              "(LoDTensor) the hidden state of LSTM operator. "
              "The shape is (T x D), and lod is the same with the `Input`.");
D
dangqingqing 已提交
156
    AddOutput("Cell",
D
dangqingqing 已提交
157 158
              "(LoDTensor) the cell state of LSTM operator. "
              "The shape is (T x D), and lod is the same with the `Input`.");
159 160
    AddOutput("BatchGate",
              "(LoDTensor) This LoDTensor contains input gate, forget gate "
Y
Yu Yang 已提交
161
              "and output gate after the nonlinear computation. This "
K
kexinzhao 已提交
162
              "LoDTensor has the same shape as the reorganized input, which "
D
dangqingqing 已提交
163
              "is also be called batch input. The LoD size is 2. The first "
164 165 166
              "LoD is the batch offsets and the second LoD contains the "
              "indexes, which denote the position of reorganized sequence "
              "in the raw input.")
D
dangqingqing 已提交
167
        .AsIntermediate();
D
dangqingqing 已提交
168
    AddOutput("BatchCellPreAct",
K
kexinzhao 已提交
169
              "(LoDTensor) This LoDTensor is obtained in the forward and used "
D
dangqingqing 已提交
170 171
              "in the backward.")
        .AsIntermediate();
172
    AddAttr<bool>("use_peepholes",
翟飞跃 已提交
173
                  "(bool, default: True) "
D
dangqingqing 已提交
174 175
                  "whether to enable diagonal/peephole connections.")
        .SetDefault(true);
176
    AddAttr<bool>("is_reverse",
翟飞跃 已提交
177
                  "(bool, default: False) "
D
dangqingqing 已提交
178
                  "whether to compute reversed LSTM.")
179
        .SetDefault(false);
D
dangqingqing 已提交
180
    AddAttr<std::string>(
181
        "gate_activation",
Y
Yu Yang 已提交
182
        "(string, default: sigmoid)"
D
dangqingqing 已提交
183
        "The activation for input gate, forget gate and output "
Y
Yu Yang 已提交
184
        "gate, `sigmoid` by default.")
D
dangqingqing 已提交
185 186
        .SetDefault("sigmoid")
        .InEnum({"sigmoid", "tanh", "relu", "identity"});
187
    AddAttr<std::string>("cell_activation",
Y
Yu Yang 已提交
188
                         "(string, default: tanh)"
翟飞跃 已提交
189
                         "The activation for cell output, `tanh` by default.")
D
dangqingqing 已提交
190 191
        .SetDefault("tanh")
        .InEnum({"sigmoid", "tanh", "relu", "identity"});
192
    AddAttr<std::string>("candidate_activation",
Y
Yu Yang 已提交
193
                         "(string, default: tanh)"
D
dangqingqing 已提交
194
                         "The activation for candidate hidden state, "
Y
Yu Yang 已提交
195
                         "`tanh` by default.")
D
dangqingqing 已提交
196 197
        .SetDefault("tanh")
        .InEnum({"sigmoid", "tanh", "relu", "identity"});
K
kexinzhao 已提交
198 199
    AddComment(R"DOC(
Long-Short Term Memory (LSTM) Operator.
D
dangqingqing 已提交
200

翟飞跃 已提交
201
The default implementation is diagonal/peephole connection
K
kexinzhao 已提交
202
(https://arxiv.org/pdf/1402.1128.pdf), the formula is as follows:
D
dangqingqing 已提交
203

Y
yuyang18 已提交
204
$$ i_t = \\sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) $$
D
dangqingqing 已提交
205

Y
yuyang18 已提交
206
$$ f_t = \\sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) $$
D
dangqingqing 已提交
207

Y
yuyang18 已提交
208
$$ \\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) $$
D
dangqingqing 已提交
209

Y
yuyang18 已提交
210
$$ o_t = \\sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) $$
D
dangqingqing 已提交
211

Y
yuyang18 已提交
212
$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$
D
dangqingqing 已提交
213

Y
yuyang18 已提交
214
$$ h_t = o_t \\odot act_h(c_t) $$
D
dangqingqing 已提交
215

Y
yi.wu 已提交
216 217 218
- W terms denote weight matrices (e.g. $W_{xi}$ is the matrix
  of weights from the input gate to the input), $W_{ic}, W_{fc}, W_{oc}$
  are diagonal weight matrices for peephole connections. In our implementation,
翟飞跃 已提交
219
  we use vectors to represent these diagonal weight matrices.
Y
yi.wu 已提交
220 221 222 223 224 225 226 227 228 229
- The b terms denote bias vectors ($b_i$ is the input gate bias vector).
- $\sigma$ is the non-line activations, such as logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
  and cell activation vectors, respectively, all of which have the same size as
  the cell output activation vector $h$.
- The $\odot$ is the element-wise product of the vectors.
- $act_g$ and $act_h$ are the cell input and cell output activation functions
  and `tanh` is usually used for them.
- $\tilde{c_t}$ is also called candidate hidden state,
  which is computed based on the current input and the previous hidden state.
D
dangqingqing 已提交
230

D
dangqingqing 已提交
231 232 233
Set `use_peepholes` False to disable peephole connection. The formula
is omitted here, please refer to the paper
http://www.bioinf.jku.at/publications/older/2604.pdf for details.
D
dangqingqing 已提交
234

D
dangqingqing 已提交
235 236
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
operations on the input $x_{t}$ are NOT included in this operator.
D
dangqingqing 已提交
237
Users can choose to use fully-connect operator before LSTM operator.
D
dangqingqing 已提交
238 239 240 241 242 243 244 245 246

)DOC");
  }
};

class LSTMGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

247
  void InferShape(framework::InferShapeContext* ctx) const override {
248 249 250 251 252 253 254 255 256 257
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM@Grad");

    OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
                   "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
                   "LSTM@Grad");
258

D
dangqingqing 已提交
259 260 261 262 263 264 265 266 267 268 269
    auto SetOutGradDim = [&ctx](const std::string& name) {
      auto g_name = framework::GradVarName(name);
      if (ctx->HasOutput(g_name))
        ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
    };

    SetOutGradDim("Input");
    SetOutGradDim("Weight");
    SetOutGradDim("Bias");
    SetOutGradDim("H0");
    SetOutGradDim("C0");
D
dangqingqing 已提交
270
  }
271 272

 protected:
273
  framework::OpKernelType GetExpectedKernelType(
274
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
275
    return framework::OpKernelType(
276 277
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        ctx.device_context());
278
  }
D
dangqingqing 已提交
279 280
};

H
hong 已提交
281 282
template <typename T>
class LSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
283
 public:
H
hong 已提交
284
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
285 286

 protected:
287
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
288
    op->SetType("lstm_grad");
H
hong 已提交
289 290 291
    op->SetAttrMap(this->Attrs());
    op->SetInput("Input", this->Input("Input"));
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
S
sneaxiy 已提交
292

H
hong 已提交
293 294 295
    if (this->HasInput("H0")) {
      op->SetInput("H0", this->Input("H0"));
      op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0"));
S
sneaxiy 已提交
296 297
    }

H
hong 已提交
298 299 300
    if (this->HasInput("C0")) {
      op->SetInput("C0", this->Input("C0"));
      op->SetOutput(framework::GradVarName("C0"), this->InputGrad("C0"));
S
sneaxiy 已提交
301 302
    }

H
hong 已提交
303 304
    op->SetInput("Weight", this->Input("Weight"));
    op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
S
sneaxiy 已提交
305

H
hong 已提交
306 307
    op->SetInput("Bias", this->Input("Bias"));
    op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
S
sneaxiy 已提交
308

H
hong 已提交
309
    op->SetInput("Cell", this->Output("Cell"));
S
sneaxiy 已提交
310

H
hong 已提交
311 312
    op->SetInput("Hidden", this->Output("Hidden"));
    op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden"));
S
sneaxiy 已提交
313

H
hong 已提交
314 315
    op->SetInput("BatchGate", this->Output("BatchGate"));
    op->SetInput("BatchCellPreAct", this->Output("BatchCellPreAct"));
S
sneaxiy 已提交
316 317 318
  }
};

D
dangqingqing 已提交
319 320 321 322
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
323
REGISTER_OPERATOR(lstm, ops::LSTMOp, ops::LSTMOpMaker,
H
hong 已提交
324 325
                  ops::LSTMGradOpMaker<paddle::framework::OpDesc>,
                  ops::LSTMGradOpMaker<paddle::imperative::OpBase>);
326
REGISTER_OPERATOR(lstm_grad, ops::LSTMGradOp);
Q
QI JUN 已提交
327 328 329 330 331 332
REGISTER_OP_CPU_KERNEL(
    lstm, ops::LSTMKernel<paddle::platform::CPUDeviceContext, float>,
    ops::LSTMKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
    lstm_grad, ops::LSTMGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::LSTMGradKernel<paddle::platform::CPUDeviceContext, double>);