sequence_pool_op.cc 7.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
16

17
#include <memory>
18
#include <string>
19 20 21 22

namespace paddle {
namespace operators {

23
class SequencePoolOp : public framework::OperatorWithKernel {
24 25 26
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

27
  void InferShape(framework::InferShapeContext* ctx) const override {
28 29
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequencePool");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequencePool");
30 31 32

    if (!ctx->IsRuntime()) {
      // Check the lod_level for compile-time.
33
      auto in_lod_level = ctx->GetLoDLevel("X");
34
      PADDLE_ENFORCE_GT(
35 36
          in_lod_level,
          0,
37 38 39 40
          platform::errors::InvalidArgument("The LoD level of Input(X) should "
                                            "be larger than 0, but received: "
                                            "lod level %u.",
                                            in_lod_level));
41
      ctx->SetLoDLevel("Out", in_lod_level - 1);
42 43
    }

Q
Qiao Longfei 已提交
44
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
45
    if (ctx->Attrs().Get<std::string>("pooltype") == "MAX") {
46 47
      OP_INOUT_CHECK(
          ctx->HasOutput("MaxIndex"), "Output", "MaxIndex", "SequencePool");
48 49
      ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X"));
    }
50 51 52
  }
};

53
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
54
 public:
Y
Yu Yang 已提交
55
  void Make() override {
柠檬味~ 已提交
56 57 58 59 60 61
    AddInput("X",
             "(phi::DenseTensor) The variable-length input of SequencePoolOp");
    AddOutput(
        "Out",
        "(phi::DenseTensor) The output of SequencePoolOp does not contain LoD "
        "information.");
62
    AddOutput("MaxIndex",
柠檬味~ 已提交
63 64
              "(phi::DenseTensor<int>) This tensor is used for the sequence "
              "max-pooling "
D
dangqingqing 已提交
65
              "to record the max indexes.")
66
        .AsIntermediate();
67 68 69
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
70 71
        .SetDefault(false)
        .AsExtra();
D
dzhwinter 已提交
72 73
    AddAttr<std::string>(
        "pooltype",
L
Luo Tao 已提交
74
        "(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.")
75 76
        .SetDefault("AVERAGE")
        .InEnum({"AVERAGE", "SUM", "SQRT", "LAST", "FIRST", "MAX"});
77 78 79
    AddAttr<float>("pad_value",
                   "(float, default 0.0) The value to pad for empty sequence.")
        .SetDefault(0.0);
80
    AddComment(R"DOC(
81
Sequence Pool Operator.
82

83 84
The SequencePoolOp pools features of all time-steps of each instance.
It supports six pooling types:
85 86 87
1. AVERAGE: $$Out[i] = \frac{\sum_i X_i}{N}$$
2. SUM:     $$Out[i] = \sum_jX_{ij}$$
3. SQRT:    $$Out[i] = \frac{\sum_jX_{ij}}{\sqrt{len(X_i)}}$$
88 89
4. LAST:    Out[i] = last instance in i-th sequence X[i]
5. FIRST:   Out[i] = first instance in i-th sequence X[i]
90
6. MAX:     $$Out[i] = max(X_i)$$
91

92 93
and for the empty sequence Out[i] = attr(pad_value).

94 95 96
The following example explains how this works:
For a mini-batch of 3 variable-length sentences,
containing 2, 3, and 2 time-steps:
Q
Qiao Longfei 已提交
97

柠檬味~ 已提交
98
Assume X is a [7,M,N] phi::DenseTensor, and X->lod()[0] = [0, 2, 5, 7], 7=2+3+2.
99 100
Besides, for the sake of simplicity, we assume M=1 and N=1,
and the value of X = [[1, 3], [2, 4, 6], [5, 1]].
L
Luo Tao 已提交
101

柠檬味~ 已提交
102
Thus, Out is a [3,1,1] phi::DenseTensor without LoD information.
103
And for different pooltype, the value of Out is as follows:
L
Luo Tao 已提交
104

105 106 107
- AVERAGE: [2, 4, 3], where 2=(1+3)/2, 4=(2+4+6)/3, 3=(5+1)/2
- SUM: [4, 12, 6], where 4=1+3, 12=2+4+6, 6=5+1
- SQRT: [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
L
Luo Tao 已提交
108
           6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
109 110 111 112
- MAX: [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
- LAST: [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
- FIRST: [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)

113 114 115 116
    )DOC");
  }
};

117
class SequencePoolGradOp : public framework::OperatorWithKernel {
118 119 120
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

121
  void InferShape(framework::InferShapeContext* ctx) const override {
122 123 124 125
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                   "Input",
                   framework::GradVarName("Out"),
                   "SequencePoolGrad");
126 127
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequencePoolGrad");

Q
Qiao Longfei 已提交
128 129
    auto og_dims = ctx->GetInputDim(framework::GradVarName("Out"));
    auto x_dims = ctx->GetInputDim("X");
130 131
    PADDLE_ENFORCE_EQ(og_dims.size(),
                      x_dims.size(),
132 133 134
                      platform::errors::InvalidArgument(
                          "The rank of output grad must equal to Input(X). But "
                          "received: input rank %u, input shape [%s].",
135 136
                          og_dims.size(),
                          og_dims));
137
    for (int64_t i = 1; i < og_dims.size(); ++i) {
138
      PADDLE_ENFORCE_EQ(
139 140
          og_dims[i],
          x_dims[i],
141 142 143 144 145
          platform::errors::InvalidArgument(
              "The dimension mismatch between Input(OUT@GRAD) and "
              "Input(X). Received Input(OUT@GRAD): input rank %u, "
              "input shape [%s]; received Input(X): input rank %u, "
              "input shape [%s].",
146 147 148 149
              og_dims.size(),
              og_dims,
              x_dims.size(),
              x_dims));
150
    }
151 152 153

    ctx->ShareDim("X", /*->*/ framework::GradVarName("X"));
    ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
154
  }
155 156

 protected:
157
  phi::KernelKey GetExpectedKernelType(
158
      const framework::ExecutionContext& ctx) const override {
159 160 161
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
                              ctx, framework::GradVarName("Out")),
                          ctx.GetPlace());
162
  }
163 164
};

H
hong 已提交
165 166
template <typename T>
class SequencePoolGradOpMaker : public framework::SingleGradOpMaker<T> {
167
 public:
H
hong 已提交
168
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
169 170

 protected:
171
  void Apply(GradOpPtr<T> op_desc_ptr) const override {
172
    op_desc_ptr->SetType("sequence_pool_grad");
H
hong 已提交
173
    op_desc_ptr->SetInput("X", this->Input("X"));
R
Ruibiao Chen 已提交
174
    if (PADDLE_GET_CONST(std::string, this->GetAttr("pooltype")) == "MAX") {
H
hong 已提交
175
      op_desc_ptr->SetInput("MaxIndex", this->Output("MaxIndex"));
176
    }
H
hong 已提交
177 178 179 180
    op_desc_ptr->SetInput(framework::GradVarName("Out"),
                          this->OutputGrad("Out"));
    op_desc_ptr->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op_desc_ptr->SetAttrMap(this->Attrs());
181 182 183
  }
};

184
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SequencePoolGradOpNoNeedBufferVarsInferer,
185
                                    "X");
186

187 188 189 190
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
191 192 193
REGISTER_OPERATOR(sequence_pool,
                  ops::SequencePoolOp,
                  ops::SequencePoolOpMaker,
H
hong 已提交
194 195
                  ops::SequencePoolGradOpMaker<paddle::framework::OpDesc>,
                  ops::SequencePoolGradOpMaker<paddle::imperative::OpBase>);
196 197
REGISTER_OPERATOR(sequence_pool_grad,
                  ops::SequencePoolGradOp,
198
                  ops::SequencePoolGradOpNoNeedBufferVarsInferer);
L
Leo Chen 已提交
199

H
huangjiyi 已提交
200 201 202 203 204 205
PD_REGISTER_STRUCT_KERNEL(sequence_pool_grad,
                          CPU,
                          ALL_LAYOUT,
                          ops::SequencePoolGradKernel,
                          float,
                          double) {}