sequence_pool_op.cc 7.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/op_registry.h"
16 17 18 19

namespace paddle {
namespace operators {

20
class SequencePoolOp : public framework::OperatorWithKernel {
21 22 23
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

24
  void InferShape(framework::InferShapeContext* ctx) const override {
25 26
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequencePool");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequencePool");
27 28 29

    if (!ctx->IsRuntime()) {
      // Check the lod_level for compile-time.
30
      auto in_lod_level = ctx->GetLoDLevel("X");
31
      PADDLE_ENFORCE_GT(
32 33
          in_lod_level,
          0,
34 35 36 37
          platform::errors::InvalidArgument("The LoD level of Input(X) should "
                                            "be larger than 0, but received: "
                                            "lod level %u.",
                                            in_lod_level));
38
      ctx->SetLoDLevel("Out", in_lod_level - 1);
39 40
    }

Q
Qiao Longfei 已提交
41
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
42
    if (ctx->Attrs().Get<std::string>("pooltype") == "MAX") {
43 44
      OP_INOUT_CHECK(
          ctx->HasOutput("MaxIndex"), "Output", "MaxIndex", "SequencePool");
45 46
      ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X"));
    }
47 48 49
  }
};

50
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
51
 public:
Y
Yu Yang 已提交
52
  void Make() override {
柠檬味~ 已提交
53 54 55 56 57 58
    AddInput("X",
             "(phi::DenseTensor) The variable-length input of SequencePoolOp");
    AddOutput(
        "Out",
        "(phi::DenseTensor) The output of SequencePoolOp does not contain LoD "
        "information.");
59
    AddOutput("MaxIndex",
柠檬味~ 已提交
60 61
              "(phi::DenseTensor<int>) This tensor is used for the sequence "
              "max-pooling "
D
dangqingqing 已提交
62
              "to record the max indexes.")
63
        .AsIntermediate();
64 65 66
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
67 68
        .SetDefault(false)
        .AsExtra();
D
dzhwinter 已提交
69 70
    AddAttr<std::string>(
        "pooltype",
L
Luo Tao 已提交
71
        "(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.")
72 73
        .SetDefault("AVERAGE")
        .InEnum({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"});
74 75 76
    AddAttr<float>("pad_value",
                   "(float, default 0.0) The value to pad for empty sequence.")
        .SetDefault(0.0);
77
    AddComment(R"DOC(
78
Sequence Pool Operator.
79

80 81
The SequencePoolOp pools features of all time-steps of each instance.
It supports six pooling types:
82 83 84
1. AVERAGE: $$Out[i] = \frac{\sum_i X_i}{N}$$
2. SUM:     $$Out[i] = \sum_jX_{ij}$$
3. SQRT:    $$Out[i] = \frac{\sum_jX_{ij}}{\sqrt{len(X_i)}}$$
85 86
4. LAST:    Out[i] = last instance in i-th sequence X[i]
5. FIRST:   Out[i] = first instance in i-th sequence X[i]
87
6. MAX:     $$Out[i] = max(X_i)$$
88

89 90
and for the empty sequence Out[i] = attr(pad_value).

91 92 93
The following example explains how this works:
For a mini-batch of 3 variable-length sentences,
containing 2, 3, and 2 time-steps:
Q
Qiao Longfei 已提交
94

柠檬味~ 已提交
95
Assume X is a [7,M,N] phi::DenseTensor, and X->lod()[0] = [0, 2, 5, 7], 7=2+3+2.
96 97
Besides, for the sake of simplicity, we assume M=1 and N=1,
and the value of X = [[1, 3], [2, 4, 6], [5, 1]].
L
Luo Tao 已提交
98

柠檬味~ 已提交
99
Thus, Out is a [3,1,1] phi::DenseTensor without LoD information.
100
And for different pooltype, the value of Out is as follows:
L
Luo Tao 已提交
101

102 103 104
- AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
- SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
- SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
L
Luo Tao 已提交
105
           6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
106 107 108 109
- MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
- LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
- FIRST: [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)

110 111 112 113
    )DOC");
  }
};

114
class SequencePoolGradOp : public framework::OperatorWithKernel {
115 116 117
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

118
  void InferShape(framework::InferShapeContext* ctx) const override {
119 120 121 122
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                   "Input",
                   framework::GradVarName("Out"),
                   "SequencePoolGrad");
123 124
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequencePoolGrad");

Q
Qiao Longfei 已提交
125 126
    auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
    auto x_dims = ctx->GetInputDim("X");
127 128
    PADDLE_ENFORCE_EQ(og_dims.size(),
                      x_dims.size(),
129 130 131
                      platform::errors::InvalidArgument(
                          "The rank of output grad must equal to Input(X). But "
                          "received: input rank %u, input shape [%s].",
132 133
                          og_dims.size(),
                          og_dims));
134
    for (int64_t i = 1; i < og_dims.size(); ++i) {
135
      PADDLE_ENFORCE_EQ(
136 137
          og_dims[i],
          x_dims[i],
138 139 140 141 142
          platform::errors::InvalidArgument(
              "The dimension mismatch between Input(OUT@GRAD) and "
              "Input(X). Received Input(OUT@GRAD): input rank %u, "
              "input shape [%s]; received Input(X): input rank %u, "
              "input shape [%s].",
143 144 145 146
              og_dims.size(),
              og_dims,
              x_dims.size(),
              x_dims));
147
    }
148 149 150

    ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
    ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
151
  }
152 153

 protected:
154
  phi::KernelKey GetExpectedKernelType(
155
      const framework::ExecutionContext& ctx) const override {
156 157 158
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
                              ctx, framework::GradVarName("Out")),
                          ctx.GetPlace());
159
  }
160 161
};

H
hong 已提交
162 163
template <typename T>
class SequencePoolGradOpMaker : public framework::SingleGradOpMaker<T> {
164
 public:
H
hong 已提交
165
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
166 167

 protected:
168
  void Apply(GradOpPtr<T> op_desc_ptr) const override {
169
    op_desc_ptr->SetType("sequence_pool_grad");
H
hong 已提交
170
    op_desc_ptr->SetInput("X", this->Input("X"));
R
Ruibiao Chen 已提交
171
    if (PADDLE_GET_CONST(std::string, this->GetAttr("pooltype")) == "MAX") {
H
hong 已提交
172
      op_desc_ptr->SetInput("MaxIndex", this->Output("MaxIndex"));
173
    }
H
hong 已提交
174 175 176 177
    op_desc_ptr->SetInput(framework::GradVarName("Out"),
                          this->OutputGrad("Out"));
    op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op_desc_ptr->SetAttrMap(this->Attrs());
178 179 180
  }
};

181
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePoolGradOpNoNeedBufferVarsInferer,
182
                                    "X");
183

184 185 186 187
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
188 189 190
REGISTER_OPERATOR(sequence_pool,
                  ops::SequencePoolOp,
                  ops::SequencePoolOpMaker,
H
hong 已提交
191 192
                  ops::SequencePoolGradOpMaker<paddle::framework::OpDesc>,
                  ops::SequencePoolGradOpMaker<paddle::imperative::OpBase>);
193 194
REGISTER_OPERATOR(sequence_pool_grad,
                  ops::SequencePoolGradOp,
195
                  ops::SequencePoolGradOpNoNeedBufferVarsInferer);