gru_unit_op.cc 12.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
guosheng 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
guosheng 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
guosheng 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/gru_unit_op.h"
H
hong 已提交
16
#include <memory>
G
guosheng 已提交
17 18 19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

using framework::Tensor;

class GRUUnitOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

27
  void InferShape(framework::InferShapeContext* ctx) const override {
28 29 30 31 32 33 34 35
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnit");
    OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
                   "GRUUnit");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnit");
    OP_INOUT_CHECK(ctx->HasOutput("Gate"), "Output", "Gate", "GRUUnit");
    OP_INOUT_CHECK(ctx->HasOutput("ResetHiddenPrev"), "Output",
                   "ResetHiddenPrev", "GRUUnit");
    OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRUUnit");
36 37 38
    auto input_dims = ctx->GetInputDim("Input");
    auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
    auto weight_dims = ctx->GetInputDim("Weight");
G
guosheng 已提交
39 40 41 42 43
    int batch_size = input_dims[0];
    int input_size = input_dims[1];
    int frame_size = hidden_prev_dims[1];
    int weight_height = weight_dims[0];
    int weight_width = weight_dims[1];
44 45 46 47 48 49 50 51
    if (ctx->IsRuntime() || input_size >= 0) {
      PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
                        platform::errors::InvalidArgument(
                            "The second dimension of Input(Input) must be 3 "
                            "times of frame_size in GRUUnitOp, but received %d "
                            "(Input) vs %d (frame_size).",
                            input_size, frame_size));
    }
G
guosheng 已提交
52 53
    PADDLE_ENFORCE_EQ(
        weight_height, frame_size,
54 55 56 57 58
        platform::errors::InvalidArgument(
            "The shape of Input(Weight) matrix must be [frame_size, frame_size "
            "* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
            "(frame_size).",
            weight_height, weight_width, frame_size, frame_size * 3));
G
guosheng 已提交
59 60
    PADDLE_ENFORCE_EQ(
        weight_width, frame_size * 3,
61 62 63 64 65 66
        platform::errors::InvalidArgument(
            "The shape of Input(Weight) matrix must be [frame_size, frame_size "
            "* 3] in GRUUnitOp, but received [%d, %d] (Weight) vs [%d, %d] "
            "(frame_size).",
            weight_height, weight_width, frame_size, frame_size * 3));

Y
Yang Yang(Tony) 已提交
67
    if (ctx->HasInput("Bias")) {
G
guosheng 已提交
68 69 70
      auto bias_dims = ctx->GetInputDim("Bias");
      int bias_height = bias_dims[0];
      int bias_width = bias_dims[1];
71 72 73 74 75 76 77 78 79 80 81 82
      PADDLE_ENFORCE_EQ(
          bias_height, 1,
          platform::errors::InvalidArgument(
              "The shape of Bias must be [1, frame_size * 3], but received "
              "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
              bias_height, bias_width, frame_size * 3));
      PADDLE_ENFORCE_EQ(
          bias_width, frame_size * 3,
          platform::errors::InvalidArgument(
              "The shape of Bias must be [1, frame_size * 3], but received "
              "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
              bias_height, bias_width, frame_size * 3));
G
guosheng 已提交
83
    }
84 85 86
    ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
    ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
    ctx->SetOutputDim("Hidden", {batch_size, frame_size});
G
guosheng 已提交
87 88 89 90 91
  }
};

class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
92
  void Make() override {
93
    AddInput("Input",
G
guosheng 已提交
94 95
             "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
             "input.");
96
    AddInput("HiddenPrev",
G
guosheng 已提交
97 98
             "(Tensor) Matrix with shape [batch_size, frame_size] for the "
             "states of previous time step.");
K
kexinzhao 已提交
99 100 101 102 103 104 105 106 107 108 109
    AddInput(
        "Weight",
        "(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. "
        "The elements continuous in memory can be divided into two parts. "
        "The first part are weights of the update gate and reset gate "
        "with shape [frame_size, frame_size * 2], and the second part are "
        "weights of output candidate with shape [frame_size, frame_size].");
    AddInput(
        "Bias",
        "(Tensor) Bias vector with shape [1, frame_size * 3] concatenating "
        "bias of the update gate, reset gate and output candidate.")
Y
Yang Yang(Tony) 已提交
110
        .AsDispensable();
111
    AddOutput("Gate",
G
guosheng 已提交
112
              "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
K
kexinzhao 已提交
113
              "output of update gate, reset gate and output candidate.")
G
guosheng 已提交
114
        .AsIntermediate();
115
    AddOutput("ResetHiddenPrev",
G
guosheng 已提交
116
              "(Tensor) Matrix with shape [batch_size, frame_size] for the "
T
tianshuo78520a 已提交
117
              "reset hidden state of previous time step.")
G
guosheng 已提交
118
        .AsIntermediate();
119
    AddOutput("Hidden",
G
guosheng 已提交
120 121
              "(Tensor) The GRU hidden state of the current time step "
              "with shape [batch_size, frame_size].");
122 123 124 125 126 127 128 129 130 131
    AddAttr<int>("activation",
                 "(enum int, default tanh) "
                 "The activation type used for output candidate {h}_t.")
        .SetDefault(tanh)
        .InEnum({identity, sigmoid, tanh, relu});
    AddAttr<int>("gate_activation",
                 "(enum int, default sigmoid) "
                 "The activation type used in update gate and reset gate.")
        .SetDefault(sigmoid)
        .InEnum({identity, sigmoid, tanh, relu});
Q
Qiao Longfei 已提交
132 133
    AddAttr<bool>("origin_mode",
                  "bool"
134 135 136 137
                  "use origin mode in article <Learning Phrase Representations "
                  "using RNN Encoder–Decoder\n"
                  "for Statistical Machine "
                  "Translation>(https://arxiv.org/pdf/1406.1078.pdf)")
Q
Qiao Longfei 已提交
138
        .SetDefault(false);
G
guosheng 已提交
139
    AddComment(R"DOC(
G
guosheng 已提交
140
GRUUnit Operator implements partial calculations of the GRU unit as following:
K
kexinzhao 已提交
141 142

$$
G
guosheng 已提交
143 144 145 146
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r)  \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
K
kexinzhao 已提交
147
$$
G
guosheng 已提交
148

G
guosheng 已提交
149 150
which is same as one time step of GRU Operator.

151
@note To implement the complete GRU unit, fully-connected operator must be
G
guosheng 已提交
152
used before to feed xu, xr and xc as the Input of GRUUnit operator.
K
kexinzhao 已提交
153

G
guosheng 已提交
154 155 156 157 158 159 160 161
)DOC");
  }
};

class GRUUnitGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

162
  void InferShape(framework::InferShapeContext* ctx) const override {
163 164 165 166 167 168 169 170 171 172
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRUUnitGrad");
    OP_INOUT_CHECK(ctx->HasInput("HiddenPrev"), "Input", "HiddenPrev",
                   "GRUUnitGrad");
    OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRUUnitGrad");
    OP_INOUT_CHECK(ctx->HasInput("Gate"), "Input", "Gate", "GRUUnitGrad");
    OP_INOUT_CHECK(ctx->HasInput("ResetHiddenPrev"), "Input", "ResetHiddenPrev",
                   "GRUUnitGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Hidden")), "Input",
                   "Hidden@GRAD", "GRUUnitGrad");

173 174 175
    auto input_dims = ctx->GetInputDim("Input");
    auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
    auto weight_dims = ctx->GetInputDim("Weight");
G
guosheng 已提交
176 177 178 179 180
    // int batch_size = input_dims[0];
    int input_size = input_dims[1];
    int frame_size = hidden_prev_dims[1];
    int weight_height = weight_dims[0];
    int weight_width = weight_dims[1];
181 182 183 184 185 186 187 188 189
    if (ctx->IsRuntime() || input_size >= 0) {
      PADDLE_ENFORCE_EQ(
          input_size, frame_size * 3,
          platform::errors::InvalidArgument(
              "The second dimension of Input(Input) must be 3 "
              "times of frame_size in GRUUnitGradOp, but received %d "
              "(Input) vs %d (frame_size).",
              input_size, frame_size));
    }
G
guosheng 已提交
190 191
    PADDLE_ENFORCE_EQ(
        weight_height, frame_size,
192 193 194 195 196
        platform::errors::InvalidArgument(
            "The shape of Input(Weight) matrix must be [frame_size, frame_size "
            "* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
            "(frame_size).",
            weight_height, weight_width, frame_size, frame_size * 3));
G
guosheng 已提交
197 198
    PADDLE_ENFORCE_EQ(
        weight_width, frame_size * 3,
199 200 201 202 203
        platform::errors::InvalidArgument(
            "The shape of Input(Weight) matrix must be [frame_size, frame_size "
            "* 3] in GRUUnitGradOp, but received [%d, %d] (Weight) vs [%d, %d] "
            "(frame_size).",
            weight_height, weight_width, frame_size, frame_size * 3));
Y
Yu Yang 已提交
204
    if (ctx->HasInput("Bias")) {
G
guosheng 已提交
205 206 207
      auto bias_dims = ctx->GetInputDim("Bias");
      int bias_height = bias_dims[0];
      int bias_width = bias_dims[1];
208 209 210 211 212 213 214 215 216 217 218 219 220

      PADDLE_ENFORCE_EQ(
          bias_height, 1,
          platform::errors::InvalidArgument(
              "The shape of Bias must be [1, frame_size * 3], but received "
              "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
              bias_height, bias_width, frame_size * 3));
      PADDLE_ENFORCE_EQ(
          bias_width, frame_size * 3,
          platform::errors::InvalidArgument(
              "The shape of Bias must be [1, frame_size * 3], but received "
              "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).",
              bias_height, bias_width, frame_size * 3));
G
guosheng 已提交
221 222 223 224
      auto bias_grad_name = framework::GradVarName("Bias");
      if (ctx->HasOutput(bias_grad_name))
        ctx->SetOutputDim(bias_grad_name, bias_dims);
    }
225
    auto input_grad_name = framework::GradVarName("Input");
G
guosheng 已提交
226 227
    if (ctx->HasOutput(input_grad_name))
      ctx->SetOutputDim(input_grad_name, input_dims);
228
    auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev");
G
guosheng 已提交
229 230
    if (ctx->HasOutput(hidden_prev_grad_name))
      ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims);
231
    auto weight_grad_name = framework::GradVarName("Weight");
G
guosheng 已提交
232 233 234
    if (ctx->HasOutput(weight_grad_name))
      ctx->SetOutputDim(weight_grad_name, weight_dims);
  }
235 236 237 238 239 240 241

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
                                       ctx, framework::GradVarName("Hidden")),
                                   ctx.device_context());
  }
G
guosheng 已提交
242 243
};

H
hong 已提交
244 245
template <typename T>
class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
246
 public:
H
hong 已提交
247
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
248 249

 protected:
250
  void Apply(GradOpPtr<T> op) const override {
251 252
    op->SetType("gru_unit_grad");

H
hong 已提交
253 254 255 256
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("HiddenPrev", this->Input("HiddenPrev"));
    op->SetInput("Weight", this->Input("Weight"));
    op->SetInput("Bias", this->Input("Bias"));
257

H
hong 已提交
258 259 260
    op->SetInput("Gate", this->Output("Gate"));
    op->SetInput("ResetHiddenPrev", this->Output("ResetHiddenPrev"));
    op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden"));
261

H
hong 已提交
262
    op->SetAttrMap(this->Attrs());
263

H
hong 已提交
264
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
265
    op->SetOutput(framework::GradVarName("HiddenPrev"),
H
hong 已提交
266 267 268
                  this->InputGrad("HiddenPrev"));
    op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
    op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
269 270 271
  }
};

272
DECLARE_NO_NEED_BUFFER_VARS_INFERER(GRUUnitGradOpNoNeedBufferVarInferer,
273
                                    "Bias");
274

G
guosheng 已提交
275 276 277 278
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
279

Y
Yang Yang 已提交
280
REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
H
hong 已提交
281 282
                  ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>,
                  ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);
283
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp,
284
                  ops::GRUUnitGradOpNoNeedBufferVarInferer);
285

G
guosheng 已提交
286
REGISTER_OP_CPU_KERNEL(
Q
QI JUN 已提交
287 288 289 290 291 292
    gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
    ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
    gru_unit_grad,
    ops::GRUUnitGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::GRUUnitGradKernel<paddle::platform::CPUDeviceContext, double>);