lstm_op.cc 14.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dangqingqing 已提交
2

D
dangqingqing 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
D
dangqingqing 已提交
6

D
dangqingqing 已提交
7
http://www.apache.org/licenses/LICENSE-2.0
D
dangqingqing 已提交
8

D
dangqingqing 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dangqingqing 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/lstm_op.h"
16

S
sneaxiy 已提交
17
#include <memory>
18
#include <string>
D
dangqingqing 已提交
19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

class LSTMOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

27
  void InferShape(framework::InferShapeContext* ctx) const override {
28 29 30 31 32 33
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM");
    OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM");

    OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
    OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
D
dangqingqing 已提交
34

35 36 37
    bool is_test = ctx->Attrs().Get<bool>("is_test");

    if (!is_test) {
38 39 40 41 42
      OP_INOUT_CHECK(
          ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
      OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"),
                     "Output",
                     "BatchCellPreAct",
43 44
                     "LSTM");
    }
D
dangqingqing 已提交
45
    auto in_dims = ctx->GetInputDim("Input");
46
    PADDLE_ENFORCE_EQ(
47 48
        in_dims.size(),
        2,
49 50
        platform::errors::InvalidArgument(
            "Input(X)'s rank must be 2, but received %d.", in_dims.size()));
D
dangqingqing 已提交
51 52

    if (ctx->HasInput("H0")) {
53
      PADDLE_ENFORCE_EQ(
54 55
          ctx->HasInput("C0"),
          true,
56 57
          platform::errors::NotFound("Input(Cell) and Input(Hidden) of LSTM "
                                     "should not be null at the same time."));
D
dangqingqing 已提交
58 59
      auto h_dims = ctx->GetInputDim("H0");
      auto c_dims = ctx->GetInputDim("C0");
60 61
      PADDLE_ENFORCE_EQ(h_dims,
                        c_dims,
62 63 64
                        platform::errors::InvalidArgument(
                            "The dimension of Input(H0) and Input(C0) should "
                            "be the same, but received [%s] (H0) vs [%s] (C0).",
65 66
                            h_dims,
                            c_dims));
D
dangqingqing 已提交
67 68
    }

D
dangqingqing 已提交
69
    int frame_size = in_dims[1] / 4;
D
dangqingqing 已提交
70
    auto w_dims = ctx->GetInputDim("Weight");
71
    PADDLE_ENFORCE_EQ(
72 73
        w_dims.size(),
        2,
74 75 76
        platform::errors::InvalidArgument(
            "The rank of Input(Weight) should be 2, but received %d.",
            w_dims.size()));
77 78
    PADDLE_ENFORCE_EQ(w_dims[0],
                      frame_size,
79 80 81
                      platform::errors::InvalidArgument(
                          "The first dimension of Input(Weight) should be %d, "
                          "but received %d.",
82 83 84 85
                          frame_size,
                          w_dims[0]));
    PADDLE_ENFORCE_EQ(w_dims[1],
                      4 * frame_size,
86 87 88
                      platform::errors::InvalidArgument(
                          "The second dimension of Input(Weight) should be 4 * "
                          "%d, but received %d.",
89 90
                          frame_size,
                          w_dims[1]));
91

D
dangqingqing 已提交
92
    auto b_dims = ctx->GetInputDim("Bias");
93
    PADDLE_ENFORCE_EQ(
94 95
        b_dims.size(),
        2,
96 97 98 99
        platform::errors::InvalidArgument(
            "The rank of Input(Bias) should be 2, but received %d.",
            b_dims.size()));
    PADDLE_ENFORCE_EQ(
100 101
        b_dims[0],
        1,
102 103 104
        platform::errors::InvalidArgument(
            "The first dimension of Input(Bias) should be 1, but received %d.",
            b_dims[0]));
105 106

    if (ctx->Attrs().Get<bool>("use_peepholes")) {
107
      PADDLE_ENFORCE_EQ(
108 109
          b_dims[1],
          7 * frame_size,
110 111 112
          platform::errors::InvalidArgument(
              "The second dimension of Input(Bias) should be 7 * %d if enable "
              "peepholes connection, but received %d.",
113 114
              frame_size,
              b_dims[1]));
D
dangqingqing 已提交
115
    } else {
116
      PADDLE_ENFORCE_EQ(
117 118
          b_dims[1],
          4 * frame_size,
119 120 121
          platform::errors::InvalidArgument(
              "The second dimension of Input(Bias) should be 4 * %d if disable "
              "peepholes connection, but received %d.",
122 123
              frame_size,
              b_dims[1]));
D
dangqingqing 已提交
124
    }
125

D
dangqingqing 已提交
126 127 128
    framework::DDim out_dims({in_dims[0], frame_size});
    ctx->SetOutputDim("Hidden", out_dims);
    ctx->SetOutputDim("Cell", out_dims);
129 130 131 132
    if (!is_test) {
      ctx->SetOutputDim("BatchGate", in_dims);
      ctx->SetOutputDim("BatchCellPreAct", out_dims);
    }
D
dangqingqing 已提交
133 134 135
    ctx->ShareLoD("Input", "Hidden");
    ctx->ShareLoD("Input", "Cell");
  }
136 137

 protected:
138
  phi::KernelKey GetExpectedKernelType(
139
      const framework::ExecutionContext& ctx) const override {
140 141
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
                          ctx.device_context().GetPlace());
142
  }
D
dangqingqing 已提交
143 144 145 146
};

class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
147
  void Make() override {
148 149 150 151 152 153
    AddInput(
        "Input",
        "(phi::DenseTensor) the first input is a phi::DenseTensor, which "
        "support variable-time length input sequence. The underlying tensor in "
        "this phi::DenseTensor is a matrix with shape (T X 4D), where T is the "
        "total time steps in this mini-batch, D is the hidden size.");
D
dangqingqing 已提交
154 155 156
    AddInput("H0",
             "(Tensor, optional) the initial hidden state is an optional "
             "input. This is a tensor with shape (N x D), where N is the "
K
kexinzhao 已提交
157
             "batch size and D is the hidden size.")
158
        .AsDispensable();
D
dangqingqing 已提交
159 160 161
    AddInput("C0",
             "(Tensor, optional) the initial cell state is an optional "
             "input. This is a tensor with shape (N x D), where N is the "
Y
Yibing Liu 已提交
162
             "batch size. `H0` and `C0` can be NULL but only at the same time.")
163
        .AsDispensable();
D
dangqingqing 已提交
164 165
    AddInput("Weight",
             "(Tensor) the learnable hidden-hidden weights."
D
dangqingqing 已提交
166 167
             " - The shape is (D x 4D), where D is the hidden size. "
             " - Weight = {W_ch, W_ih, W_fh, W_oh}");
D
dangqingqing 已提交
168 169 170
    AddInput("Bias",
             "(Tensor) the learnable weights, which contains two parts: "
             "input-hidden bias weight and peephole connections weight if "
171 172
             "setting `use_peepholes` True. "
             "1. `use_peepholes = False` "
D
dangqingqing 已提交
173 174
             " - The shape is (1 x 4D). "
             " - Bias = {b_c, b_i, b_f, b_o}."
175
             "2. `use_peepholes = True` "
D
dangqingqing 已提交
176
             " - The shape is (1 x 7D). "
177
             " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
D
dangqingqing 已提交
178
    AddOutput("Hidden",
179
              "(phi::DenseTensor) the hidden state of LSTM operator. "
D
dangqingqing 已提交
180
              "The shape is (T x D), and lod is the same with the `Input`.");
D
dangqingqing 已提交
181
    AddOutput("Cell",
182
              "(phi::DenseTensor) the cell state of LSTM operator. "
D
dangqingqing 已提交
183
              "The shape is (T x D), and lod is the same with the `Input`.");
184 185 186 187 188 189 190 191 192 193
    AddOutput(
        "BatchGate",
        "(phi::DenseTensor) This phi::DenseTensor contains input gate, forget "
        "gate "
        "and output gate after the nonlinear computation. This "
        "phi::DenseTensor has the same shape as the reorganized input, which "
        "is also be called batch input. The LoD size is 2. The first "
        "LoD is the batch offsets and the second LoD contains the "
        "indexes, which denote the position of reorganized sequence "
        "in the raw input.")
194 195
        .AsIntermediate()
        .AsExtra();
D
dangqingqing 已提交
196
    AddOutput("BatchCellPreAct",
197 198
              "(phi::DenseTensor) This phi::DenseTensor is obtained in the "
              "forward and used "
D
dangqingqing 已提交
199
              "in the backward.")
200 201
        .AsIntermediate()
        .AsExtra();
202
    AddAttr<bool>("use_peepholes",
翟飞跃 已提交
203
                  "(bool, default: True) "
D
dangqingqing 已提交
204 205
                  "whether to enable diagonal/peephole connections.")
        .SetDefault(true);
206
    AddAttr<bool>("is_reverse",
翟飞跃 已提交
207
                  "(bool, default: False) "
D
dangqingqing 已提交
208
                  "whether to compute reversed LSTM.")
209
        .SetDefault(false);
210
    AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
D
dangqingqing 已提交
211
    AddAttr<std::string>(
212
        "gate_activation",
Y
Yu Yang 已提交
213
        "(string, default: sigmoid)"
D
dangqingqing 已提交
214
        "The activation for input gate, forget gate and output "
Y
Yu Yang 已提交
215
        "gate, `sigmoid` by default.")
D
dangqingqing 已提交
216 217
        .SetDefault("sigmoid")
        .InEnum({"sigmoid", "tanh", "relu", "identity"});
218
    AddAttr<std::string>("cell_activation",
Y
Yu Yang 已提交
219
                         "(string, default: tanh)"
翟飞跃 已提交
220
                         "The activation for cell output, `tanh` by default.")
D
dangqingqing 已提交
221 222
        .SetDefault("tanh")
        .InEnum({"sigmoid", "tanh", "relu", "identity"});
223
    AddAttr<std::string>("candidate_activation",
Y
Yu Yang 已提交
224
                         "(string, default: tanh)"
D
dangqingqing 已提交
225
                         "The activation for candidate hidden state, "
Y
Yu Yang 已提交
226
                         "`tanh` by default.")
D
dangqingqing 已提交
227 228
        .SetDefault("tanh")
        .InEnum({"sigmoid", "tanh", "relu", "identity"});
K
kexinzhao 已提交
229 230
    AddComment(R"DOC(
Long-Short Term Memory (LSTM) Operator.
D
dangqingqing 已提交
231

翟飞跃 已提交
232
The default implementation is diagonal/peephole connection
K
kexinzhao 已提交
233
(https://arxiv.org/pdf/1402.1128.pdf), the formula is as follows:
D
dangqingqing 已提交
234

Y
yuyang18 已提交
235
$$ i_t = \\sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) $$
D
dangqingqing 已提交
236

Y
yuyang18 已提交
237
$$ f_t = \\sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) $$
D
dangqingqing 已提交
238

Y
yuyang18 已提交
239
$$ \\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) $$
D
dangqingqing 已提交
240

Y
yuyang18 已提交
241
$$ o_t = \\sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) $$
D
dangqingqing 已提交
242

Y
yuyang18 已提交
243
$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$
D
dangqingqing 已提交
244

Y
yuyang18 已提交
245
$$ h_t = o_t \\odot act_h(c_t) $$
D
dangqingqing 已提交
246

Y
yi.wu 已提交
247 248 249
- W terms denote weight matrices (e.g. $W_{xi}$ is the matrix
  of weights from the input gate to the input), $W_{ic}, W_{fc}, W_{oc}$
  are diagonal weight matrices for peephole connections. In our implementation,
翟飞跃 已提交
250
  we use vectors to represent these diagonal weight matrices.
Y
yi.wu 已提交
251 252 253 254 255 256 257 258 259 260
- The b terms denote bias vectors ($b_i$ is the input gate bias vector).
- $\sigma$ is the non-line activations, such as logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
  and cell activation vectors, respectively, all of which have the same size as
  the cell output activation vector $h$.
- The $\odot$ is the element-wise product of the vectors.
- $act_g$ and $act_h$ are the cell input and cell output activation functions
  and `tanh` is usually used for them.
- $\tilde{c_t}$ is also called candidate hidden state,
  which is computed based on the current input and the previous hidden state.
D
dangqingqing 已提交
261

D
dangqingqing 已提交
262 263 264
Set `use_peepholes` False to disable peephole connection. The formula
is omitted here, please refer to the paper
http://www.bioinf.jku.at/publications/older/2604.pdf for details.
D
dangqingqing 已提交
265

D
dangqingqing 已提交
266 267
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
operations on the input $x_{t}$ are NOT included in this operator.
D
dangqingqing 已提交
268
Users can choose to use fully-connect operator before LSTM operator.
D
dangqingqing 已提交
269 270 271 272 273 274 275 276 277

)DOC");
  }
};

class LSTMGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

278
  void InferShape(framework::InferShapeContext* ctx) const override {
279 280 281 282 283 284
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM@Grad");

285 286 287 288 289
    OP_INOUT_CHECK(
        ctx->HasInput("BatchGate"), "Input", "BatchGate", "LSTM@Grad");
    OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"),
                   "Input",
                   "BatchCellPreAct",
290
                   "LSTM@Grad");
291

D
dangqingqing 已提交
292 293 294 295 296 297 298 299 300 301 302
    auto SetOutGradDim = [&ctx](const std::string& name) {
      auto g_name = framework::GradVarName(name);
      if (ctx->HasOutput(g_name))
        ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
    };

    SetOutGradDim("Input");
    SetOutGradDim("Weight");
    SetOutGradDim("Bias");
    SetOutGradDim("H0");
    SetOutGradDim("C0");
D
dangqingqing 已提交
303
  }
304 305

 protected:
306
  phi::KernelKey GetExpectedKernelType(
307
      const framework::ExecutionContext& ctx) const override {
308 309
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
                          ctx.device_context().GetPlace());
310
  }
D
dangqingqing 已提交
311 312
};

H
hong 已提交
313 314
template <typename T>
class LSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
315
 public:
H
hong 已提交
316
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
317 318

 protected:
319
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
320
    op->SetType("lstm_grad");
H
hong 已提交
321 322 323
    op->SetAttrMap(this->Attrs());
    op->SetInput("Input", this->Input("Input"));
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
S
sneaxiy 已提交
324

H
hong 已提交
325 326 327
    if (this->HasInput("H0")) {
      op->SetInput("H0", this->Input("H0"));
      op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0"));
S
sneaxiy 已提交
328 329
    }

H
hong 已提交
330 331 332
    if (this->HasInput("C0")) {
      op->SetInput("C0", this->Input("C0"));
      op->SetOutput(framework::GradVarName("C0"), this->InputGrad("C0"));
S
sneaxiy 已提交
333 334
    }

H
hong 已提交
335 336
    op->SetInput("Weight", this->Input("Weight"));
    op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
S
sneaxiy 已提交
337

H
hong 已提交
338 339
    op->SetInput("Bias", this->Input("Bias"));
    op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
S
sneaxiy 已提交
340

H
hong 已提交
341
    op->SetInput("Cell", this->Output("Cell"));
S
sneaxiy 已提交
342

H
hong 已提交
343 344
    op->SetInput("Hidden", this->Output("Hidden"));
    op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden"));
S
sneaxiy 已提交
345

H
hong 已提交
346 347
    op->SetInput("BatchGate", this->Output("BatchGate"));
    op->SetInput("BatchCellPreAct", this->Output("BatchCellPreAct"));
S
sneaxiy 已提交
348 349 350
  }
};

D
dangqingqing 已提交
351 352 353 354
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
355 356 357
REGISTER_OPERATOR(lstm,
                  ops::LSTMOp,
                  ops::LSTMOpMaker,
H
hong 已提交
358 359
                  ops::LSTMGradOpMaker<paddle::framework::OpDesc>,
                  ops::LSTMGradOpMaker<paddle::imperative::OpBase>);
360
REGISTER_OPERATOR(lstm_grad, ops::LSTMGradOp);
H
huangjiyi 已提交
361 362 363 364 365

PD_REGISTER_STRUCT_KERNEL(
    lstm, CPU, ALL_LAYOUT, ops::LSTMKernel, float, double) {}
PD_REGISTER_STRUCT_KERNEL(
    lstm_grad, CPU, ALL_LAYOUT, ops::LSTMGradKernel, float, double) {}